[Quantization][FP8] Adding support for fp8 gemm layer input in fp8 (#14578)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Gregory Shtrasberg 2025-03-27 22:58:16 -04:00 committed by GitHub
parent e7f720ea56
commit 4d0ec37267
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 41 additions and 9 deletions

View File

@ -23,6 +23,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
@ -143,5 +144,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)

View File

@ -73,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
self.out_dtype = torch.get_default_dtype()
def create_weights(
self,
@ -161,6 +162,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias)

View File

@ -116,6 +116,21 @@ class Fp8Config(QuantizationConfig):
return Fp8KVCacheMethod(self)
return None
def get_cache_scale(self, name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
return None
class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
@ -138,6 +153,7 @@ class Fp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
@ -386,6 +402,7 @@ class Fp8LinearMethod(LinearMethodBase):
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)

View File

@ -22,6 +22,7 @@ class QuarkW8A8Fp8(QuarkScheme):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
self.out_dtype = torch.get_default_dtype()
@classmethod
def get_min_capability(cls) -> int:
@ -134,5 +135,6 @@ class QuarkW8A8Fp8(QuarkScheme):
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)

View File

@ -163,6 +163,7 @@ class Fp8LinearOp:
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: Optional[torch.dtype] = None,
input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
@ -182,8 +183,13 @@ class Fp8LinearOp:
if use_per_token_if_dynamic is None:
use_per_token_if_dynamic = self.use_per_token_if_dynamic
if out_dtype is None:
out_dtype = input.dtype
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if self.cutlass_fp8_supported:
assert input.dtype != current_platform.fp8_dtype(
), "FP8 input to cutlass is not currently implemented"
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
@ -193,7 +199,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
@ -202,12 +208,15 @@ class Fp8LinearOp:
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else:
# Maybe apply padding to output, see comment in __init__
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic)
if input.dtype != current_platform.fp8_dtype():
# Maybe apply padding to output, see comment in __init__
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic)
else:
qinput, x_scale = input_2d, input_scale
per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)
@ -216,7 +225,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
@ -240,7 +249,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias)