mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 14:15:01 +08:00
[CustomOp][Refactor] Extract common methods for ApplyRotaryEmb CustomOp (#31021)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
parent
b5545d9d5c
commit
23a1946e3b
@ -178,6 +178,37 @@ class ApplyRotaryEmb(CustomOp):
|
||||
output = output.to(origin_dtype)
|
||||
return output
|
||||
|
||||
def _pre_process(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Size, torch.dtype]:
|
||||
origin_shape = x.shape
|
||||
if len(origin_shape) == 3:
|
||||
# x: [seq_len, num_heads, head_size]
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
origin_dtype = x.dtype
|
||||
if self.enable_fp32_compute:
|
||||
x = x.float()
|
||||
cos = cos.float()
|
||||
sin = sin.float()
|
||||
|
||||
return x, cos, sin, origin_shape, origin_dtype
|
||||
|
||||
def _post_process(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
origin_shape: torch.Size,
|
||||
origin_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
if len(origin_shape) == 3:
|
||||
output = output.squeeze(0)
|
||||
if self.enable_fp32_compute:
|
||||
output = output.to(origin_dtype)
|
||||
return output
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -197,16 +228,7 @@ class ApplyRotaryEmb(CustomOp):
|
||||
) -> torch.Tensor:
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
origin_dtype = x.dtype
|
||||
if self.enable_fp32_compute:
|
||||
x = x.float()
|
||||
cos = cos.float()
|
||||
sin = sin.float()
|
||||
|
||||
origin_shape = x.shape
|
||||
if len(origin_shape) == 3:
|
||||
# x: [seq_len, num_heads, head_size]
|
||||
x = x.unsqueeze(0)
|
||||
x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin)
|
||||
|
||||
"""
|
||||
Arguments of apply_rotary_emb() in vllm_flash_attn:
|
||||
@ -218,10 +240,7 @@ class ApplyRotaryEmb(CustomOp):
|
||||
interleaved = not self.is_neox_style
|
||||
output = apply_rotary_emb(x, cos, sin, interleaved)
|
||||
|
||||
if len(origin_shape) == 3:
|
||||
output = output.squeeze(0)
|
||||
if self.enable_fp32_compute:
|
||||
output = output.to(origin_dtype)
|
||||
output = self._post_process(output, origin_shape, origin_dtype)
|
||||
return output
|
||||
|
||||
def forward_hip(
|
||||
@ -231,16 +250,7 @@ class ApplyRotaryEmb(CustomOp):
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.apply_rotary_emb_flash_attn is not None:
|
||||
origin_dtype = x.dtype
|
||||
if self.enable_fp32_compute:
|
||||
x = x.float()
|
||||
cos = cos.float()
|
||||
sin = sin.float()
|
||||
|
||||
origin_shape = x.shape
|
||||
if len(origin_shape) == 3:
|
||||
# x: [seq_len, num_heads, head_size]
|
||||
x = x.unsqueeze(0)
|
||||
x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin)
|
||||
|
||||
"""
|
||||
Arguments of apply_rotary() in flash_attn:
|
||||
@ -254,10 +264,7 @@ class ApplyRotaryEmb(CustomOp):
|
||||
x, cos, sin, interleaved=interleaved
|
||||
).type_as(x)
|
||||
|
||||
if len(origin_shape) == 3:
|
||||
output = output.squeeze(0)
|
||||
if self.enable_fp32_compute:
|
||||
output = output.to(origin_dtype)
|
||||
output = self._post_process(output, origin_shape, origin_dtype)
|
||||
else:
|
||||
# Falling back to PyTorch native implementation.
|
||||
output = self.forward_native(x, cos, sin)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user