[Perf]Optimize rotary_emb implementation to use Triton operator for improved inference performance (#16457)

Signed-off-by: cynthieye <yexin93@qq.com>
Co-authored-by: MagnetoWang <magnetowang@outlook.com>
This commit is contained in:
yexin(叶鑫) 2025-04-25 14:52:28 +08:00 committed by GitHub
parent 881f735827
commit b22980a1dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 12 deletions

View File

@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG e93779c59ba4905e56e5c39dc2c1904ada71fa21
GIT_TAG 8798f27777fb57f447070301bf33a9f9c607f491
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2)
def _apply_rotary_emb(
def _apply_rotary_emb_torch(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
@ -75,6 +67,24 @@ def _apply_rotary_emb(
return torch.stack((o1, o2), dim=-1).flatten(-2)
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
is_neox_style: bool) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if current_platform.is_cuda_alike():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
not is_neox_style).squeeze(0)
else:
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""
@ -141,14 +151,16 @@ class RotaryEmbedding(CustomOp):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query_rot = _apply_rotary_emb_torch(query_rot, cos, sin,
self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key