diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index afa69324c4e2e..7e83ea9a1355b 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -38,7 +38,10 @@ class RotaryEmbeddingBase(CustomOp): # and current_platform.is_cuda() # and has_flashinfer() # and self.head_size in [64, 128, 256, 512]) - self.use_flashinfer = False + + # Check if use_flashinfer is already set + if not hasattr(self, "use_flashinfer"): + self.use_flashinfer = False cache = self._compute_cos_sin_cache() if not self.use_flashinfer: diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index e72834e473c15..8402b65efcc04 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -6,6 +6,7 @@ import math import torch from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer from .base import RotaryEmbeddingBase from .common import ( @@ -56,6 +57,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase): / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor ) + self.use_flashinfer = ( + self.enabled() + and dtype in (torch.float16, torch.bfloat16) + and current_platform.is_cuda() + and has_flashinfer() + and head_size in [64, 128, 256, 512] + ) super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) @@ -162,4 +170,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase): key: torch.Tensor | None = None, offsets: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - return self.forward_native(positions, query, key, offsets) + if self.use_flashinfer: + torch.ops.vllm.flashinfer_rotary_embedding( + torch.add(positions, offsets) if offsets is not None else positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key + else: + return self.forward_native(positions, query, key, offsets)