[Bugfix] fix quark ptpc (#20251)

Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Co-authored-by: Haoyang Li <307790822@qq.com>
This commit is contained in:
li haoyang 2025-06-30 21:24:50 +08:00 committed by GitHub
parent 3ee56e26be
commit 1c50e100a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 16 deletions

View File

@ -312,11 +312,7 @@ class QuarkConfig(QuantizationConfig):
is_fp8_w8a8_supported = self._check_scheme_supported( is_fp8_w8a8_supported = self._check_scheme_supported(
QuarkW8A8Fp8.get_min_capability(), error=False) QuarkW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported: if is_fp8_w8a8_supported:
weight_qscheme = cast(str, weight_config.get("qscheme")) return QuarkW8A8Fp8(weight_config, input_config)
input_static = (input_config is not None and
not cast(bool, input_config.get("is_dynamic")))
return QuarkW8A8Fp8(qscheme=weight_qscheme,
is_static_input_scheme=input_static)
elif self._is_static_tensor_w8a8(weight_config, input_config): elif self._is_static_tensor_w8a8(weight_config, input_config):
weight_qscheme = cast(str, weight_config.get("qscheme")) weight_qscheme = cast(str, weight_config.get("qscheme"))
return QuarkW8A8Int8(qscheme=weight_qscheme, return QuarkW8A8Int8(qscheme=weight_qscheme,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional from typing import Any, Callable, Optional, cast
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
@ -19,10 +19,19 @@ __all__ = ["QuarkW8A8Fp8"]
class QuarkW8A8Fp8(QuarkScheme): class QuarkW8A8Fp8(QuarkScheme):
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): def __init__(self, weight_config: dict[str, Any],
self.qscheme = qscheme input_config: Optional[dict[str, Any]]):
self.is_static_input_scheme = is_static_input_scheme self.weight_qscheme = cast(str, weight_config.get("qscheme"))
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) self.is_static_input_scheme: bool = False
self.input_qscheme: Optional[str] = None
if input_config is not None:
self.is_static_input_scheme = not cast(
bool, input_config.get("is_dynamic"))
self.input_qscheme = cast(str, input_config.get("qscheme"))
self.use_per_token_if_dynamic = (not self.is_static_input_scheme \
and self.input_qscheme == "per_channel")
self.fp8_linear = Fp8LinearOp(
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
@classmethod @classmethod
@ -34,7 +43,7 @@ class QuarkW8A8Fp8(QuarkScheme):
# If per tensor, when we have a fused module (e.g. QKV) with per # If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel), # tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor # requantize so we can always run per tensor
if self.qscheme == "per_tensor": if self.weight_qscheme == "per_tensor":
if current_platform.is_rocm(): if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None) input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
@ -58,7 +67,7 @@ class QuarkW8A8Fp8(QuarkScheme):
layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose. # If channelwise, scales are already lined up, so just transpose.
elif self.qscheme == "per_channel": elif self.weight_qscheme == "per_channel":
weight = layer.weight weight = layer.weight
if current_platform.is_fp8_fnuz(): if current_platform.is_fp8_fnuz():
@ -73,13 +82,15 @@ class QuarkW8A8Fp8(QuarkScheme):
requires_grad=False) requires_grad=False)
else: else:
weight_scale = layer.weight_scale.data weight_scale = layer.weight_scale.data
if self.use_per_token_if_dynamic:
weight_scale = weight_scale.view(-1, 1)
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter # required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else: else:
raise ValueError(f"Unknown quantization scheme {self.qscheme}") raise ValueError(
f"Unknown quantization scheme {self.weight_qscheme}")
# INPUT SCALE # INPUT SCALE
if self.is_static_input_scheme: if self.is_static_input_scheme:
@ -109,14 +120,14 @@ class QuarkW8A8Fp8(QuarkScheme):
# WEIGHT SCALE # WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return # TODO: update create_xxx_parameter functions to return
# the newly added parameters # the newly added parameters
if self.qscheme == "per_channel": if self.weight_qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter( weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes)), data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32), dtype=torch.float32),
output_dim=0, output_dim=0,
weight_loader=weight_loader) weight_loader=weight_loader)
else: else:
assert self.qscheme == "per_tensor" assert self.weight_qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty( weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32), len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader) weight_loader=weight_loader)