From 4cf9429897c1a6c720a0f099a7a46e9d51af9342 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Fri, 19 Dec 2025 18:31:31 -0500 Subject: [PATCH] [Bug] Fix `error 'Dynamo failed to run FX node with fake tensors` for Deepseek V3.2 (#31046) Signed-off-by: yewentao256 --- vllm/model_executor/models/deepseek_v2.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 6670143cda250..22d43a4bae18a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -878,8 +878,11 @@ class Indexer(nn.Module): ) q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) - q = torch.cat([q_pe.squeeze(0), q_nope], dim=-1) - k = torch.cat([k_pe.squeeze((0, 2)), k_nope], dim=-1) + # `rotary_emb` is shape-preserving; `q_pe` is already + # [num_tokens, n_head, rope_dim]. + q = torch.cat([q_pe, q_nope], dim=-1) + # `k_pe` is [num_tokens, 1, rope_dim] (MQA). + k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim)