From 0825197bee8dea547f2ab25f48afd8aea0cd2578 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Thu, 23 Oct 2025 13:43:53 -0400 Subject: [PATCH] [Bugfix][ROCm][DeepSeek] Fix for forward_hip in rope for DeepSeek (#27373) Signed-off-by: Gregory Shtrasberg --- vllm/model_executor/layers/rotary_embedding/base.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 17cd39bb8cd63..711902f0cc67e 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -165,11 +165,8 @@ class RotaryEmbedding(CustomOp): self.rotary_dim, self.is_neox_style, ) - else: - # ops.rotary_embedding() is an in-place operation - # that updates the query and key tensors. - self.forward_cuda(positions, query, key) - return query, key + return query, key + return self.forward_cuda(positions, query, key) def forward_xpu( self,