mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 00:54:00 +08:00
[Perf] enable flashinfer rotary_embedding custom ops in DeepSeek rotary (#30729)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
parent
889f8bb250
commit
53ad423f26
@ -38,7 +38,10 @@ class RotaryEmbeddingBase(CustomOp):
|
||||
# and current_platform.is_cuda()
|
||||
# and has_flashinfer()
|
||||
# and self.head_size in [64, 128, 256, 512])
|
||||
self.use_flashinfer = False
|
||||
|
||||
# Check if use_flashinfer is already set
|
||||
if not hasattr(self, "use_flashinfer"):
|
||||
self.use_flashinfer = False
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
if not self.use_flashinfer:
|
||||
|
||||
@ -6,6 +6,7 @@ import math
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
from .base import RotaryEmbeddingBase
|
||||
from .common import (
|
||||
@ -56,6 +57,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
|
||||
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
|
||||
* attn_factor
|
||||
)
|
||||
self.use_flashinfer = (
|
||||
self.enabled()
|
||||
and dtype in (torch.float16, torch.bfloat16)
|
||||
and current_platform.is_cuda()
|
||||
and has_flashinfer()
|
||||
and head_size in [64, 128, 256, 512]
|
||||
)
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
@ -162,4 +170,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
if self.use_flashinfer:
|
||||
torch.ops.vllm.flashinfer_rotary_embedding(
|
||||
torch.add(positions, offsets) if offsets is not None else positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
else:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user