From 8d75fe48ca5f46b7af0f5201d8500b9604eed769 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 7 Jun 2024 04:42:35 -0400 Subject: [PATCH] [Kernel] Switch fp8 layers to use the CUTLASS kernels (#5183) Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8 see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and #5144 for comparisons across different GEMM sizes. --- vllm/_custom_ops.py | 4 +- .../model_executor/layers/quantization/fp8.py | 64 ++++++++++++++----- 2 files changed, 51 insertions(+), 17 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 462ba8a753105..cae6822166b66 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -179,7 +179,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # cutlass def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, - a_scales: torch.Tensor, b_scales: torch.Tensor, + scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: Type[torch.dtype]) -> torch.Tensor: assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) @@ -188,7 +188,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, n = b.shape[1] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) + vllm_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) return out diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index bf3a59e3d709b..136a64623d7fb 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -17,6 +17,24 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) +def cutlass_fp8_supported() -> bool: + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + version = torch.version.cuda + version = version[0] * 10 + version[1] + + # CUTLASS FP8 kernels need at least + # CUDA 12.0 on SM90 systems (Hopper) + # CUDA 12.4 on SM89 systems (Lovelace) + gpu_is_supported = False + if capability >= 900: + gpu_is_supported = version > 120 + elif capability >= 890: + gpu_is_supported = version > 124 + + return gpu_is_supported + + class Fp8Config(QuantizationConfig): """Config class for FP8.""" @@ -92,6 +110,7 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() def _create_scale_param( self, @@ -233,25 +252,40 @@ class Fp8LinearMethod(LinearMethodBase): layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.act_scale is None and x_scale computed from x. # If static, layer.act_scale is scalar and x_scale set to act_scale. - qinput, x_scale = ops.scaled_fp8_quant(x, - layer.act_scale, - batch_dim_padding=17) - # Fused GEMM_DQ -- note we padded the input above because - # torch._scaled_mm is more performant for matrices with - # batch dimension > 16. Note that this could change - # in the future. - output, _ = torch._scaled_mm( - qinput, - layer.weight, - out_dtype=x.dtype, - scale_a=x_scale, - scale_b=layer.weight_scale, - bias=bias, - ) + if bias is None and self.cutlass_fp8_supported: + qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale) + + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm_dq( + qinput, + layer.weight, + out_dtype=x.dtype, + scale_a=x_scale, + scale_b=layer.weight_scale, + ) + + else: + qinput, x_scale = ops.scaled_fp8_quant(x, + layer.act_scale, + batch_dim_padding=17) + + # Fused GEMM_DQ -- note we padded the input above because + # torch._scaled_mm is more performant for matrices with + # batch dimension > 16. Note that this could change + # in the future. + output, _ = torch._scaled_mm( + qinput, + layer.weight, + out_dtype=x.dtype, + scale_a=x_scale, + scale_b=layer.weight_scale, + bias=bias, + ) return torch.narrow(output, 0, 0, x.shape[0])