mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-12 02:07:12 +08:00
[Bugfix] fix quark ptpc (#20251)
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com> Co-authored-by: Haoyang Li <307790822@qq.com>
This commit is contained in:
parent
3ee56e26be
commit
1c50e100a9
@ -312,11 +312,7 @@ class QuarkConfig(QuantizationConfig):
|
|||||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||||
QuarkW8A8Fp8.get_min_capability(), error=False)
|
QuarkW8A8Fp8.get_min_capability(), error=False)
|
||||||
if is_fp8_w8a8_supported:
|
if is_fp8_w8a8_supported:
|
||||||
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
return QuarkW8A8Fp8(weight_config, input_config)
|
||||||
input_static = (input_config is not None and
|
|
||||||
not cast(bool, input_config.get("is_dynamic")))
|
|
||||||
return QuarkW8A8Fp8(qscheme=weight_qscheme,
|
|
||||||
is_static_input_scheme=input_static)
|
|
||||||
elif self._is_static_tensor_w8a8(weight_config, input_config):
|
elif self._is_static_tensor_w8a8(weight_config, input_config):
|
||||||
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||||
return QuarkW8A8Int8(qscheme=weight_qscheme,
|
return QuarkW8A8Int8(qscheme=weight_qscheme,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
@ -19,10 +19,19 @@ __all__ = ["QuarkW8A8Fp8"]
|
|||||||
|
|
||||||
class QuarkW8A8Fp8(QuarkScheme):
|
class QuarkW8A8Fp8(QuarkScheme):
|
||||||
|
|
||||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
|
def __init__(self, weight_config: dict[str, Any],
|
||||||
self.qscheme = qscheme
|
input_config: Optional[dict[str, Any]]):
|
||||||
self.is_static_input_scheme = is_static_input_scheme
|
self.weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
|
self.is_static_input_scheme: bool = False
|
||||||
|
self.input_qscheme: Optional[str] = None
|
||||||
|
if input_config is not None:
|
||||||
|
self.is_static_input_scheme = not cast(
|
||||||
|
bool, input_config.get("is_dynamic"))
|
||||||
|
self.input_qscheme = cast(str, input_config.get("qscheme"))
|
||||||
|
self.use_per_token_if_dynamic = (not self.is_static_input_scheme \
|
||||||
|
and self.input_qscheme == "per_channel")
|
||||||
|
self.fp8_linear = Fp8LinearOp(
|
||||||
|
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||||
self.out_dtype = torch.get_default_dtype()
|
self.out_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -34,7 +43,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||||
# tensor scales (thus N scales being passed to the kernel),
|
# tensor scales (thus N scales being passed to the kernel),
|
||||||
# requantize so we can always run per tensor
|
# requantize so we can always run per tensor
|
||||||
if self.qscheme == "per_tensor":
|
if self.weight_qscheme == "per_tensor":
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
input_scale = getattr(layer, 'input_scale', None)
|
input_scale = getattr(layer, 'input_scale', None)
|
||||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
@ -58,7 +67,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||||
|
|
||||||
# If channelwise, scales are already lined up, so just transpose.
|
# If channelwise, scales are already lined up, so just transpose.
|
||||||
elif self.qscheme == "per_channel":
|
elif self.weight_qscheme == "per_channel":
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
|
|
||||||
if current_platform.is_fp8_fnuz():
|
if current_platform.is_fp8_fnuz():
|
||||||
@ -73,13 +82,15 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
else:
|
else:
|
||||||
weight_scale = layer.weight_scale.data
|
weight_scale = layer.weight_scale.data
|
||||||
|
if self.use_per_token_if_dynamic:
|
||||||
|
weight_scale = weight_scale.view(-1, 1)
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
# required by torch.compile to be torch.nn.Parameter
|
# required by torch.compile to be torch.nn.Parameter
|
||||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown quantization scheme {self.qscheme}")
|
raise ValueError(
|
||||||
|
f"Unknown quantization scheme {self.weight_qscheme}")
|
||||||
|
|
||||||
# INPUT SCALE
|
# INPUT SCALE
|
||||||
if self.is_static_input_scheme:
|
if self.is_static_input_scheme:
|
||||||
@ -109,14 +120,14 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
# WEIGHT SCALE
|
# WEIGHT SCALE
|
||||||
# TODO: update create_xxx_parameter functions to return
|
# TODO: update create_xxx_parameter functions to return
|
||||||
# the newly added parameters
|
# the newly added parameters
|
||||||
if self.qscheme == "per_channel":
|
if self.weight_qscheme == "per_channel":
|
||||||
weight_scale = ChannelQuantScaleParameter(
|
weight_scale = ChannelQuantScaleParameter(
|
||||||
data=torch.empty((sum(output_partition_sizes)),
|
data=torch.empty((sum(output_partition_sizes)),
|
||||||
dtype=torch.float32),
|
dtype=torch.float32),
|
||||||
output_dim=0,
|
output_dim=0,
|
||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
else:
|
else:
|
||||||
assert self.qscheme == "per_tensor"
|
assert self.weight_qscheme == "per_tensor"
|
||||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||||
len(output_partition_sizes), dtype=torch.float32),
|
len(output_partition_sizes), dtype=torch.float32),
|
||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user