From cd3ea013d6346b5d3fb5826940a4baf7f2b0325e Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Sat, 27 Sep 2025 17:49:34 -0700 Subject: [PATCH] maybe fix Signed-off-by: Alexander Matveev --- vllm/v1/attention/backends/mla/cutlass_mla.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index d44e20f2cb6be..27d16635cf78b 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -167,6 +167,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): MAX_HEADS = 128 assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}" + if H < MAX_HEADS: + q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope)) + q_nope_padded[:, :H] = q_nope + q_nope = q_nope_padded + + q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe)) + q_pe_padded[:, :H] = q_pe + q_pe = q_pe_padded assert len(page_table.shape) == 2 B_block_table, block_num = page_table.shape @@ -209,8 +217,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): if H < MAX_HEADS: # Extract the subsets of the outputs - lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse - out = out[:, :H] + lse = lse[:, :H].contiguous( + ) if self.need_to_return_lse_for_decode else lse + out = out[:, :H].contiguous() return out, lse