Optimize MiniCPMO mask creation with vectorized implementation (#22464)

Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com>
Signed-off-by: zitian zhao <zitian.zhao@tencentmusic.com>
This commit is contained in:
ZiTian Zhao 2025-08-08 11:18:50 +08:00 committed by GitHub
parent c152e2a8a0
commit 6f287915d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -587,15 +587,29 @@ class MiniCPMO(MiniCPMV2_6):
num_lookhead: int = 0,
) -> torch.Tensor:
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size,
0)
ending = min((i // chunk_size + 1) * chunk_size + num_lookhead,
size)
ret[i, start:ending] = True
# Vectorized computation of row indices and chunk boundaries
row_indices = torch.arange(size, device=device)
chunk_indices = row_indices // chunk_size
if num_left_chunks < 0:
# If num_left_chunks < 0, start is always 0 for all rows
start_indices = torch.zeros_like(row_indices)
else:
# Compute start indices vectorially
start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks,
min=0)
start_indices = start_chunk_indices * chunk_size
# Compute ending indices vectorially
end_chunk_indices = chunk_indices + 1
end_indices = torch.clamp(end_chunk_indices * chunk_size +
num_lookhead,
max=size)
# Create column indices for broadcasting
col_indices = torch.arange(size, device=device).unsqueeze(0)
row_indices = row_indices.unsqueeze(1)
start_indices = start_indices.unsqueeze(1)
end_indices = end_indices.unsqueeze(1)
# Vectorized mask creation
ret = (col_indices >= start_indices) & (col_indices < end_indices)
return ret
def _get_feat_extract_output_lengths(self,