diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index adca97c71c581..1e6ff1fec6d5c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -608,6 +608,17 @@ class Qwen2_5_VisionTransformer(nn.Module): window_index = torch.cat(window_index, dim=0) return window_index, cu_window_seqlens + def compute_attn_mask_seqlen( + self, + cu_seqlens: torch.Tensor, + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + def forward( self, x: torch.Tensor, @@ -645,25 +656,27 @@ class Qwen2_5_VisionTransformer(nn.Module): # transformers hidden_states = hidden_states.unsqueeze(1) - max_seqlen = None - seqlens = None + # pre-compute seqlens for window/full attn to reduce cuMemcpy operations + max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen( + cu_seqlens) + max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( + cu_window_seqlens) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens + max_seqlen_now = max_seqlen_full + seqlens_now = seqlens_full else: cu_seqlens_now = cu_window_seqlens - # pre-compute cu_seqlens for window attn - if self.attn_backend == _Backend.FLASH_ATTN: - max_seqlen = (cu_seqlens_now[1:] - - cu_seqlens_now[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: - seqlens = (cu_seqlens_now[1:] - cu_seqlens_now[:-1]).tolist() + max_seqlen_now = max_seqlen_window + seqlens_now = seqlens_window + hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, - max_seqlen=max_seqlen, - seqlens=seqlens, + max_seqlen=max_seqlen_now, + seqlens=seqlens_now, ) # For Qwen2.5-VL-3B, float16 will overflow at last block diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b8ac40b7e7f9b..7537671e1bb82 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -617,6 +617,16 @@ class Qwen2VisionTransformer(nn.Module): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb + def compute_attn_mask_seqlen( + self, cu_seqlens: torch.Tensor + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + def forward( self, x: torch.Tensor, @@ -638,12 +648,8 @@ class Qwen2VisionTransformer(nn.Module): # transformers x = x.unsqueeze(1) - max_seqlen = None - seqlens = None - if self.attn_backend == _Backend.FLASH_ATTN: - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + # pre-compute seqlens for attn mask to reduce cuMemcpy operations + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) for blk in self.blocks: x = blk( x,