mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 16:05:42 +08:00
[Model] Fix Qwen3VL and Qwen3Omni after torch.compile changes (#27705)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
d2c33c397a
commit
0d8161b075
@ -836,10 +836,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
max_seqlen, seqlens = (
|
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||||
torch.zeros(1, device=cu_seqlens.device),
|
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||||
torch.zeros(1, device=cu_seqlens.device),
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
self.attn_backend == _Backend.FLASH_ATTN
|
self.attn_backend == _Backend.FLASH_ATTN
|
||||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
or self.attn_backend == _Backend.ROCM_AITER_FA
|
||||||
|
|||||||
@ -223,8 +223,8 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor,
|
rotary_pos_emb: torch.Tensor,
|
||||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||||
seqlens: list[int] | None = None, # Only used for xFormers
|
seqlens: torch.Tensor, # Only used for xFormers
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = x + self.attn(
|
x = x + self.attn(
|
||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
@ -488,12 +488,13 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
|||||||
def compute_attn_mask_seqlen(
|
def compute_attn_mask_seqlen(
|
||||||
self,
|
self,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> tuple[int | None, list[int] | None]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
max_seqlen, seqlens = None, None
|
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||||
|
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
return max_seqlen, seqlens
|
return max_seqlen, seqlens
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -1114,6 +1115,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
|||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.vllm_config = vllm_config # needed for torch compile forward context
|
||||||
thinker_config: Qwen3OmniMoeThinkerConfig = (
|
thinker_config: Qwen3OmniMoeThinkerConfig = (
|
||||||
vllm_config.model_config.hf_config.thinker_config
|
vllm_config.model_config.hf_config.thinker_config
|
||||||
)
|
)
|
||||||
|
|||||||
@ -231,8 +231,8 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor,
|
rotary_pos_emb: torch.Tensor,
|
||||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||||
seqlens: list[int] | None = None, # Only used for xFormers
|
seqlens: torch.Tensor, # Only used for xFormers
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = x + self.attn(
|
x = x + self.attn(
|
||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
@ -512,15 +512,16 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
def compute_attn_mask_seqlen(
|
def compute_attn_mask_seqlen(
|
||||||
self,
|
self,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> tuple[int | None, list[int] | None]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
max_seqlen, seqlens = None, None
|
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||||
|
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||||
if (
|
if (
|
||||||
self.attn_backend == _Backend.FLASH_ATTN
|
self.attn_backend == _Backend.FLASH_ATTN
|
||||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
or self.attn_backend == _Backend.ROCM_AITER_FA
|
||||||
):
|
):
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
return max_seqlen, seqlens
|
return max_seqlen, seqlens
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user