diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index ef3d28c8087d2..ae48c779481f7 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -255,10 +255,12 @@ class Qwen2_5_VisionAttention(nn.Module): return q, k, v def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -285,7 +287,6 @@ class Qwen2_5_VisionAttention(nn.Module): q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() output = flash_attn_varlen_func(q, k, v, @@ -321,7 +322,6 @@ class Qwen2_5_VisionAttention(nn.Module): from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, kv_seqlen=None, device=q.device) @@ -364,11 +364,20 @@ class Qwen2_5_VisionBlock(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp") - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb) + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens) + x = x + self.mlp(self.norm2(x)) return x @@ -528,6 +537,7 @@ class Qwen2_5_VisionTransformer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.merger", ) + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @property def dtype(self) -> torch.dtype: @@ -633,14 +643,25 @@ class Qwen2_5_VisionTransformer(nn.Module): # transformers hidden_states = hidden_states.unsqueeze(1) + + max_seqlen = None + seqlens = None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - hidden_states = blk(hidden_states, - cu_seqlens=cu_seqlens_now, - rotary_pos_emb=rotary_pos_emb) + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) # For Qwen2.5-VL-3B, float16 will overflow at last block # for long visual tokens sequences. diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index ac3d154dd881c..0e9fa7183c89a 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -303,10 +303,12 @@ class Qwen2VisionAttention(nn.Module): return q, k, v def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, 3 * head * head_dim] @@ -329,7 +331,6 @@ class Qwen2VisionAttention(nn.Module): q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() output = flash_attn_varlen_func(q, k, v, @@ -365,7 +366,6 @@ class Qwen2VisionAttention(nn.Module): from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, kv_seqlen=None, device=q.device) @@ -409,11 +409,22 @@ class Qwen2VisionBlock(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp") - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor) -> torch.Tensor: - x = x + self.attn(self.norm1(x), - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb) + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + x = x + self.mlp(self.norm2(x)) return x @@ -570,6 +581,7 @@ class Qwen2VisionTransformer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.merger", ) + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @property def dtype(self) -> torch.dtype: @@ -624,8 +636,21 @@ class Qwen2VisionTransformer(nn.Module): # transformers x = x.unsqueeze(1) + + max_seqlen = None + seqlens = None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() for blk in self.blocks: - x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + x = blk( + x, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) # adapter x = self.merger(x)