mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
[BugFix] Add fallback path in apply_rotary_pos_emb_flashattn for non-cuda platforms (#28447)
Signed-off-by: Lin, Fanli <fanli.lin@intel.com>
This commit is contained in:
parent
4ccffe561f
commit
b9ce9a3013
@ -346,6 +346,13 @@ def apply_rotary_pos_emb_flashatt(
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
elif current_platform.is_rocm():
|
||||
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
|
||||
else:
|
||||
# For other platforms, use PyTorch fallback
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
apply_rotary_emb_torch,
|
||||
)
|
||||
|
||||
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
|
||||
|
||||
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
||||
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user