mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 02:27:03 +08:00
[Quantization][FP8] Adding support for fp8 gemm layer input in fp8 (#14578)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
e7f720ea56
commit
4d0ec37267
@ -23,6 +23,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||
|
||||
@ -143,5 +144,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
|
||||
@ -73,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: FBGEMMFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -161,6 +162,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=None,
|
||||
input_scale_ub=layer.input_scale_ub,
|
||||
bias=bias)
|
||||
|
||||
@ -116,6 +116,21 @@ class Fp8Config(QuantizationConfig):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
return None
|
||||
|
||||
|
||||
class Fp8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for FP8.
|
||||
@ -138,6 +153,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
@ -386,6 +402,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
self.qscheme = qscheme
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -134,5 +135,6 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
|
||||
@ -163,6 +163,7 @@ class Fp8LinearOp:
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
input_scale_ub: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
@ -182,8 +183,13 @@ class Fp8LinearOp:
|
||||
if use_per_token_if_dynamic is None:
|
||||
use_per_token_if_dynamic = self.use_per_token_if_dynamic
|
||||
|
||||
if out_dtype is None:
|
||||
out_dtype = input.dtype
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if self.cutlass_fp8_supported:
|
||||
assert input.dtype != current_platform.fp8_dtype(
|
||||
), "FP8 input to cutlass is not currently implemented"
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
@ -193,7 +199,7 @@ class Fp8LinearOp:
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
@ -202,12 +208,15 @@ class Fp8LinearOp:
|
||||
# torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
else:
|
||||
# Maybe apply padding to output, see comment in __init__
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
num_token_padding=self.output_padding,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||
if input.dtype != current_platform.fp8_dtype():
|
||||
# Maybe apply padding to output, see comment in __init__
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
num_token_padding=self.output_padding,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||
else:
|
||||
qinput, x_scale = input_2d, input_scale
|
||||
|
||||
per_tensor_weights = (weight_scale.numel() == 1)
|
||||
per_tensor_activations = (x_scale.numel() == 1)
|
||||
@ -216,7 +225,7 @@ class Fp8LinearOp:
|
||||
# Fused GEMM_DQ
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
@ -240,7 +249,7 @@ class Fp8LinearOp:
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.t(),
|
||||
bias=bias)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user