mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-20 15:54:34 +08:00
[Perf] Improve fp8 quant in mla; replace ReduceSum with ReduceScatterSum (#29795)
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
This commit is contained in:
parent
6af70e11a0
commit
1fb632fdb6
@ -225,7 +225,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
if sizes is not None:
|
||||
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
|
||||
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
|
||||
else:
|
||||
pynccl_comm.reduce_scatter(output, input_tensor)
|
||||
|
||||
@ -2037,21 +2037,30 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
|
||||
if fp8_attention:
|
||||
ql_nope_shape = decode_ql_nope.shape
|
||||
decode_ql_nope, _ = ops.scaled_fp8_quant(
|
||||
decode_ql_nope.reshape(
|
||||
[ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]]
|
||||
),
|
||||
layer._q_scale,
|
||||
)
|
||||
decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
|
||||
q_pe_shape = decode_q_pe.shape
|
||||
decode_q_pe, _ = ops.scaled_fp8_quant(
|
||||
decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]),
|
||||
assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
|
||||
assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
|
||||
decode_q_shape = (
|
||||
ql_nope_shape[0],
|
||||
ql_nope_shape[1],
|
||||
ql_nope_shape[2] + q_pe_shape[2],
|
||||
)
|
||||
# Using empty and copy since torch.cat introduces significant overhead.
|
||||
decode_q0 = torch.empty(
|
||||
decode_q_shape,
|
||||
device=decode_ql_nope.device,
|
||||
dtype=decode_ql_nope.dtype,
|
||||
)
|
||||
decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope)
|
||||
decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe)
|
||||
|
||||
decode_q, _ = ops.scaled_fp8_quant(
|
||||
decode_q0.view(decode_q_shape[0], -1),
|
||||
layer._q_scale,
|
||||
)
|
||||
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
|
||||
|
||||
decode_q = (decode_ql_nope, decode_q_pe)
|
||||
decode_q = decode_q.view(decode_q_shape)
|
||||
else:
|
||||
decode_q = (decode_ql_nope, decode_q_pe)
|
||||
if self.dcp_world_size > 1:
|
||||
assert not fp8_attention, "DCP not support fp8 kvcache now."
|
||||
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user