[Model] Fix Qwen3VL and Qwen3Omni after torch.compile changes (#27705)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Lukas Geiger 2025-10-29 05:28:20 +00:00 committed by GitHub
parent d2c33c397a
commit 0d8161b075
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 16 deletions

View File

@ -836,10 +836,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
self, self,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen, seqlens = ( max_seqlen = torch.zeros([], device=cu_seqlens.device)
torch.zeros(1, device=cu_seqlens.device), seqlens = torch.zeros(1, device=cu_seqlens.device)
torch.zeros(1, device=cu_seqlens.device),
)
if ( if (
self.attn_backend == _Backend.FLASH_ATTN self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA or self.attn_backend == _Backend.ROCM_AITER_FA

View File

@ -223,8 +223,8 @@ class Qwen3_VisionBlock(nn.Module):
x: torch.Tensor, x: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor, rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor: ) -> torch.Tensor:
x = x + self.attn( x = x + self.attn(
self.norm1(x), self.norm1(x),
@ -488,12 +488,13 @@ class Qwen3Omni_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen( def compute_attn_mask_seqlen(
self, self,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> tuple[int | None, list[int] | None]: ) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen, seqlens = None, None max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == _Backend.FLASH_ATTN: if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens return max_seqlen, seqlens
def forward( def forward(
@ -1114,6 +1115,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.vllm_config = vllm_config # needed for torch compile forward context
thinker_config: Qwen3OmniMoeThinkerConfig = ( thinker_config: Qwen3OmniMoeThinkerConfig = (
vllm_config.model_config.hf_config.thinker_config vllm_config.model_config.hf_config.thinker_config
) )

View File

@ -231,8 +231,8 @@ class Qwen3_VisionBlock(nn.Module):
x: torch.Tensor, x: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor, rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor: ) -> torch.Tensor:
x = x + self.attn( x = x + self.attn(
self.norm1(x), self.norm1(x),
@ -512,15 +512,16 @@ class Qwen3_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen( def compute_attn_mask_seqlen(
self, self,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> tuple[int | None, list[int] | None]: ) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen, seqlens = None, None max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if ( if (
self.attn_backend == _Backend.FLASH_ATTN self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA or self.attn_backend == _Backend.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens return max_seqlen, seqlens
def forward( def forward(