[Perf] Improve fp8 quant in mla; replace ReduceSum with ReduceScatterSum (#29795)

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
This commit is contained in:
Lain 2025-12-08 15:02:34 -08:00 committed by GitHub
parent 6af70e11a0
commit 1fb632fdb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 13 deletions

View File

@ -225,7 +225,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
if sizes is not None:
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
else:
pynccl_comm.reduce_scatter(output, input_tensor)

View File

@ -2037,21 +2037,30 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
if fp8_attention:
ql_nope_shape = decode_ql_nope.shape
decode_ql_nope, _ = ops.scaled_fp8_quant(
decode_ql_nope.reshape(
[ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]]
),
layer._q_scale,
)
decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
q_pe_shape = decode_q_pe.shape
decode_q_pe, _ = ops.scaled_fp8_quant(
decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]),
assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
decode_q_shape = (
ql_nope_shape[0],
ql_nope_shape[1],
ql_nope_shape[2] + q_pe_shape[2],
)
# Using empty and copy since torch.cat introduces significant overhead.
decode_q0 = torch.empty(
decode_q_shape,
device=decode_ql_nope.device,
dtype=decode_ql_nope.dtype,
)
decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope)
decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe)
decode_q, _ = ops.scaled_fp8_quant(
decode_q0.view(decode_q_shape[0], -1),
layer._q_scale,
)
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
decode_q = (decode_ql_nope, decode_q_pe)
decode_q = decode_q.view(decode_q_shape)
else:
decode_q = (decode_ql_nope, decode_q_pe)
if self.dcp_world_size > 1:
assert not fp8_attention, "DCP not support fp8 kvcache now."
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)