[Performance] Remove redundant clone() calls in cutlass_mla (#24891)

This commit is contained in:
Alexander Matveev 2025-09-15 16:21:53 -04:00 committed by GitHub
parent 73df49ef3a
commit aae725af7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -210,9 +210,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
sm_scale,
num_kv_splits,
)
returned_lse = lse[:, :H].contiguous(
) if self.need_to_return_lse_for_decode else lse
return out[:, :H].contiguous(), returned_lse
if H < MAX_HEADS:
# Extract the subsets of the outputs
returned_lse = lse[:, :H].contiguous(
) if self.need_to_return_lse_for_decode else lse
out = out[:, :H]
return out, returned_lse
def _sm100_forward_decode(
self,
@ -228,11 +233,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
# Run MLA
# Clone q_nope and q_pe to make sure strides computation is correct.
# TODO: Check if we really need it
q_nope = q_nope.clone()
q_pe = q_pe.clone()
o, lse = self._sm100_cutlass_mla_decode(
q_nope,
q_pe,