From e45271b09c714946ff0d52a1254dd90a8ea2e323 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Fri, 3 Oct 2025 10:52:26 -0500 Subject: [PATCH] [BugFix][QWEN-VL]fix wrong apply_rotary_emb_torch selection introduced by #24642 (#26123) Signed-off-by: Chendi Xue Signed-off-by: Chendi.Xue Co-authored-by: Roger Wang Signed-off-by: yewentao256 --- vllm/model_executor/layers/rotary_embedding/common.py | 11 ++++++++--- vllm/model_executor/models/qwen2_vl.py | 3 ++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 4f02c996bda14..0d11d1ffea9f5 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -4,7 +4,7 @@ import math from functools import cache from importlib.util import find_spec -from typing import Callable +from typing import Callable, Optional import torch @@ -72,7 +72,9 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor, @cache -def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]: +def dispatch_rotary_emb_function( + default: Optional[Callable[..., torch.Tensor]] = None +) -> Callable[..., torch.Tensor]: if current_platform.is_cuda(): return apply_rotary_emb @@ -85,7 +87,10 @@ def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]: "flash_attn is not installed. Falling back to PyTorch " "implementation for rotary embeddings.") - return apply_rotary_emb_torch + if default is not None: + return default + else: + return apply_rotary_emb_torch # yarn functions diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 6f15a7f4ef380..ab9bfe4d0f191 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -276,7 +276,8 @@ def apply_rotary_emb_torch(x: torch.Tensor, def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - rotary_emb_function = dispatch_rotary_emb_function() + rotary_emb_function = dispatch_rotary_emb_function( + default=apply_rotary_emb_torch) t_ = t.float() cos = freqs.cos() sin = freqs.sin()