mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:30:37 +08:00
Optimize MiniCPMO mask creation with vectorized implementation (#22464)
Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com> Signed-off-by: zitian zhao <zitian.zhao@tencentmusic.com>
This commit is contained in:
parent
c152e2a8a0
commit
6f287915d8
@ -587,15 +587,29 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
num_lookhead: int = 0,
|
||||
) -> torch.Tensor:
|
||||
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||
for i in range(size):
|
||||
if num_left_chunks < 0:
|
||||
start = 0
|
||||
else:
|
||||
start = max((i // chunk_size - num_left_chunks) * chunk_size,
|
||||
0)
|
||||
ending = min((i // chunk_size + 1) * chunk_size + num_lookhead,
|
||||
size)
|
||||
ret[i, start:ending] = True
|
||||
# Vectorized computation of row indices and chunk boundaries
|
||||
row_indices = torch.arange(size, device=device)
|
||||
chunk_indices = row_indices // chunk_size
|
||||
if num_left_chunks < 0:
|
||||
# If num_left_chunks < 0, start is always 0 for all rows
|
||||
start_indices = torch.zeros_like(row_indices)
|
||||
else:
|
||||
# Compute start indices vectorially
|
||||
start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks,
|
||||
min=0)
|
||||
start_indices = start_chunk_indices * chunk_size
|
||||
# Compute ending indices vectorially
|
||||
end_chunk_indices = chunk_indices + 1
|
||||
end_indices = torch.clamp(end_chunk_indices * chunk_size +
|
||||
num_lookhead,
|
||||
max=size)
|
||||
# Create column indices for broadcasting
|
||||
col_indices = torch.arange(size, device=device).unsqueeze(0)
|
||||
row_indices = row_indices.unsqueeze(1)
|
||||
start_indices = start_indices.unsqueeze(1)
|
||||
end_indices = end_indices.unsqueeze(1)
|
||||
# Vectorized mask creation
|
||||
ret = (col_indices >= start_indices) & (col_indices < end_indices)
|
||||
return ret
|
||||
|
||||
def _get_feat_extract_output_lengths(self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user