From e54e5720854debbb87e1165ca9fd6355c4f7c938 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 30 Oct 2025 08:04:24 +0000 Subject: [PATCH] fix int8 path Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 7 +--- .../schemes/compressed_tensors_w8a8_int8.py | 32 +++++++++----- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 42 ++++++++----------- .../kernels/scaled_mm/__init__.py | 1 + .../quantization/kernels/scaled_mm/aiter.py | 4 +- .../quantization/kernels/scaled_mm/cpu.py | 4 +- .../quantization/kernels/scaled_mm/cutlass.py | 15 +++---- .../kernels/scaled_mm/flash_infer.py | 6 +-- .../quantization/kernels/scaled_mm/rocm.py | 6 +-- .../quantization/kernels/scaled_mm/torch.py | 10 ++--- .../quantization/kernels/scaled_mm/triton.py | 4 +- .../quantization/kernels/scaled_mm/xla.py | 4 +- .../quark/schemes/quark_w8a8_int8.py | 4 +- 13 files changed, 67 insertions(+), 72 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index c1108e96d2135..bd9a6bd0ef043 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( choose_scaled_mm_linear_kernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( - ScaledMMLinearLayerConfig, + FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -91,10 +91,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): elif self.strategy == QuantizationStrategy.CHANNEL: weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL - scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( - is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), - is_static_input_scheme=self.is_static_input_scheme, - input_symmetric=True, + scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( weight_quant_strategy=weight_quant_strategy, activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 6fd0a6a1c822c..049f96f1faa35 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -11,8 +11,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel, + _POSSIBLE_INT8_KERNELS ) from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -20,6 +20,7 @@ from vllm.model_executor.parameter import ( ModelWeightParameter, PerTensorScaleParameter, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig logger = init_logger(__name__) @@ -50,13 +51,16 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ): layer.logical_widths = output_partition_sizes - scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig( is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), is_static_input_scheme=self.is_static_input_scheme, input_symmetric=self.input_symmetric, ) - kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, + _POSSIBLE_INT8_KERNELS + ) if kernel_type.__name__ not in self._kernel_backends_being_used: logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__) @@ -90,12 +94,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE + input_zero_point=None + input_scale=None if self.is_static_input_scheme: input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader ) - layer.register_parameter("input_scale", input_scale) - if not self.input_symmetric: # Note: compressed-tensors stores the zp using the same dtype # as the weights @@ -103,15 +107,21 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): input_zero_point = BasevLLMParameter( data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader ) - layer.register_parameter("input_zero_point", input_zero_point) + layer.register_parameter("input_zero_point", input_zero_point) + layer.register_parameter("input_scale", input_scale) + if not hasattr(layer, "azp_adj"): + layer.register_parameter("azp_adj", None) + + param_name_list = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] + + layer_mapping_function = lambda layer: ( + tuple(getattr(layer, param_name) for param_name in param_name_list), + param_name_list, + ) self.kernel = kernel_type( c=scaled_mm_linear_kernel_config, - w_q_param_name="weight", - w_s_param_name="weight_scale", - i_s_param_name="input_scale", - i_zp_param_name="input_zero_point", - azp_adj_param_name="azp_adj", + layer_mapping_function = layer_mapping_function ) # Checkpoints are serialized in compressed-tensors format, which is diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 0445223526c9e..e12aa2c5c4d2c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass from enum import Enum - +from typing import Generic, TypeVar import torch from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape @@ -25,16 +25,25 @@ class ScaledMMLinearQuantStrategy(Enum): @dataclass class ScaledMMLinearLayerConfig: - # TODO: remove is channelwise + pass + +@dataclass +class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): is_channelwise: bool is_static_input_scheme: bool input_symmetric: bool - out_dtype: torch.dtype | None + +@dataclass +class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): weight_quant_strategy: ScaledMMLinearQuantStrategy - activation_group_shape: GroupShape | None = GroupShape.PER_TENSOR + activation_group_shape: GroupShape + out_dtype: torch.dtype -class ScaledMMLinearKernel(ABC): +ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig) + + +class ScaledMMLinearKernel(Generic[ConfigT], ABC): @classmethod @abstractmethod def get_min_capability(cls) -> int: @@ -42,11 +51,11 @@ class ScaledMMLinearKernel(ABC): @classmethod @abstractmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]: raise NotImplementedError def __init__( - self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + self, c: ConfigT, layer_mapping_function: Callable ) -> None: assert self.can_implement(c) self.config = c @@ -63,21 +72,4 @@ class ScaledMMLinearKernel(ABC): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - raise NotImplementedError - - # def _get_weight_params( - # self, layer: torch.nn.Module - # ) -> tuple[ - # torch.Tensor, # weight - # torch.Tensor, # weight_scale - # torch.Tensor | None, # input_scale, - # torch.Tensor | None, # input_zp - # torch.Tensor | None, # azp_adj - # ]: - # return ( - # getattr(layer, self.w_q_name), - # getattr(layer, self.w_s_name), - # getattr(layer, self.i_s_name), - # getattr(layer, self.i_zp_name), - # getattr(layer, self.azp_adj_name), - # ) + raise NotImplementedError \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 2ad21162995fe..85aaf51ae844c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer ScaledMMLinearKernel, ScaledMMLinearLayerConfig, ) + from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( ChannelWiseTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 7dc1a57f1ecd3..a39e96bca614b 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -10,7 +10,7 @@ from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op from .cutlass import process_weights_after_loading -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig def rocm_aiter_gemm_w8a8_impl( @@ -58,7 +58,7 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel): return 90 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_rocm(): return ( False, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index feb1e0bee1aaf..9c8ece8559b48 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -14,7 +14,7 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig class CPUScaledMMLinearKernel(ScaledMMLinearKernel): @@ -23,7 +23,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): return 75 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_cpu(): return False, "CPUScaledMM requires running on CPU." diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 6e88d65acd453..b81d670686930 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig def cutlass_w8a8_scaled_mm( @@ -36,7 +36,7 @@ def cutlass_w8a8_scaled_mm( def process_weights_after_loading( - config: ScaledMMLinearLayerConfig, + config: Int8ScaledMMLinearLayerConfig, layer: torch.nn.Module, w_q_name: str, w_s_name: str, @@ -98,9 +98,6 @@ def process_weights_after_loading( layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) ) - else: - setattr(layer, i_s_name, None) - setattr(layer, i_zp_name, None) # azp_adj is the AZP adjustment term, used to account for weights. # It does not depend on scales or azp, so it is the same for @@ -119,8 +116,6 @@ def process_weights_after_loading( azp_adj_name, torch.nn.Parameter(azp_adj, requires_grad=False), ) - else: - setattr(layer, azp_adj_name, None) class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): @@ -129,7 +124,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): return 75 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_cuda(): return False, "CutlassScaledMM requires running on CUDA." @@ -177,7 +172,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel): def __init__( - self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable ) -> None: self.quant_fp8 = QuantFP8( static=c.is_static_input_scheme, @@ -192,7 +187,7 @@ class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel): return 89 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_cuda(): return ( False, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py index 9940ef49bb3e0..9fcbb2ff8ec8a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -11,7 +11,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer from .ScaledMMLinearKernel import ( ScaledMMLinearKernel, - ScaledMMLinearLayerConfig, + Int8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) @@ -32,7 +32,7 @@ def flashinfer_w8a8_scaled_mm( class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel): def __init__( - self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + self, c: Int8ScaledMMLinearLayerConfig, layer_mapping_function: Callable ) -> None: self.quant_fp8 = QuantFP8( static=c.is_static_input_scheme, @@ -46,7 +46,7 @@ class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel): return 100 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_weight_scales = ( c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 74454743fb0db..17b932f2336db 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -13,7 +13,7 @@ from vllm.utils.torch_utils import direct_register_custom_op from .ScaledMMLinearKernel import ( ScaledMMLinearKernel, - ScaledMMLinearLayerConfig, + FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) @@ -90,7 +90,7 @@ if current_platform.is_rocm(): class ROCmScaledMMLinearKernel(ScaledMMLinearKernel): def __init__( - self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable ) -> None: self.quant_fp8 = QuantFP8( static=c.is_static_input_scheme, @@ -104,7 +104,7 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel): return 90 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: # TODO: check if this causes an issue on non-ROCM platforms from vllm.platforms.rocm import on_mi3xx diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py index 0b2c0a8b49fd1..7d82496dca023 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py @@ -12,7 +12,7 @@ from vllm.platforms import current_platform from .ScaledMMLinearKernel import ( ScaledMMLinearKernel, - ScaledMMLinearLayerConfig, + FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) @@ -136,7 +136,7 @@ def torch_channelwise_w8a8_scaled_mm( class TorchScaledMMLinearKernel(ScaledMMLinearKernel): def __init__( - self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable ) -> None: vllm_config = get_current_vllm_config().compilation_config pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE @@ -161,7 +161,7 @@ class TorchScaledMMLinearKernel(ScaledMMLinearKernel): class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: assert c.activation_group_shape is not None per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_weight_scales = ( @@ -218,7 +218,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): return 94 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: assert c.activation_group_shape is not None per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() @@ -290,7 +290,7 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): return 94 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: assert c.activation_group_shape is not None per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py index 3f4ec7f2a738b..0c8ee18457dda 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -7,7 +7,7 @@ import torch from vllm.platforms import current_platform from .cutlass import CutlassScaledMMLinearKernel -from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): @@ -16,7 +16,7 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): return 75 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if current_platform.is_cpu(): return ( False, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index ddac9f13cf4f3..6150270c8773f 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig class XLAScaledMMLinearKernel(ScaledMMLinearKernel): @@ -24,7 +24,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): ) @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_tpu(): return False, "ScaledMMXLA requires running on TPU." diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index 42d2ed2e85ed9..3d51ea2cd9580 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -7,9 +7,9 @@ import torch from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -50,7 +50,7 @@ class QuarkW8A8Int8(QuarkScheme): ): layer.logical_widths = output_partition_sizes - scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig( is_channelwise=(self.qscheme == "per_channel"), is_static_input_scheme=(self.is_static_input_scheme is True), input_symmetric=(self.input_symmetric is True),