[Perf] enable flashinfer rotary_embedding custom ops in DeepSeek rotary (#30729)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
jiahanc 2025-12-18 11:31:18 -08:00 committed by GitHub
parent 889f8bb250
commit 53ad423f26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 2 deletions

View File

@ -38,7 +38,10 @@ class RotaryEmbeddingBase(CustomOp):
# and current_platform.is_cuda()
# and has_flashinfer()
# and self.head_size in [64, 128, 256, 512])
self.use_flashinfer = False
# Check if use_flashinfer is already set
if not hasattr(self, "use_flashinfer"):
self.use_flashinfer = False
cache = self._compute_cos_sin_cache()
if not self.use_flashinfer:

View File

@ -6,6 +6,7 @@ import math
import torch
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from .base import RotaryEmbeddingBase
from .common import (
@ -56,6 +57,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor
)
self.use_flashinfer = (
self.enabled()
and dtype in (torch.float16, torch.bfloat16)
and current_platform.is_cuda()
and has_flashinfer()
and head_size in [64, 128, 256, 512]
)
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
@ -162,4 +170,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
if self.use_flashinfer:
torch.ops.vllm.flashinfer_rotary_embedding(
torch.add(positions, offsets) if offsets is not None else positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
else:
return self.forward_native(positions, query, key, offsets)