From b13a44754674a0056d7c8113deb33ea858f6ef1c Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 4 Nov 2025 09:12:19 +0800 Subject: [PATCH] [Bugfix][ROCm] Fix ViT rotary embeddings for torch.compile compatibility on ROCm (#27748) Signed-off-by: vllmellm --- vllm/model_executor/layers/rotary_embedding/common.py | 11 +++++++---- vllm/model_executor/models/glm4_1v.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 9e6ec9fdd523..196533b61795 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -77,7 +77,11 @@ def dispatch_rotary_emb_function( if current_platform.is_cuda(): return apply_rotary_emb - if current_platform.is_rocm(): + # if torch compile is not enabled + # use rotary embedding function from flash_attn package + # otherwise use the naive pytorch embedding implementation + # is faster when torch compile is enabled. + if current_platform.is_rocm() and not torch.compiler.is_compiling(): if find_spec("flash_attn") is not None: from flash_attn.ops.triton.rotary import apply_rotary @@ -87,11 +91,10 @@ def dispatch_rotary_emb_function( "flash_attn is not installed. Falling back to PyTorch " "implementation for rotary embeddings." ) - if default is not None: return default - else: - return apply_rotary_emb_torch + + return apply_rotary_emb_torch # yarn functions diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 3e243385fd04..121e84469c52 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -370,7 +370,7 @@ class Glm4vVisionAttention(nn.Module): cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, - dropout_p=0, + dropout_p=0.0, causal=False, )