mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:25:01 +08:00
[Perf]Optimize rotary_emb implementation to use Triton operator for improved inference performance (#16457)
Signed-off-by: cynthieye <yexin93@qq.com> Co-authored-by: MagnetoWang <magnetowang@outlook.com>
This commit is contained in:
parent
881f735827
commit
b22980a1dc
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG e93779c59ba4905e56e5c39dc2c1904ada71fa21
|
||||
GIT_TAG 8798f27777fb57f447070301bf33a9f9c607f491
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.flatten(-2)
|
||||
|
||||
|
||||
def _apply_rotary_emb(
|
||||
def _apply_rotary_emb_torch(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
@ -75,6 +67,24 @@ def _apply_rotary_emb(
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
|
||||
is_neox_style: bool) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
|
||||
not is_neox_style).squeeze(0)
|
||||
else:
|
||||
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding")
|
||||
class RotaryEmbedding(CustomOp):
|
||||
"""Original rotary positional embedding."""
|
||||
@ -141,14 +151,16 @@ class RotaryEmbedding(CustomOp):
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query_rot = _apply_rotary_emb_torch(query_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user