mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 13:27:15 +08:00
[Bugfix] fix quark ptpc (#20251)
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com> Co-authored-by: Haoyang Li <307790822@qq.com>
This commit is contained in:
parent
3ee56e26be
commit
1c50e100a9
@ -312,11 +312,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
QuarkW8A8Fp8.get_min_capability(), error=False)
|
||||
if is_fp8_w8a8_supported:
|
||||
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
input_static = (input_config is not None and
|
||||
not cast(bool, input_config.get("is_dynamic")))
|
||||
return QuarkW8A8Fp8(qscheme=weight_qscheme,
|
||||
is_static_input_scheme=input_static)
|
||||
return QuarkW8A8Fp8(weight_config, input_config)
|
||||
elif self._is_static_tensor_w8a8(weight_config, input_config):
|
||||
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
return QuarkW8A8Int8(qscheme=weight_qscheme,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional, cast
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@ -19,10 +19,19 @@ __all__ = ["QuarkW8A8Fp8"]
|
||||
|
||||
class QuarkW8A8Fp8(QuarkScheme):
|
||||
|
||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
|
||||
self.qscheme = qscheme
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
|
||||
def __init__(self, weight_config: dict[str, Any],
|
||||
input_config: Optional[dict[str, Any]]):
|
||||
self.weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
self.is_static_input_scheme: bool = False
|
||||
self.input_qscheme: Optional[str] = None
|
||||
if input_config is not None:
|
||||
self.is_static_input_scheme = not cast(
|
||||
bool, input_config.get("is_dynamic"))
|
||||
self.input_qscheme = cast(str, input_config.get("qscheme"))
|
||||
self.use_per_token_if_dynamic = (not self.is_static_input_scheme \
|
||||
and self.input_qscheme == "per_channel")
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@classmethod
|
||||
@ -34,7 +43,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.qscheme == "per_tensor":
|
||||
if self.weight_qscheme == "per_tensor":
|
||||
if current_platform.is_rocm():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
@ -58,7 +67,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.qscheme == "per_channel":
|
||||
elif self.weight_qscheme == "per_channel":
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
@ -73,13 +82,15 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
if self.use_per_token_if_dynamic:
|
||||
weight_scale = weight_scale.view(-1, 1)
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization scheme {self.qscheme}")
|
||||
raise ValueError(
|
||||
f"Unknown quantization scheme {self.weight_qscheme}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
@ -109,14 +120,14 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
# WEIGHT SCALE
|
||||
# TODO: update create_xxx_parameter functions to return
|
||||
# the newly added parameters
|
||||
if self.qscheme == "per_channel":
|
||||
if self.weight_qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.qscheme == "per_tensor"
|
||||
assert self.weight_qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user