diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 8184b073275c6..109e8496fc31e 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1308,7 +1308,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ) kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank].unsqueeze(1) + [..., :self.kv_lora_rank] k_pe = workspace[:toks]\ [..., self.kv_lora_rank:].unsqueeze(1) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index c98262eea1e91..0b55854de94af 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -874,7 +874,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ) kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank].unsqueeze(1) + [..., :self.kv_lora_rank] k_pe = workspace[:toks]\ [..., self.kv_lora_rank:].unsqueeze(1)