mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 09:16:32 +08:00
[Bug] Fix Number of dimensions of tensors must match. for Deepseek V3.2 (#31160)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
8b59753cdb
commit
76e6a95192
@ -878,11 +878,14 @@ class Indexer(nn.Module):
|
||||
)
|
||||
|
||||
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
|
||||
# `rotary_emb` is shape-preserving; `q_pe` is already
|
||||
# [num_tokens, n_head, rope_dim].
|
||||
# Note: RoPE (NeoX) can introduce extra leading dimensions during compilation
|
||||
# so we need to reshape back to token-flattened shapes
|
||||
q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim)
|
||||
k_pe = k_pe.reshape(-1, 1, self.rope_dim)
|
||||
|
||||
q = torch.cat([q_pe, q_nope], dim=-1)
|
||||
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
|
||||
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
|
||||
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
|
||||
|
||||
# we only quant q here since k quant is fused with cache insertion
|
||||
q = q.view(-1, self.head_dim)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user