mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:45:19 +08:00
[Performance] Remove redundant clone() calls in cutlass_mla (#24891)
This commit is contained in:
parent
73df49ef3a
commit
aae725af7c
@ -210,9 +210,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
sm_scale,
|
||||
num_kv_splits,
|
||||
)
|
||||
returned_lse = lse[:, :H].contiguous(
|
||||
) if self.need_to_return_lse_for_decode else lse
|
||||
return out[:, :H].contiguous(), returned_lse
|
||||
|
||||
if H < MAX_HEADS:
|
||||
# Extract the subsets of the outputs
|
||||
returned_lse = lse[:, :H].contiguous(
|
||||
) if self.need_to_return_lse_for_decode else lse
|
||||
out = out[:, :H]
|
||||
|
||||
return out, returned_lse
|
||||
|
||||
def _sm100_forward_decode(
|
||||
self,
|
||||
@ -228,11 +233,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
|
||||
|
||||
# Run MLA
|
||||
# Clone q_nope and q_pe to make sure strides computation is correct.
|
||||
# TODO: Check if we really need it
|
||||
q_nope = q_nope.clone()
|
||||
q_pe = q_pe.clone()
|
||||
|
||||
o, lse = self._sm100_cutlass_mla_decode(
|
||||
q_nope,
|
||||
q_pe,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user