diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 30e3d2dff97b..c68115729c42 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -836,10 +836,8 @@ class Qwen2_5_VisionTransformer(nn.Module): self, cu_seqlens: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - max_seqlen, seqlens = ( - torch.zeros(1, device=cu_seqlens.device), - torch.zeros(1, device=cu_seqlens.device), - ) + max_seqlen = torch.zeros([], device=cu_seqlens.device) + seqlens = torch.zeros(1, device=cu_seqlens.device) if ( self.attn_backend == _Backend.FLASH_ATTN or self.attn_backend == _Backend.ROCM_AITER_FA diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index f3b6ad495db4..efcd003fbbda 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -223,8 +223,8 @@ class Qwen3_VisionBlock(nn.Module): x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: int | None = None, # Only used for Flash Attention - seqlens: list[int] | None = None, # Only used for xFormers + max_seqlen: torch.Tensor, # Only used for Flash Attention + seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), @@ -488,12 +488,13 @@ class Qwen3Omni_VisionTransformer(nn.Module): def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> tuple[int | None, list[int] | None]: - max_seqlen, seqlens = None, None + ) -> tuple[torch.Tensor, torch.Tensor]: + max_seqlen = torch.zeros([], device=cu_seqlens.device) + seqlens = torch.zeros(1, device=cu_seqlens.device) if self.attn_backend == _Backend.FLASH_ATTN: - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() elif self.attn_backend == _Backend.XFORMERS: - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens def forward( @@ -1114,6 +1115,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + self.vllm_config = vllm_config # needed for torch compile forward context thinker_config: Qwen3OmniMoeThinkerConfig = ( vllm_config.model_config.hf_config.thinker_config ) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 10c0eb4eb65e..d611580c7182 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -231,8 +231,8 @@ class Qwen3_VisionBlock(nn.Module): x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: int | None = None, # Only used for Flash Attention - seqlens: list[int] | None = None, # Only used for xFormers + max_seqlen: torch.Tensor, # Only used for Flash Attention + seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), @@ -512,15 +512,16 @@ class Qwen3_VisionTransformer(nn.Module): def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> tuple[int | None, list[int] | None]: - max_seqlen, seqlens = None, None + ) -> tuple[torch.Tensor, torch.Tensor]: + max_seqlen = torch.zeros([], device=cu_seqlens.device) + seqlens = torch.zeros(1, device=cu_seqlens.device) if ( self.attn_backend == _Backend.FLASH_ATTN or self.attn_backend == _Backend.ROCM_AITER_FA ): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() elif self.attn_backend == _Backend.XFORMERS: - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens def forward(