From b22980a1dc894d98de64119af4a580881f16b2f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?yexin=28=E5=8F=B6=E9=91=AB=29?= Date: Fri, 25 Apr 2025 14:52:28 +0800 Subject: [PATCH] [Perf]Optimize rotary_emb implementation to use Triton operator for improved inference performance (#16457) Signed-off-by: cynthieye Co-authored-by: MagnetoWang --- cmake/external_projects/vllm_flash_attn.cmake | 2 +- .../model_executor/layers/rotary_embedding.py | 34 +++++++++++++------ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index c60a547455193..b04e4c2d06edc 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG e93779c59ba4905e56e5c39dc2c1904ada71fa21 + GIT_TAG 8798f27777fb57f447070301bf33a9f9c607f491 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index e6f2461eb674c..6f4d796fb0406 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: return x.flatten(-2) -def _apply_rotary_emb( +def _apply_rotary_emb_torch( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool, ) -> torch.Tensor: - """ - Args: - x: [num_tokens, num_heads, head_size] - cos: [num_tokens, head_size // 2] - sin: [num_tokens, head_size // 2] - is_neox_style: Whether to use the Neox-style or GPT-J-style rotary - positional embeddings. - """ cos = cos.unsqueeze(-2).to(x.dtype) sin = sin.unsqueeze(-2).to(x.dtype) if is_neox_style: @@ -75,6 +67,24 @@ def _apply_rotary_emb( return torch.stack((o1, o2), dim=-1).flatten(-2) +def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, + is_neox_style: bool) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + if current_platform.is_cuda_alike(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + return apply_rotary_emb(x.unsqueeze(0), cos, sin, + not is_neox_style).squeeze(0) + else: + return _apply_rotary_emb_torch(x, cos, sin, is_neox_style) + + @CustomOp.register("rotary_embedding") class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" @@ -141,14 +151,16 @@ class RotaryEmbedding(CustomOp): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] - query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query_rot = _apply_rotary_emb_torch(query_rot, cos, sin, + self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, + self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key