diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 6670143cda250..22d43a4bae18a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -878,8 +878,11 @@ class Indexer(nn.Module): ) q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) - q = torch.cat([q_pe.squeeze(0), q_nope], dim=-1) - k = torch.cat([k_pe.squeeze((0, 2)), k_nope], dim=-1) + # `rotary_emb` is shape-preserving; `q_pe` is already + # [num_tokens, n_head, rope_dim]. + q = torch.cat([q_pe, q_nope], dim=-1) + # `k_pe` is [num_tokens, 1, rope_dim] (MQA). + k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim)