[Perf] Use FlashInfer RoPE for RotaryEmbedding.forward_cuda when available (#21126)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Michael Goin 2025-09-19 16:06:49 -04:00 committed by GitHub
parent e57fc15971
commit 48ecb4438b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 78 additions and 14 deletions

View File

@ -6,6 +6,8 @@ from typing import Optional
import torch
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from .common import apply_rotary_emb_torch
@ -30,9 +32,17 @@ class RotaryEmbedding(CustomOp):
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
# Flashinfer only supports head_size=64, 128, 256, 512.
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
self.use_flashinfer = (self.enabled()
and dtype in (torch.float16, torch.bfloat16)
and current_platform.is_cuda()
and has_flashinfer()
and self.head_size in [64, 128, 256, 512])
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
if not self.use_flashinfer:
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
@ -57,6 +67,14 @@ class RotaryEmbedding(CustomOp):
cache = torch.cat((cos, sin), dim=-1)
return cache
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
# is expensive, so avoid calling it if possible
if self.cos_sin_cache.device != query.device or \
self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
def forward_native(
self,
positions: torch.Tensor,
@ -94,15 +112,16 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.use_flashinfer:
torch.ops.vllm.flashinfer_rotary_embedding(positions, query, key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style)
return query, key
from vllm import _custom_ops as ops
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
# is expensive, so avoid calling it if possible
if self.cos_sin_cache.device != query.device or \
self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
self._match_cos_sin_cache_dtype(query)
# ops.rotary_embedding() is an in-place operation
# that updates the query and key tensors.
ops.rotary_embedding(positions, query, key, self.head_size,
@ -117,8 +136,7 @@ class RotaryEmbedding(CustomOp):
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
from vllm._ipex_ops import ipex_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
dtype=query.dtype)
self._match_cos_sin_cache_dtype(query)
# ops.rotary_embedding() is an in-place operation
# that updates the query and key tensors.
if key is None:

View File

@ -6,6 +6,7 @@ import math
import torch
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
@ -103,3 +104,48 @@ def yarn_get_mscale(scale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
def _flashinfer_rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
"""Custom op wrapper for flashinfer's rotary embedding.
This is an in-place operation that modifies query and key tensors directly.
"""
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
head_size=head_size,
cos_sin_cache=cos_sin_cache,
is_neox=is_neox,
)
def _flashinfer_rotary_embedding_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
return
# Register flashinfer rotary embedding custom op
direct_register_custom_op(
op_name="flashinfer_rotary_embedding",
op_func=_flashinfer_rotary_embedding,
mutates_args=["query", "key"], # These tensors are modified in-place
fake_impl=_flashinfer_rotary_embedding_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -97,15 +97,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
self._match_cos_sin_cache_dtype(query)
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
if self.cos_sin_cache.device != positions.device:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)

View File

@ -59,7 +59,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert key is not None
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
self._match_cos_sin_cache_dtype(query)
query_ = torch.view_as_complex(query.float().reshape(
*query.shape[:-1], -1, 2))
key_ = torch.view_as_complex(key.float().reshape(

View File

@ -245,6 +245,7 @@ class MRotaryEmbedding(RotaryEmbedding):
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
self._match_cos_sin_cache_dtype(query)
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
@ -293,6 +294,7 @@ class MRotaryEmbedding(RotaryEmbedding):
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
self._match_cos_sin_cache_dtype(query)
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)