[BugFix] Add fallback path in apply_rotary_pos_emb_flashattn for non-cuda platforms (#28447)

Signed-off-by: Lin, Fanli <fanli.lin@intel.com>
This commit is contained in:
Fanli Lin 2025-11-12 11:13:21 +08:00 committed by GitHub
parent 4ccffe561f
commit b9ce9a3013
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -346,6 +346,13 @@ def apply_rotary_pos_emb_flashatt(
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
elif current_platform.is_rocm(): elif current_platform.is_rocm():
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
else:
# For other platforms, use PyTorch fallback
from vllm.model_executor.layers.rotary_embedding.common import (
apply_rotary_emb_torch,
)
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)