mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 03:56:07 +08:00
[Misc] Clean up RoPE forward_native (#8076)
This commit is contained in:
parent
1afc931987
commit
4624d98dbd
@ -28,7 +28,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
|
|
||||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -48,21 +47,29 @@ def _apply_rotary_emb(
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
|
is_neox_style: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: [num_tokens, num_heads, head_size]
|
x: [num_tokens, num_heads, head_size]
|
||||||
cos: [num_tokens, head_size // 2]
|
cos: [num_tokens, head_size // 2]
|
||||||
sin: [num_tokens, head_size // 2]
|
sin: [num_tokens, head_size // 2]
|
||||||
|
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||||
|
positional embeddings.
|
||||||
"""
|
"""
|
||||||
orig_dtype = x.dtype
|
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||||
x = x.float()
|
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
if is_neox_style:
|
||||||
cos = cos.unsqueeze(-2)
|
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||||
sin = sin.unsqueeze(-2)
|
else:
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
o1 = x1 * cos - x2 * sin
|
o1 = x1 * cos - x2 * sin
|
||||||
o2 = x2 * cos + x1 * sin
|
o2 = x2 * cos + x1 * sin
|
||||||
return torch.cat((o1, o2), dim=-1).to(orig_dtype)
|
if is_neox_style:
|
||||||
|
return torch.cat((o1, o2), dim=-1)
|
||||||
|
else:
|
||||||
|
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(CustomOp):
|
class RotaryEmbedding(CustomOp):
|
||||||
@ -87,10 +94,9 @@ class RotaryEmbedding(CustomOp):
|
|||||||
|
|
||||||
cache = self._compute_cos_sin_cache()
|
cache = self._compute_cos_sin_cache()
|
||||||
cache = cache.to(dtype)
|
cache = cache.to(dtype)
|
||||||
|
self.cos_sin_cache: torch.Tensor
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
self.use_native2 = current_platform.is_tpu() and is_neox_style
|
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
"""Compute the inverse frequency."""
|
"""Compute the inverse frequency."""
|
||||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||||
@ -119,59 +125,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""A PyTorch-native implementation equivalent to forward().
|
"""A PyTorch-native implementation of forward()."""
|
||||||
|
|
||||||
This method mimics the implementation of the custom CUDA kernel
|
|
||||||
used in `forward_cuda()`.
|
|
||||||
"""
|
|
||||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
|
||||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
|
||||||
|
|
||||||
query_rot = query[..., :self.rotary_dim]
|
|
||||||
key_rot = key[..., :self.rotary_dim]
|
|
||||||
if self.rotary_dim < self.head_size:
|
|
||||||
query_pass = query[..., self.rotary_dim:]
|
|
||||||
key_pass = key[..., self.rotary_dim:]
|
|
||||||
|
|
||||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
|
||||||
positions.device, dtype=query.dtype)
|
|
||||||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
|
||||||
if offsets is not None else positions]
|
|
||||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
|
||||||
if self.is_neox_style:
|
|
||||||
# NOTE(woosuk): Here we assume that the positions tensor has the
|
|
||||||
# shape [batch_size, seq_len].
|
|
||||||
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
|
||||||
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
|
||||||
else:
|
|
||||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
|
||||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
|
||||||
|
|
||||||
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
|
||||||
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
|
||||||
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
|
||||||
|
|
||||||
if self.rotary_dim < self.head_size:
|
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
|
||||||
else:
|
|
||||||
query = query_rot
|
|
||||||
key = key_rot
|
|
||||||
query = query.flatten(-2)
|
|
||||||
key = key.flatten(-2)
|
|
||||||
return query, key
|
|
||||||
|
|
||||||
def forward_native2(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Another PyTorch-native implementation of forward().
|
|
||||||
|
|
||||||
This method might perform better than `forward_native()` when compiled.
|
|
||||||
"""
|
|
||||||
if offsets is not None:
|
if offsets is not None:
|
||||||
positions = positions + offsets
|
positions = positions + offsets
|
||||||
positions = positions.flatten()
|
positions = positions.flatten()
|
||||||
@ -183,14 +137,14 @@ class RotaryEmbedding(CustomOp):
|
|||||||
query = query.view(num_tokens, -1, self.head_size)
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
query_rot = query[..., :self.rotary_dim]
|
query_rot = query[..., :self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim:]
|
query_pass = query[..., self.rotary_dim:]
|
||||||
query_rot = _apply_rotary_emb(query_rot, cos, sin)
|
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
key_shape = key.shape
|
key_shape = key.shape
|
||||||
key = key.view(num_tokens, -1, self.head_size)
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
key_rot = key[..., :self.rotary_dim]
|
key_rot = key[..., :self.rotary_dim]
|
||||||
key_pass = key[..., self.rotary_dim:]
|
key_pass = key[..., self.rotary_dim:]
|
||||||
key_rot = _apply_rotary_emb(key_rot, cos, sin)
|
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
@ -203,7 +157,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||||
dtype=query.dtype)
|
dtype=query.dtype)
|
||||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||||
# are in-place operations that update the query and key tensors.
|
# are in-place operations that update the query and key tensors.
|
||||||
@ -240,17 +194,6 @@ class RotaryEmbedding(CustomOp):
|
|||||||
self.cos_sin_cache, self.is_neox_style)
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
def forward_tpu(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
forward_fn = (self.forward_native2
|
|
||||||
if self.use_native2 else self.forward_native)
|
|
||||||
return forward_fn(positions, query, key, offsets)
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user