diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 21be17a750df4..ae534f3207b51 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -206,12 +206,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): ) if H < MAX_HEADS: - # Extract the subsets of the outputs - returned_lse = lse[:, :H].contiguous( - ) if self.need_to_return_lse_for_decode else lse out = out[:, :H] + if self.need_to_return_lse_for_decode: + lse = lse[:, :H].contiguous() - return out, returned_lse + return out, lse def _forward_decode( self,