[Model] Fix Qwen3VL and Qwen3Omni after torch.compile changes (#27705)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Lukas Geiger 2025-10-29 05:28:20 +00:00 committed by GitHub
parent d2c33c397a
commit 0d8161b075
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 16 deletions

View File

@ -836,10 +836,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
self,
cu_seqlens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen, seqlens = (
torch.zeros(1, device=cu_seqlens.device),
torch.zeros(1, device=cu_seqlens.device),
)
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA

View File

@ -223,8 +223,8 @@ class Qwen3_VisionBlock(nn.Module):
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@ -488,12 +488,13 @@ class Qwen3Omni_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
def forward(
@ -1114,6 +1115,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config # needed for torch compile forward context
thinker_config: Qwen3OmniMoeThinkerConfig = (
vllm_config.model_config.hf_config.thinker_config
)

View File

@ -231,8 +231,8 @@ class Qwen3_VisionBlock(nn.Module):
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@ -512,15 +512,16 @@ class Qwen3_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
def forward(