update quark fp8 path; format

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-10-30 14:24:19 +00:00
parent c05027f67a
commit c089ea5753
13 changed files with 172 additions and 122 deletions

View File

@ -6,8 +6,8 @@ from collections.abc import Callable
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from torch.nn import Parameter
from vllm.logger import init_logger
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
@ -15,10 +15,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_FP8_KERNELS,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
QUANT_STRATEGY_MAP,
FP8ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
@ -53,6 +52,7 @@ strategy_to_parameter_type = {
logger = init_logger(__name__)
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
@ -92,17 +92,20 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
activation_group_shape=self.act_q_group_shape,
out_dtype=self.out_dtype,
)
kernel = choose_scaled_mm_linear_kernel(
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_FP8_KERNELS,
)
self.fp8_linear = kernel(
scaled_mm_linear_kernel_config, layer_param_names = layer_param_names
)
if kernel.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW8A8FP8", kernel.__name__)
self._kernel_backends_being_used.add(kernel.__name__)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info(
"Using %s for CompressedTensorsW8A8FP8", kernel_type.__name__
)
self._kernel_backends_being_used.add(kernel_type.__name__)
self.kernel = kernel_type(
scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
)
@classmethod
def get_min_capability(cls) -> int:
@ -217,4 +220,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)
return self.kernel.apply_weights(layer, x, bias)

View File

@ -11,8 +11,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_INT8_KERNELS,
choose_scaled_mm_linear_kernel,
_POSSIBLE_INT8_KERNELS
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
Int8ScaledMMLinearLayerConfig,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
@ -20,7 +23,6 @@ from vllm.model_executor.parameter import (
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
logger = init_logger(__name__)
@ -58,8 +60,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_INT8_KERNELS
scaled_mm_linear_kernel_config, _POSSIBLE_INT8_KERNELS
)
if kernel_type.__name__ not in self._kernel_backends_being_used:
@ -94,8 +95,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
input_zero_point=None
input_scale=None
input_zero_point = None
input_scale = None
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
@ -113,11 +114,16 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None)
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
layer_param_names = [
"weight",
"weight_scale",
"input_scale",
"input_zero_point",
"azp_adj",
]
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config,
layer_param_names = layer_param_names
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
)
# Checkpoints are serialized in compressed-tensors format, which is

View File

@ -2,15 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Generic, Sequence, TypeVar
from typing import Generic, TypeVar
import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
class ScaledMMLinearQuantStrategy(Enum):
@ -18,21 +19,24 @@ class ScaledMMLinearQuantStrategy(Enum):
CHANNEL = "channel"
BLOCK = "block"
QUANT_STRATEGY_MAP = {
QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR,
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL,
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.BLOCK,
}
@dataclass
class ScaledMMLinearLayerConfig:
is_static_input_scheme: bool
@dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
is_channelwise: bool
input_symmetric: bool
@dataclass
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
weight_quant_strategy: ScaledMMLinearQuantStrategy
@ -40,22 +44,22 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
out_dtype: torch.dtype
Int8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
]
FP8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp
torch.Tensor | None, # azp_adj
]
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
]
Int8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp
torch.Tensor | None, # azp_adj
]
ParamsT = TypeVar("ParamsT", Int8ParamsT, FP8ParamsT)
ConfigT = TypeVar("ConfigT", bound=ScaledMMLinearLayerConfig)
ParamsT = TypeVar('ParamsT', Int8ParamsT, FP8ParamsT)
ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig)
class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
@classmethod
@ -68,9 +72,7 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]:
raise NotImplementedError
def __init__(
self, c: ConfigT, layer_param_names: Sequence[str]
) -> None:
def __init__(self, c: ConfigT, layer_param_names: Sequence[str]) -> None:
assert self.can_implement(c)
self.config = c
self.layer_param_names = layer_param_names
@ -87,16 +89,18 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
# return a covariant type in the subclass
@abstractmethod
def _get_layer_params(self, layer) -> ParamsT:
raise NotImplementedError
class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC):
class FP8ScaledMMLinearKernel(
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC
):
def __init__(
self, c: ConfigT, layer_param_names: Sequence[str]
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
@ -104,7 +108,7 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig,
num_token_padding=self.get_ouput_padding(),
)
super().__init__(c, layer_param_names)
@abstractmethod
def get_ouput_padding(self) -> int | None:
raise NotImplementedError
@ -113,7 +117,7 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig,
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
@ -126,7 +130,9 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig,
)
class Int8ScaledMMLinearKernel(ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC):
class Int8ScaledMMLinearKernel(
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC
):
def _get_layer_params(self, layer) -> Int8ParamsT:
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
return (

View File

@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
ChannelWiseTorchScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel,

View File

@ -14,7 +14,10 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import (
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@ -49,9 +52,7 @@ class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Transpose to [K, N] for convenience
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
self.layer_param_names
)
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
weight = getattr(layer, w_q_name)
replace_parameter(
layer,

View File

@ -2,22 +2,24 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearKernel, Int8ScaledMMLinearKernel
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
from .utils import apply_weights_fp8
def cutlass_w8a8_scaled_mm_fp8(
*,
A: torch.Tensor,
@ -34,6 +36,7 @@ def cutlass_w8a8_scaled_mm_fp8(
)
return output.view(*output_shape)
class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
@ -47,9 +50,7 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
self.layer_param_names
)
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
config = self.config
# WEIGHT
# Cutlass kernels need transposed weight.
@ -105,7 +106,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
@ -124,7 +124,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
torch.nn.Parameter(azp_adj, requires_grad=False),
)
def apply_weights(
self,
layer: torch.nn.Module,
@ -161,7 +160,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None:
return None
@ -191,5 +189,5 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
w_s,
x_s,
bias,
self.config.out_dtype
)
self.config.out_dtype,
)

View File

@ -1,17 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from .utils import apply_weights_fp8
@ -32,7 +29,6 @@ def flashinfer_w8a8_scaled_mm(
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None:
return None
@ -41,7 +37,7 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
return 100
@classmethod
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
@ -90,5 +86,5 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
w_s,
x_s,
bias,
self.config.out_dtype
)
self.config.out_dtype,
)

View File

@ -1,13 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@ -18,6 +15,7 @@ from .ScaledMMLinearKernel import (
)
from .utils import apply_weights_fp8
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A: torch.Tensor,
B: torch.Tensor,
@ -40,7 +38,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
current_platform.get_cu_count(),
bias,
)
# Fallabck
# Fallback
else:
output = torch._scaled_mm(
A,
@ -143,5 +141,5 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
w_s,
x_s,
bias,
self.config.out_dtype
self.config.out_dtype,
)

View File

@ -1,13 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from packaging import version
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (
@ -15,8 +12,8 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from .utils import apply_weights_fp8
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
@ -142,6 +139,7 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
output_padding = 17 if pad_output else None
return output_padding
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
@ -173,9 +171,10 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
w_s,
x_s,
bias,
self.config.out_dtype
self.config.out_dtype,
)
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
@ -199,7 +198,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
return (
False,
"RowWiseTorchScaledMMLinearKernel is only supported "
+ "in ROCm platforms.",
+ "on ROCm platforms.",
)
if not version.parse(torch.__version__) >= version.parse("2.7"):
@ -225,7 +224,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
w_s,
x_s,
bias,
self.config.out_dtype
self.config.out_dtype,
)
@ -265,5 +264,5 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
w_s,
x_s,
bias,
self.config.out_dtype
self.config.out_dtype,
)

View File

@ -1,20 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import torch
from vllm.platforms import current_platform
FP8ScaledMMCallBack = Callable[..., torch.Tensor]
FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]]
def apply_weights_fp8(
scaled_mm_func: FP8ScaledMMCallBack,
quant_fp8_func: FP8QuantCallback,
w:torch.Tensor,
x:torch.Tensor,
w_s:torch.Tensor,
x_s:torch.Tensor,
bias:torch.Tensor,
maybe_out_dtype: torch.dtype | None,
) -> torch.Tensor:
scaled_mm_func: FP8ScaledMMCallBack,
quant_fp8_func: FP8QuantCallback,
w: torch.Tensor,
x: torch.Tensor,
w_s: torch.Tensor,
x_s: torch.Tensor,
bias: torch.Tensor,
maybe_out_dtype: torch.dtype | None,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_s computed from x.
# If static, layer.input_scale is scalar and x_s is input_scale.

View File

@ -12,7 +12,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import (
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@ -42,9 +45,7 @@ class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# [out, in] (different than cutlass_scaled_mm)
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
self.layer_param_names
)
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
weight = getattr(layer, w_q_name)
replace_parameter(
layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)

View File

@ -7,10 +7,18 @@ from typing import Any, cast
import torch
from torch.nn import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_FP8_KERNELS,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale,
)
@ -23,8 +31,17 @@ from vllm.platforms import current_platform
__all__ = ["QuarkW8A8Fp8"]
logger = init_logger(__name__)
QUANT_STRATEGY_MAP = {
"per_tensor": ScaledMMLinearQuantStrategy.TENSOR,
"per_channel": ScaledMMLinearQuantStrategy.CHANNEL,
}
class QuarkW8A8Fp8(QuarkScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(
self, weight_config: dict[str, Any], input_config: dict[str, Any] | None
):
@ -41,10 +58,6 @@ class QuarkW8A8Fp8(QuarkScheme):
self.act_quant_group_shape = (
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
)
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_quant_group_shape,
)
self.out_dtype = torch.get_default_dtype()
@classmethod
@ -163,17 +176,32 @@ class QuarkW8A8Fp8(QuarkScheme):
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)
layer_param_names = ["weight", "weight_scale", "input_scale"]
weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme]
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
is_static_input_scheme=self.is_static_input_scheme,
weight_quant_strategy=weight_quant_strategy,
activation_group_shape=self.act_quant_group_shape,
out_dtype=self.out_dtype,
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_FP8_KERNELS,
)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for QuarkW8A8FP8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
layer_param_names = ["weight", "weight_scale", "input_scale"]
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)
return self.kernel.apply_weights(layer, x, bias)

View File

@ -7,9 +7,12 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_INT8_KERNELS,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
Int8ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.parameter import (
BasevLLMParameter,
@ -56,7 +59,9 @@ class QuarkW8A8Int8(QuarkScheme):
input_symmetric=(self.input_symmetric is True),
)
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config, possible_kernels=_POSSIBLE_INT8_KERNELS
)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
@ -102,8 +107,8 @@ class QuarkW8A8Int8(QuarkScheme):
layer.register_parameter("weight_zero_point", weight_zero_point)
# INPUT SCALE
input_zero_point=None
input_scale=None
input_zero_point = None
input_scale = None
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
@ -117,12 +122,17 @@ class QuarkW8A8Int8(QuarkScheme):
layer.register_parameter("input_zero_point", input_zero_point)
if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None)
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
layer_param_names = [
"weight",
"weight_scale",
"input_scale",
"input_zero_point",
"azp_adj",
]
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config,
layer_param_names = layer_param_names
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
)
# Checkpoints are serialized in quark format, which is