diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 6017445402ec..78af8d28f889 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -210,9 +210,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): sm_scale, num_kv_splits, ) - returned_lse = lse[:, :H].contiguous( - ) if self.need_to_return_lse_for_decode else lse - return out[:, :H].contiguous(), returned_lse + + if H < MAX_HEADS: + # Extract the subsets of the outputs + returned_lse = lse[:, :H].contiguous( + ) if self.need_to_return_lse_for_decode else lse + out = out[:, :H] + + return out, returned_lse def _sm100_forward_decode( self, @@ -228,11 +233,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): self._workspace.ensure_size(attn_metadata, self._num_kv_splits) # Run MLA - # Clone q_nope and q_pe to make sure strides computation is correct. - # TODO: Check if we really need it - q_nope = q_nope.clone() - q_pe = q_pe.clone() - o, lse = self._sm100_cutlass_mla_decode( q_nope, q_pe,