mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:35:58 +08:00
Optimize MiniCPMO mask creation with vectorized implementation (#22464)
Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com> Signed-off-by: zitian zhao <zitian.zhao@tencentmusic.com>
This commit is contained in:
parent
c152e2a8a0
commit
6f287915d8
@ -587,15 +587,29 @@ class MiniCPMO(MiniCPMV2_6):
|
|||||||
num_lookhead: int = 0,
|
num_lookhead: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||||
for i in range(size):
|
# Vectorized computation of row indices and chunk boundaries
|
||||||
if num_left_chunks < 0:
|
row_indices = torch.arange(size, device=device)
|
||||||
start = 0
|
chunk_indices = row_indices // chunk_size
|
||||||
else:
|
if num_left_chunks < 0:
|
||||||
start = max((i // chunk_size - num_left_chunks) * chunk_size,
|
# If num_left_chunks < 0, start is always 0 for all rows
|
||||||
0)
|
start_indices = torch.zeros_like(row_indices)
|
||||||
ending = min((i // chunk_size + 1) * chunk_size + num_lookhead,
|
else:
|
||||||
size)
|
# Compute start indices vectorially
|
||||||
ret[i, start:ending] = True
|
start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks,
|
||||||
|
min=0)
|
||||||
|
start_indices = start_chunk_indices * chunk_size
|
||||||
|
# Compute ending indices vectorially
|
||||||
|
end_chunk_indices = chunk_indices + 1
|
||||||
|
end_indices = torch.clamp(end_chunk_indices * chunk_size +
|
||||||
|
num_lookhead,
|
||||||
|
max=size)
|
||||||
|
# Create column indices for broadcasting
|
||||||
|
col_indices = torch.arange(size, device=device).unsqueeze(0)
|
||||||
|
row_indices = row_indices.unsqueeze(1)
|
||||||
|
start_indices = start_indices.unsqueeze(1)
|
||||||
|
end_indices = end_indices.unsqueeze(1)
|
||||||
|
# Vectorized mask creation
|
||||||
|
ret = (col_indices >= start_indices) & (col_indices < end_indices)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def _get_feat_extract_output_lengths(self,
|
def _get_feat_extract_output_lengths(self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user