From 6f287915d8e4c2c09e7db2eb5cb670036d33f478 Mon Sep 17 00:00:00 2001 From: ZiTian Zhao Date: Fri, 8 Aug 2025 11:18:50 +0800 Subject: [PATCH] Optimize MiniCPMO mask creation with vectorized implementation (#22464) Signed-off-by: zitian.zhao Signed-off-by: zitian zhao --- vllm/model_executor/models/minicpmo.py | 32 ++++++++++++++++++-------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 4e4fc3d5c7621..fd91c7fcc12bd 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -587,15 +587,29 @@ class MiniCPMO(MiniCPMV2_6): num_lookhead: int = 0, ) -> torch.Tensor: ret = torch.zeros(size, size, device=device, dtype=torch.bool) - for i in range(size): - if num_left_chunks < 0: - start = 0 - else: - start = max((i // chunk_size - num_left_chunks) * chunk_size, - 0) - ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, - size) - ret[i, start:ending] = True + # Vectorized computation of row indices and chunk boundaries + row_indices = torch.arange(size, device=device) + chunk_indices = row_indices // chunk_size + if num_left_chunks < 0: + # If num_left_chunks < 0, start is always 0 for all rows + start_indices = torch.zeros_like(row_indices) + else: + # Compute start indices vectorially + start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, + min=0) + start_indices = start_chunk_indices * chunk_size + # Compute ending indices vectorially + end_chunk_indices = chunk_indices + 1 + end_indices = torch.clamp(end_chunk_indices * chunk_size + + num_lookhead, + max=size) + # Create column indices for broadcasting + col_indices = torch.arange(size, device=device).unsqueeze(0) + row_indices = row_indices.unsqueeze(1) + start_indices = start_indices.unsqueeze(1) + end_indices = end_indices.unsqueeze(1) + # Vectorized mask creation + ret = (col_indices >= start_indices) & (col_indices < end_indices) return ret def _get_feat_extract_output_lengths(self,