mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 16:04:27 +08:00
[CustomOp][Refactor] Extract common methods for ApplyRotaryEmb CustomOp (#31021)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
parent
b5545d9d5c
commit
23a1946e3b
@ -178,6 +178,37 @@ class ApplyRotaryEmb(CustomOp):
|
|||||||
output = output.to(origin_dtype)
|
output = output.to(origin_dtype)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _pre_process(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Size, torch.dtype]:
|
||||||
|
origin_shape = x.shape
|
||||||
|
if len(origin_shape) == 3:
|
||||||
|
# x: [seq_len, num_heads, head_size]
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
|
||||||
|
origin_dtype = x.dtype
|
||||||
|
if self.enable_fp32_compute:
|
||||||
|
x = x.float()
|
||||||
|
cos = cos.float()
|
||||||
|
sin = sin.float()
|
||||||
|
|
||||||
|
return x, cos, sin, origin_shape, origin_dtype
|
||||||
|
|
||||||
|
def _post_process(
|
||||||
|
self,
|
||||||
|
output: torch.Tensor,
|
||||||
|
origin_shape: torch.Size,
|
||||||
|
origin_dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if len(origin_shape) == 3:
|
||||||
|
output = output.squeeze(0)
|
||||||
|
if self.enable_fp32_compute:
|
||||||
|
output = output.to(origin_dtype)
|
||||||
|
return output
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -197,16 +228,7 @@ class ApplyRotaryEmb(CustomOp):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||||
|
|
||||||
origin_dtype = x.dtype
|
x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin)
|
||||||
if self.enable_fp32_compute:
|
|
||||||
x = x.float()
|
|
||||||
cos = cos.float()
|
|
||||||
sin = sin.float()
|
|
||||||
|
|
||||||
origin_shape = x.shape
|
|
||||||
if len(origin_shape) == 3:
|
|
||||||
# x: [seq_len, num_heads, head_size]
|
|
||||||
x = x.unsqueeze(0)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Arguments of apply_rotary_emb() in vllm_flash_attn:
|
Arguments of apply_rotary_emb() in vllm_flash_attn:
|
||||||
@ -218,10 +240,7 @@ class ApplyRotaryEmb(CustomOp):
|
|||||||
interleaved = not self.is_neox_style
|
interleaved = not self.is_neox_style
|
||||||
output = apply_rotary_emb(x, cos, sin, interleaved)
|
output = apply_rotary_emb(x, cos, sin, interleaved)
|
||||||
|
|
||||||
if len(origin_shape) == 3:
|
output = self._post_process(output, origin_shape, origin_dtype)
|
||||||
output = output.squeeze(0)
|
|
||||||
if self.enable_fp32_compute:
|
|
||||||
output = output.to(origin_dtype)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward_hip(
|
def forward_hip(
|
||||||
@ -231,16 +250,7 @@ class ApplyRotaryEmb(CustomOp):
|
|||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.apply_rotary_emb_flash_attn is not None:
|
if self.apply_rotary_emb_flash_attn is not None:
|
||||||
origin_dtype = x.dtype
|
x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin)
|
||||||
if self.enable_fp32_compute:
|
|
||||||
x = x.float()
|
|
||||||
cos = cos.float()
|
|
||||||
sin = sin.float()
|
|
||||||
|
|
||||||
origin_shape = x.shape
|
|
||||||
if len(origin_shape) == 3:
|
|
||||||
# x: [seq_len, num_heads, head_size]
|
|
||||||
x = x.unsqueeze(0)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Arguments of apply_rotary() in flash_attn:
|
Arguments of apply_rotary() in flash_attn:
|
||||||
@ -254,10 +264,7 @@ class ApplyRotaryEmb(CustomOp):
|
|||||||
x, cos, sin, interleaved=interleaved
|
x, cos, sin, interleaved=interleaved
|
||||||
).type_as(x)
|
).type_as(x)
|
||||||
|
|
||||||
if len(origin_shape) == 3:
|
output = self._post_process(output, origin_shape, origin_dtype)
|
||||||
output = output.squeeze(0)
|
|
||||||
if self.enable_fp32_compute:
|
|
||||||
output = output.to(origin_dtype)
|
|
||||||
else:
|
else:
|
||||||
# Falling back to PyTorch native implementation.
|
# Falling back to PyTorch native implementation.
|
||||||
output = self.forward_native(x, cos, sin)
|
output = self.forward_native(x, cos, sin)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user