mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 23:05:59 +08:00
[Perf] Use FlashInfer RoPE for RotaryEmbedding.forward_cuda when available (#21126)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
e57fc15971
commit
48ecb4438b
@ -6,6 +6,8 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
from .common import apply_rotary_emb_torch
|
||||
|
||||
@ -30,9 +32,17 @@ class RotaryEmbedding(CustomOp):
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
# Flashinfer only supports head_size=64, 128, 256, 512.
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
|
||||
self.use_flashinfer = (self.enabled()
|
||||
and dtype in (torch.float16, torch.bfloat16)
|
||||
and current_platform.is_cuda()
|
||||
and has_flashinfer()
|
||||
and self.head_size in [64, 128, 256, 512])
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
if not self.use_flashinfer:
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
@ -57,6 +67,14 @@ class RotaryEmbedding(CustomOp):
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
|
||||
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
|
||||
# is expensive, so avoid calling it if possible
|
||||
if self.cos_sin_cache.device != query.device or \
|
||||
self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@ -94,15 +112,16 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if self.use_flashinfer:
|
||||
torch.ops.vllm.flashinfer_rotary_embedding(positions, query, key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
|
||||
# is expensive, so avoid calling it if possible
|
||||
if self.cos_sin_cache.device != query.device or \
|
||||
self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
# ops.rotary_embedding() is an in-place operation
|
||||
# that updates the query and key tensors.
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
@ -117,8 +136,7 @@ class RotaryEmbedding(CustomOp):
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
||||
dtype=query.dtype)
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
# ops.rotary_embedding() is an in-place operation
|
||||
# that updates the query and key tensors.
|
||||
if key is None:
|
||||
|
||||
@ -6,6 +6,7 @@ import math
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
@ -103,3 +104,48 @@ def yarn_get_mscale(scale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def _flashinfer_rotary_embedding(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
"""Custom op wrapper for flashinfer's rotary embedding.
|
||||
|
||||
This is an in-place operation that modifies query and key tensors directly.
|
||||
"""
|
||||
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
|
||||
|
||||
apply_rope_with_cos_sin_cache_inplace(
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
head_size=head_size,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=is_neox,
|
||||
)
|
||||
|
||||
|
||||
def _flashinfer_rotary_embedding_fake(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
# Register flashinfer rotary embedding custom op
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_rotary_embedding",
|
||||
op_func=_flashinfer_rotary_embedding,
|
||||
mutates_args=["query", "key"], # These tensors are modified in-place
|
||||
fake_impl=_flashinfer_rotary_embedding_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -97,15 +97,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
assert key is not None
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
if self.cos_sin_cache.device != positions.device:
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
||||
positions.device)
|
||||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||
if offsets is not None else positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
@ -59,7 +59,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert key is not None
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
query_ = torch.view_as_complex(query.float().reshape(
|
||||
*query.shape[:-1], -1, 2))
|
||||
key_ = torch.view_as_complex(key.float().reshape(
|
||||
|
||||
@ -245,6 +245,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
@ -293,6 +294,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user