[Bugfix] fix quark ptpc (#20251)

Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Co-authored-by: Haoyang Li <307790822@qq.com>
This commit is contained in:
li haoyang 2025-06-30 21:24:50 +08:00 committed by GitHub
parent 3ee56e26be
commit 1c50e100a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 16 deletions

View File

@ -312,11 +312,7 @@ class QuarkConfig(QuantizationConfig):
is_fp8_w8a8_supported = self._check_scheme_supported(
QuarkW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
weight_qscheme = cast(str, weight_config.get("qscheme"))
input_static = (input_config is not None and
not cast(bool, input_config.get("is_dynamic")))
return QuarkW8A8Fp8(qscheme=weight_qscheme,
is_static_input_scheme=input_static)
return QuarkW8A8Fp8(weight_config, input_config)
elif self._is_static_tensor_w8a8(weight_config, input_config):
weight_qscheme = cast(str, weight_config.get("qscheme"))
return QuarkW8A8Int8(qscheme=weight_qscheme,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from typing import Any, Callable, Optional, cast
import torch
from torch.nn import Parameter
@ -19,10 +19,19 @@ __all__ = ["QuarkW8A8Fp8"]
class QuarkW8A8Fp8(QuarkScheme):
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
def __init__(self, weight_config: dict[str, Any],
input_config: Optional[dict[str, Any]]):
self.weight_qscheme = cast(str, weight_config.get("qscheme"))
self.is_static_input_scheme: bool = False
self.input_qscheme: Optional[str] = None
if input_config is not None:
self.is_static_input_scheme = not cast(
bool, input_config.get("is_dynamic"))
self.input_qscheme = cast(str, input_config.get("qscheme"))
self.use_per_token_if_dynamic = (not self.is_static_input_scheme \
and self.input_qscheme == "per_channel")
self.fp8_linear = Fp8LinearOp(
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
self.out_dtype = torch.get_default_dtype()
@classmethod
@ -34,7 +43,7 @@ class QuarkW8A8Fp8(QuarkScheme):
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.qscheme == "per_tensor":
if self.weight_qscheme == "per_tensor":
if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
@ -58,7 +67,7 @@ class QuarkW8A8Fp8(QuarkScheme):
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
elif self.qscheme == "per_channel":
elif self.weight_qscheme == "per_channel":
weight = layer.weight
if current_platform.is_fp8_fnuz():
@ -73,13 +82,15 @@ class QuarkW8A8Fp8(QuarkScheme):
requires_grad=False)
else:
weight_scale = layer.weight_scale.data
if self.use_per_token_if_dynamic:
weight_scale = weight_scale.view(-1, 1)
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
raise ValueError(f"Unknown quantization scheme {self.qscheme}")
raise ValueError(
f"Unknown quantization scheme {self.weight_qscheme}")
# INPUT SCALE
if self.is_static_input_scheme:
@ -109,14 +120,14 @@ class QuarkW8A8Fp8(QuarkScheme):
# WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if self.qscheme == "per_channel":
if self.weight_qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.qscheme == "per_tensor"
assert self.weight_qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)