From 76e6a951925bf37c49f88ad155dc9fcec01a3faf Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:41:09 -0500 Subject: [PATCH] [Bug] Fix `Number of dimensions of tensors must match.` for Deepseek V3.2 (#31160) Signed-off-by: yewentao256 --- vllm/model_executor/models/deepseek_v2.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 22d43a4bae18a..4899f5476f955 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -878,11 +878,14 @@ class Indexer(nn.Module): ) q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) - # `rotary_emb` is shape-preserving; `q_pe` is already - # [num_tokens, n_head, rope_dim]. + # Note: RoPE (NeoX) can introduce extra leading dimensions during compilation + # so we need to reshape back to token-flattened shapes + q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim) + k_pe = k_pe.reshape(-1, 1, self.rope_dim) + q = torch.cat([q_pe, q_nope], dim=-1) # `k_pe` is [num_tokens, 1, rope_dim] (MQA). - k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) + k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim)