mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 10:09:36 +08:00
[Kernel] [FP8] Improve FP8 linear layer performance (#4691)
This PR improves the FP8 performance of linear layers, which had been lacking before (#4118 (comment) and #4118 (comment)). We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance. Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization: qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16) qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16) qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16) qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16) qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
This commit is contained in:
parent
ebce310b74
commit
379da6dcb5
@ -189,8 +189,34 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
batch_dim_padding: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
||||
"""
|
||||
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||
|
||||
This function supports both static and dynamic quantization: If you
|
||||
provide the scale, it will use static scaling and if you omit it,
|
||||
the scale will be determined dynamically. The function also allows
|
||||
optional padding of the output tensor for downstream kernels that
|
||||
will benefit from padding.
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP8
|
||||
scale: Optional scaling factor for the FP8 quantization
|
||||
batch_dim_padding: If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
||||
scaling factor.
|
||||
"""
|
||||
if batch_dim_padding:
|
||||
shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
|
||||
output = torch.empty(shape,
|
||||
device=input.device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
else:
|
||||
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
||||
if scale is None:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
@ -231,9 +231,14 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.act_scale is None and x_scale computed from x.
|
||||
# If static, layer.act_scale is scalar and x_scale set to act_scale.
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x,
|
||||
layer.act_scale,
|
||||
batch_dim_padding=17)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
# Fused GEMM_DQ -- note we padded the input above because
|
||||
# torch._scaled_mm is more performant for matrices with
|
||||
# batch dimension > 16. Note that this could change
|
||||
# in the future.
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
@ -243,7 +248,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return output
|
||||
return torch.narrow(output, 0, 0, x.shape[0])
|
||||
|
||||
|
||||
def all_close_1d(x: torch.Tensor) -> bool:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user