mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 16:26:07 +08:00
[Bugfix][ROCm] Fix ViT rotary embeddings for torch.compile compatibility on ROCm (#27748)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
7956b0c0bc
commit
b13a447546
@ -77,7 +77,11 @@ def dispatch_rotary_emb_function(
|
|||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
return apply_rotary_emb
|
return apply_rotary_emb
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
# if torch compile is not enabled
|
||||||
|
# use rotary embedding function from flash_attn package
|
||||||
|
# otherwise use the naive pytorch embedding implementation
|
||||||
|
# is faster when torch compile is enabled.
|
||||||
|
if current_platform.is_rocm() and not torch.compiler.is_compiling():
|
||||||
if find_spec("flash_attn") is not None:
|
if find_spec("flash_attn") is not None:
|
||||||
from flash_attn.ops.triton.rotary import apply_rotary
|
from flash_attn.ops.triton.rotary import apply_rotary
|
||||||
|
|
||||||
@ -87,10 +91,9 @@ def dispatch_rotary_emb_function(
|
|||||||
"flash_attn is not installed. Falling back to PyTorch "
|
"flash_attn is not installed. Falling back to PyTorch "
|
||||||
"implementation for rotary embeddings."
|
"implementation for rotary embeddings."
|
||||||
)
|
)
|
||||||
|
|
||||||
if default is not None:
|
if default is not None:
|
||||||
return default
|
return default
|
||||||
else:
|
|
||||||
return apply_rotary_emb_torch
|
return apply_rotary_emb_torch
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -370,7 +370,7 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
cu_seqlens_k=cu_seqlens,
|
cu_seqlens_k=cu_seqlens,
|
||||||
max_seqlen_q=max_seqlen,
|
max_seqlen_q=max_seqlen,
|
||||||
max_seqlen_k=max_seqlen,
|
max_seqlen_k=max_seqlen,
|
||||||
dropout_p=0,
|
dropout_p=0.0,
|
||||||
causal=False,
|
causal=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user