From 4d0ec37267afaf988e32174ebc31f24268076491 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Thu, 27 Mar 2025 22:58:16 -0400 Subject: [PATCH] [Quantization][FP8] Adding support for fp8 gemm layer input in fp8 (#14578) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Gregory Shtrasberg Co-authored-by: Luka Govedič --- .../schemes/compressed_tensors_w8a8_fp8.py | 2 ++ .../layers/quantization/fbgemm_fp8.py | 2 ++ .../model_executor/layers/quantization/fp8.py | 17 ++++++++++++ .../quark/schemes/quark_w8a8_fp8.py | 2 ++ .../layers/quantization/utils/w8a8_utils.py | 27 ++++++++++++------- 5 files changed, 41 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 27a74d677da7b..e99a452963f48 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -23,6 +23,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy + self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) @@ -143,5 +144,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): return self.fp8_linear.apply(input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 1cc431c5cc7be..7dddc40f3446d 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -73,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + self.out_dtype = torch.get_default_dtype() def create_weights( self, @@ -161,6 +162,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): return self.fp8_linear.apply(input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=None, input_scale_ub=layer.input_scale_ub, bias=bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f3907b4784b54..11bfdb4180531 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -116,6 +116,21 @@ class Fp8Config(QuantizationConfig): return Fp8KVCacheMethod(self) return None + def get_cache_scale(self, name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in compressed-tensors. If this is the case, return its equivalent + param name expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + return None + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. @@ -138,6 +153,7 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + self.out_dtype = torch.get_default_dtype() # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization @@ -386,6 +402,7 @@ class Fp8LinearMethod(LinearMethodBase): return self.fp8_linear.apply(input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 3e4251e46931c..c161849c8c5a2 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -22,6 +22,7 @@ class QuarkW8A8Fp8(QuarkScheme): self.qscheme = qscheme self.is_static_input_scheme = is_static_input_scheme self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + self.out_dtype = torch.get_default_dtype() @classmethod def get_min_capability(cls) -> int: @@ -134,5 +135,6 @@ class QuarkW8A8Fp8(QuarkScheme): return self.fp8_linear.apply(input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index c2bd4bce560e7..b8e6384d7359f 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -163,6 +163,7 @@ class Fp8LinearOp: input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, input_scale: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, @@ -182,8 +183,13 @@ class Fp8LinearOp: if use_per_token_if_dynamic is None: use_per_token_if_dynamic = self.use_per_token_if_dynamic + if out_dtype is None: + out_dtype = input.dtype + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if self.cutlass_fp8_supported: + assert input.dtype != current_platform.fp8_dtype( + ), "FP8 input to cutlass is not currently implemented" qinput, x_scale = ops.scaled_fp8_quant( input_2d, input_scale, @@ -193,7 +199,7 @@ class Fp8LinearOp: # Fused GEMM_DQ output = ops.cutlass_scaled_mm(qinput, weight, - out_dtype=input.dtype, + out_dtype=out_dtype, scale_a=x_scale, scale_b=weight_scale, bias=bias) @@ -202,12 +208,15 @@ class Fp8LinearOp: # torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token else: - # Maybe apply padding to output, see comment in __init__ - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, - input_scale, - num_token_padding=self.output_padding, - use_per_token_if_dynamic=use_per_token_if_dynamic) + if input.dtype != current_platform.fp8_dtype(): + # Maybe apply padding to output, see comment in __init__ + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=self.output_padding, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + qinput, x_scale = input_2d, input_scale per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) @@ -216,7 +225,7 @@ class Fp8LinearOp: # Fused GEMM_DQ output = torch._scaled_mm(qinput, weight, - out_dtype=input.dtype, + out_dtype=out_dtype, scale_a=x_scale, scale_b=weight_scale, bias=bias) @@ -240,7 +249,7 @@ class Fp8LinearOp: # Fused GEMM_DQ Rowwise GEMM output = torch._scaled_mm(qinput, weight, - out_dtype=input.dtype, + out_dtype=out_dtype, scale_a=x_scale, scale_b=weight_scale.t(), bias=bias)