[BugFix][QWEN-VL]fix wrong apply_rotary_emb_torch selection introduced by #24642 (#26123)

Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Chendi.Xue 2025-10-03 10:52:26 -05:00 committed by yewentao256
parent 84135b1489
commit e45271b09c
2 changed files with 10 additions and 4 deletions

View File

@ -4,7 +4,7 @@
import math
from functools import cache
from importlib.util import find_spec
from typing import Callable
from typing import Callable, Optional
import torch
@ -72,7 +72,9 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
@cache
def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]:
def dispatch_rotary_emb_function(
default: Optional[Callable[..., torch.Tensor]] = None
) -> Callable[..., torch.Tensor]:
if current_platform.is_cuda():
return apply_rotary_emb
@ -85,7 +87,10 @@ def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]:
"flash_attn is not installed. Falling back to PyTorch "
"implementation for rotary embeddings.")
return apply_rotary_emb_torch
if default is not None:
return default
else:
return apply_rotary_emb_torch
# yarn functions

View File

@ -276,7 +276,8 @@ def apply_rotary_emb_torch(x: torch.Tensor,
def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function()
rotary_emb_function = dispatch_rotary_emb_function(
default=apply_rotary_emb_torch)
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()