mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:45:36 +08:00
[Perf]Optimize rotary_emb implementation to use Triton operator for improved inference performance (#16457)
Signed-off-by: cynthieye <yexin93@qq.com> Co-authored-by: MagnetoWang <magnetowang@outlook.com>
This commit is contained in:
parent
881f735827
commit
b22980a1dc
@ -38,7 +38,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
vllm-flash-attn
|
vllm-flash-attn
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||||
GIT_TAG e93779c59ba4905e56e5c39dc2c1904ada71fa21
|
GIT_TAG 8798f27777fb57f447070301bf33a9f9c607f491
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
# Don't share the vllm-flash-attn build between build types
|
# Don't share the vllm-flash-attn build between build types
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
|
|||||||
@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
|||||||
return x.flatten(-2)
|
return x.flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
def _apply_rotary_emb(
|
def _apply_rotary_emb_torch(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: [num_tokens, num_heads, head_size]
|
|
||||||
cos: [num_tokens, head_size // 2]
|
|
||||||
sin: [num_tokens, head_size // 2]
|
|
||||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
|
||||||
positional embeddings.
|
|
||||||
"""
|
|
||||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||||
if is_neox_style:
|
if is_neox_style:
|
||||||
@ -75,6 +67,24 @@ def _apply_rotary_emb(
|
|||||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
|
||||||
|
is_neox_style: bool) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [num_tokens, num_heads, head_size]
|
||||||
|
cos: [num_tokens, head_size // 2]
|
||||||
|
sin: [num_tokens, head_size // 2]
|
||||||
|
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||||
|
positional embeddings.
|
||||||
|
"""
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||||
|
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
|
||||||
|
not is_neox_style).squeeze(0)
|
||||||
|
else:
|
||||||
|
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("rotary_embedding")
|
@CustomOp.register("rotary_embedding")
|
||||||
class RotaryEmbedding(CustomOp):
|
class RotaryEmbedding(CustomOp):
|
||||||
"""Original rotary positional embedding."""
|
"""Original rotary positional embedding."""
|
||||||
@ -141,14 +151,16 @@ class RotaryEmbedding(CustomOp):
|
|||||||
query = query.view(num_tokens, -1, self.head_size)
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
query_rot = query[..., :self.rotary_dim]
|
query_rot = query[..., :self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim:]
|
query_pass = query[..., self.rotary_dim:]
|
||||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
query_rot = _apply_rotary_emb_torch(query_rot, cos, sin,
|
||||||
|
self.is_neox_style)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
key_shape = key.shape
|
key_shape = key.shape
|
||||||
key = key.view(num_tokens, -1, self.head_size)
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
key_rot = key[..., :self.rotary_dim]
|
key_rot = key[..., :self.rotary_dim]
|
||||||
key_pass = key[..., self.rotary_dim:]
|
key_pass = key[..., self.rotary_dim:]
|
||||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
|
||||||
|
self.is_neox_style)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user