From 1fb632fdb6b6391cd1fe3146d05cc2858c2446ab Mon Sep 17 00:00:00 2001 From: Lain Date: Mon, 8 Dec 2025 15:02:34 -0800 Subject: [PATCH] [Perf] Improve fp8 quant in mla; replace ReduceSum with ReduceScatterSum (#29795) Signed-off-by: Siyuan Fu --- .../device_communicators/cuda_communicator.py | 2 +- vllm/v1/attention/backends/mla/common.py | 33 ++++++++++++------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 2e878eef908ac..cd9c267beb5b5 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -225,7 +225,7 @@ class CudaCommunicator(DeviceCommunicatorBase): output_shape, dtype=input_tensor.dtype, device=input_tensor.device ) - if sizes is not None: + if sizes is not None and sizes.count(sizes[0]) != len(sizes): pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes) else: pynccl_comm.reduce_scatter(output, input_tensor) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 309ddee4fc2f0..0a5257a1d87d8 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -2037,21 +2037,30 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): if fp8_attention: ql_nope_shape = decode_ql_nope.shape - decode_ql_nope, _ = ops.scaled_fp8_quant( - decode_ql_nope.reshape( - [ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]] - ), - layer._q_scale, - ) - decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) q_pe_shape = decode_q_pe.shape - decode_q_pe, _ = ops.scaled_fp8_quant( - decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + assert decode_ql_nope.shape[0] == decode_q_pe.shape[0] + assert decode_ql_nope.shape[1] == decode_q_pe.shape[1] + decode_q_shape = ( + ql_nope_shape[0], + ql_nope_shape[1], + ql_nope_shape[2] + q_pe_shape[2], + ) + # Using empty and copy since torch.cat introduces significant overhead. + decode_q0 = torch.empty( + decode_q_shape, + device=decode_ql_nope.device, + dtype=decode_ql_nope.dtype, + ) + decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope) + decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe) + + decode_q, _ = ops.scaled_fp8_quant( + decode_q0.view(decode_q_shape[0], -1), layer._q_scale, ) - decode_q_pe = decode_q_pe.reshape(q_pe_shape) - - decode_q = (decode_ql_nope, decode_q_pe) + decode_q = decode_q.view(decode_q_shape) + else: + decode_q = (decode_ql_nope, decode_q_pe) if self.dcp_world_size > 1: assert not fp8_attention, "DCP not support fp8 kvcache now." # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)