diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 4f02c996bda14..0d11d1ffea9f5 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -4,7 +4,7 @@ import math from functools import cache from importlib.util import find_spec -from typing import Callable +from typing import Callable, Optional import torch @@ -72,7 +72,9 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor, @cache -def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]: +def dispatch_rotary_emb_function( + default: Optional[Callable[..., torch.Tensor]] = None +) -> Callable[..., torch.Tensor]: if current_platform.is_cuda(): return apply_rotary_emb @@ -85,7 +87,10 @@ def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]: "flash_attn is not installed. Falling back to PyTorch " "implementation for rotary embeddings.") - return apply_rotary_emb_torch + if default is not None: + return default + else: + return apply_rotary_emb_torch # yarn functions diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 6f15a7f4ef380..ab9bfe4d0f191 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -276,7 +276,8 @@ def apply_rotary_emb_torch(x: torch.Tensor, def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - rotary_emb_function = dispatch_rotary_emb_function() + rotary_emb_function = dispatch_rotary_emb_function( + default=apply_rotary_emb_torch) t_ = t.float() cos = freqs.cos() sin = freqs.sin()