Make the _apply_rotary_emb compatible with dynamo (#17435)

This commit is contained in:
Lu Fang 2025-04-30 00:52:48 -07:00 committed by GitHub
parent 54072f315f
commit ece5a8b0b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,6 +32,9 @@ from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
@ -78,7 +81,6 @@ def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
positional embeddings.
"""
if current_platform.is_cuda_alike():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
not is_neox_style).squeeze(0)
else: