[CustomOp][Refactor] Extract common methods for ApplyRotaryEmb CustomOp (#31021)

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen 2025-12-19 22:16:09 +08:00 committed by GitHub
parent b5545d9d5c
commit 23a1946e3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -178,6 +178,37 @@ class ApplyRotaryEmb(CustomOp):
output = output.to(origin_dtype)
return output
def _pre_process(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Size, torch.dtype]:
origin_shape = x.shape
if len(origin_shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
return x, cos, sin, origin_shape, origin_dtype
def _post_process(
self,
output: torch.Tensor,
origin_shape: torch.Size,
origin_dtype: torch.dtype,
) -> torch.Tensor:
if len(origin_shape) == 3:
output = output.squeeze(0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
return output
def forward_native(
self,
x: torch.Tensor,
@ -197,16 +228,7 @@ class ApplyRotaryEmb(CustomOp):
) -> torch.Tensor:
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
origin_shape = x.shape
if len(origin_shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin)
"""
Arguments of apply_rotary_emb() in vllm_flash_attn:
@ -218,10 +240,7 @@ class ApplyRotaryEmb(CustomOp):
interleaved = not self.is_neox_style
output = apply_rotary_emb(x, cos, sin, interleaved)
if len(origin_shape) == 3:
output = output.squeeze(0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
output = self._post_process(output, origin_shape, origin_dtype)
return output
def forward_hip(
@ -231,16 +250,7 @@ class ApplyRotaryEmb(CustomOp):
sin: torch.Tensor,
) -> torch.Tensor:
if self.apply_rotary_emb_flash_attn is not None:
origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
origin_shape = x.shape
if len(origin_shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin)
"""
Arguments of apply_rotary() in flash_attn:
@ -254,10 +264,7 @@ class ApplyRotaryEmb(CustomOp):
x, cos, sin, interleaved=interleaved
).type_as(x)
if len(origin_shape) == 3:
output = output.squeeze(0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
output = self._post_process(output, origin_shape, origin_dtype)
else:
# Falling back to PyTorch native implementation.
output = self.forward_native(x, cos, sin)