From 1c50e100a9c5dc439aceb9c4437b262d564baa53 Mon Sep 17 00:00:00 2001 From: li haoyang Date: Mon, 30 Jun 2025 21:24:50 +0800 Subject: [PATCH] [Bugfix] fix quark ptpc (#20251) Signed-off-by: Haoyang Li Co-authored-by: Haoyang Li <307790822@qq.com> --- .../layers/quantization/quark/quark.py | 6 +--- .../quark/schemes/quark_w8a8_fp8.py | 33 ++++++++++++------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 6ae5f5c9ad46b..05dff4bae3957 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -312,11 +312,7 @@ class QuarkConfig(QuantizationConfig): is_fp8_w8a8_supported = self._check_scheme_supported( QuarkW8A8Fp8.get_min_capability(), error=False) if is_fp8_w8a8_supported: - weight_qscheme = cast(str, weight_config.get("qscheme")) - input_static = (input_config is not None and - not cast(bool, input_config.get("is_dynamic"))) - return QuarkW8A8Fp8(qscheme=weight_qscheme, - is_static_input_scheme=input_static) + return QuarkW8A8Fp8(weight_config, input_config) elif self._is_static_tensor_w8a8(weight_config, input_config): weight_qscheme = cast(str, weight_config.get("qscheme")) return QuarkW8A8Int8(qscheme=weight_qscheme, diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 47e0a492b23b9..c7bc98184d0eb 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from typing import Any, Callable, Optional, cast import torch from torch.nn import Parameter @@ -19,10 +19,19 @@ __all__ = ["QuarkW8A8Fp8"] class QuarkW8A8Fp8(QuarkScheme): - def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): - self.qscheme = qscheme - self.is_static_input_scheme = is_static_input_scheme - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) + def __init__(self, weight_config: dict[str, Any], + input_config: Optional[dict[str, Any]]): + self.weight_qscheme = cast(str, weight_config.get("qscheme")) + self.is_static_input_scheme: bool = False + self.input_qscheme: Optional[str] = None + if input_config is not None: + self.is_static_input_scheme = not cast( + bool, input_config.get("is_dynamic")) + self.input_qscheme = cast(str, input_config.get("qscheme")) + self.use_per_token_if_dynamic = (not self.is_static_input_scheme \ + and self.input_qscheme == "per_channel") + self.fp8_linear = Fp8LinearOp( + use_per_token_if_dynamic=self.use_per_token_if_dynamic) self.out_dtype = torch.get_default_dtype() @classmethod @@ -34,7 +43,7 @@ class QuarkW8A8Fp8(QuarkScheme): # If per tensor, when we have a fused module (e.g. QKV) with per # tensor scales (thus N scales being passed to the kernel), # requantize so we can always run per tensor - if self.qscheme == "per_tensor": + if self.weight_qscheme == "per_tensor": if current_platform.is_rocm(): input_scale = getattr(layer, 'input_scale', None) weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( @@ -58,7 +67,7 @@ class QuarkW8A8Fp8(QuarkScheme): layer.weight_scale = Parameter(max_w_scale, requires_grad=False) # If channelwise, scales are already lined up, so just transpose. - elif self.qscheme == "per_channel": + elif self.weight_qscheme == "per_channel": weight = layer.weight if current_platform.is_fp8_fnuz(): @@ -73,13 +82,15 @@ class QuarkW8A8Fp8(QuarkScheme): requires_grad=False) else: weight_scale = layer.weight_scale.data - + if self.use_per_token_if_dynamic: + weight_scale = weight_scale.view(-1, 1) layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter layer.weight_scale = Parameter(weight_scale, requires_grad=False) else: - raise ValueError(f"Unknown quantization scheme {self.qscheme}") + raise ValueError( + f"Unknown quantization scheme {self.weight_qscheme}") # INPUT SCALE if self.is_static_input_scheme: @@ -109,14 +120,14 @@ class QuarkW8A8Fp8(QuarkScheme): # WEIGHT SCALE # TODO: update create_xxx_parameter functions to return # the newly added parameters - if self.qscheme == "per_channel": + if self.weight_qscheme == "per_channel": weight_scale = ChannelQuantScaleParameter( data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32), output_dim=0, weight_loader=weight_loader) else: - assert self.qscheme == "per_tensor" + assert self.weight_qscheme == "per_tensor" weight_scale = PerTensorScaleParameter(data=torch.empty( len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader)