mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 16:27:27 +08:00
[Performance] Remove redundant clone() calls in cutlass_mla (#24891)
This commit is contained in:
parent
73df49ef3a
commit
aae725af7c
@ -210,9 +210,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
sm_scale,
|
sm_scale,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
)
|
)
|
||||||
returned_lse = lse[:, :H].contiguous(
|
|
||||||
) if self.need_to_return_lse_for_decode else lse
|
if H < MAX_HEADS:
|
||||||
return out[:, :H].contiguous(), returned_lse
|
# Extract the subsets of the outputs
|
||||||
|
returned_lse = lse[:, :H].contiguous(
|
||||||
|
) if self.need_to_return_lse_for_decode else lse
|
||||||
|
out = out[:, :H]
|
||||||
|
|
||||||
|
return out, returned_lse
|
||||||
|
|
||||||
def _sm100_forward_decode(
|
def _sm100_forward_decode(
|
||||||
self,
|
self,
|
||||||
@ -228,11 +233,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
|
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
|
||||||
|
|
||||||
# Run MLA
|
# Run MLA
|
||||||
# Clone q_nope and q_pe to make sure strides computation is correct.
|
|
||||||
# TODO: Check if we really need it
|
|
||||||
q_nope = q_nope.clone()
|
|
||||||
q_pe = q_pe.clone()
|
|
||||||
|
|
||||||
o, lse = self._sm100_cutlass_mla_decode(
|
o, lse = self._sm100_cutlass_mla_decode(
|
||||||
q_nope,
|
q_nope,
|
||||||
q_pe,
|
q_pe,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user