fix int8 path

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-10-30 08:04:24 +00:00
parent 974e6820ce
commit e54e572085
13 changed files with 67 additions and 72 deletions

View File

@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
choose_scaled_mm_linear_kernel, choose_scaled_mm_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy, ScaledMMLinearQuantStrategy,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@ -91,10 +91,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif self.strategy == QuantizationStrategy.CHANNEL: elif self.strategy == QuantizationStrategy.CHANNEL:
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=True,
weight_quant_strategy=weight_quant_strategy, weight_quant_strategy=weight_quant_strategy,
activation_group_shape=self.act_q_group_shape, activation_group_shape=self.act_q_group_shape,
out_dtype=self.out_dtype, out_dtype=self.out_dtype,

View File

@ -11,8 +11,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
choose_scaled_mm_linear_kernel, choose_scaled_mm_linear_kernel,
_POSSIBLE_INT8_KERNELS
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BasevLLMParameter, BasevLLMParameter,
@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@ -50,13 +51,16 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
): ):
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme, is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=self.input_symmetric, input_symmetric=self.input_symmetric,
) )
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_INT8_KERNELS
)
if kernel_type.__name__ not in self._kernel_backends_being_used: if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__) logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__)
@ -90,12 +94,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE # INPUT SCALE
input_zero_point=None
input_scale=None
if self.is_static_input_scheme: if self.is_static_input_scheme:
input_scale = BasevLLMParameter( input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
) )
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric: if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype # Note: compressed-tensors stores the zp using the same dtype
# as the weights # as the weights
@ -103,15 +107,21 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_zero_point = BasevLLMParameter( input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
) )
layer.register_parameter("input_zero_point", input_zero_point)
layer.register_parameter("input_zero_point", input_zero_point)
layer.register_parameter("input_scale", input_scale)
if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None)
param_name_list = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
layer_mapping_function = lambda layer: (
tuple(getattr(layer, param_name) for param_name in param_name_list),
param_name_list,
)
self.kernel = kernel_type( self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config, c=scaled_mm_linear_kernel_config,
w_q_param_name="weight", layer_mapping_function = layer_mapping_function
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj",
) )
# Checkpoints are serialized in compressed-tensors format, which is # Checkpoints are serialized in compressed-tensors format, which is

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Generic, TypeVar
import torch import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
@ -25,16 +25,25 @@ class ScaledMMLinearQuantStrategy(Enum):
@dataclass @dataclass
class ScaledMMLinearLayerConfig: class ScaledMMLinearLayerConfig:
# TODO: remove is channelwise pass
@dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
is_channelwise: bool is_channelwise: bool
is_static_input_scheme: bool is_static_input_scheme: bool
input_symmetric: bool input_symmetric: bool
out_dtype: torch.dtype | None
@dataclass
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
weight_quant_strategy: ScaledMMLinearQuantStrategy weight_quant_strategy: ScaledMMLinearQuantStrategy
activation_group_shape: GroupShape | None = GroupShape.PER_TENSOR activation_group_shape: GroupShape
out_dtype: torch.dtype
class ScaledMMLinearKernel(ABC): ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig)
class ScaledMMLinearKernel(Generic[ConfigT], ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
@ -42,11 +51,11 @@ class ScaledMMLinearKernel(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]:
raise NotImplementedError raise NotImplementedError
def __init__( def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable self, c: ConfigT, layer_mapping_function: Callable
) -> None: ) -> None:
assert self.can_implement(c) assert self.can_implement(c)
self.config = c self.config = c
@ -64,20 +73,3 @@ class ScaledMMLinearKernel(ABC):
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
# def _get_weight_params(
# self, layer: torch.nn.Module
# ) -> tuple[
# torch.Tensor, # weight
# torch.Tensor, # weight_scale
# torch.Tensor | None, # input_scale,
# torch.Tensor | None, # input_zp
# torch.Tensor | None, # azp_adj
# ]:
# return (
# getattr(layer, self.w_q_name),
# getattr(layer, self.w_s_name),
# getattr(layer, self.i_s_name),
# getattr(layer, self.i_zp_name),
# getattr(layer, self.azp_adj_name),
# )

View File

@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
ScaledMMLinearKernel, ScaledMMLinearKernel,
ScaledMMLinearLayerConfig, ScaledMMLinearLayerConfig,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
ChannelWiseTorchScaledMMLinearKernel, ChannelWiseTorchScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel,

View File

@ -10,7 +10,7 @@ from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import process_weights_after_loading from .cutlass import process_weights_after_loading
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
def rocm_aiter_gemm_w8a8_impl( def rocm_aiter_gemm_w8a8_impl(
@ -58,7 +58,7 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
return 90 return 90
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_rocm(): if not current_platform.is_rocm():
return ( return (
False, False,

View File

@ -14,7 +14,7 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
class CPUScaledMMLinearKernel(ScaledMMLinearKernel): class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
@ -23,7 +23,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
return 75 return 75
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cpu(): if not current_platform.is_cpu():
return False, "CPUScaledMM requires running on CPU." return False, "CPUScaledMM requires running on CPU."

View File

@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig
def cutlass_w8a8_scaled_mm( def cutlass_w8a8_scaled_mm(
@ -36,7 +36,7 @@ def cutlass_w8a8_scaled_mm(
def process_weights_after_loading( def process_weights_after_loading(
config: ScaledMMLinearLayerConfig, config: Int8ScaledMMLinearLayerConfig,
layer: torch.nn.Module, layer: torch.nn.Module,
w_q_name: str, w_q_name: str,
w_s_name: str, w_s_name: str,
@ -98,9 +98,6 @@ def process_weights_after_loading(
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
) )
else:
setattr(layer, i_s_name, None)
setattr(layer, i_zp_name, None)
# azp_adj is the AZP adjustment term, used to account for weights. # azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for # It does not depend on scales or azp, so it is the same for
@ -119,8 +116,6 @@ def process_weights_after_loading(
azp_adj_name, azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False), torch.nn.Parameter(azp_adj, requires_grad=False),
) )
else:
setattr(layer, azp_adj_name, None)
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
@ -129,7 +124,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
return 75 return 75
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda(): if not current_platform.is_cuda():
return False, "CutlassScaledMM requires running on CUDA." return False, "CutlassScaledMM requires running on CUDA."
@ -177,7 +172,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel): class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__( def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None: ) -> None:
self.quant_fp8 = QuantFP8( self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme, static=c.is_static_input_scheme,
@ -192,7 +187,7 @@ class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
return 89 return 89
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda(): if not current_platform.is_cuda():
return ( return (
False, False,

View File

@ -11,7 +11,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from .ScaledMMLinearKernel import ( from .ScaledMMLinearKernel import (
ScaledMMLinearKernel, ScaledMMLinearKernel,
ScaledMMLinearLayerConfig, Int8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy, ScaledMMLinearQuantStrategy,
) )
@ -32,7 +32,7 @@ def flashinfer_w8a8_scaled_mm(
class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel): class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__( def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable self, c: Int8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None: ) -> None:
self.quant_fp8 = QuantFP8( self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme, static=c.is_static_input_scheme,
@ -46,7 +46,7 @@ class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
return 100 return 100
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = ( per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR

View File

@ -13,7 +13,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
from .ScaledMMLinearKernel import ( from .ScaledMMLinearKernel import (
ScaledMMLinearKernel, ScaledMMLinearKernel,
ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy, ScaledMMLinearQuantStrategy,
) )
@ -90,7 +90,7 @@ if current_platform.is_rocm():
class ROCmScaledMMLinearKernel(ScaledMMLinearKernel): class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__( def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None: ) -> None:
self.quant_fp8 = QuantFP8( self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme, static=c.is_static_input_scheme,
@ -104,7 +104,7 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
return 90 return 90
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
# TODO: check if this causes an issue on non-ROCM platforms # TODO: check if this causes an issue on non-ROCM platforms
from vllm.platforms.rocm import on_mi3xx from vllm.platforms.rocm import on_mi3xx

View File

@ -12,7 +12,7 @@ from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ( from .ScaledMMLinearKernel import (
ScaledMMLinearKernel, ScaledMMLinearKernel,
ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy, ScaledMMLinearQuantStrategy,
) )
@ -136,7 +136,7 @@ def torch_channelwise_w8a8_scaled_mm(
class TorchScaledMMLinearKernel(ScaledMMLinearKernel): class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__( def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None: ) -> None:
vllm_config = get_current_vllm_config().compilation_config vllm_config = get_current_vllm_config().compilation_config
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
@ -161,7 +161,7 @@ class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = ( per_tensor_weight_scales = (
@ -218,7 +218,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
return 94 return 94
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
@ -290,7 +290,7 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
return 94 return 94
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()

View File

@ -7,7 +7,7 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .cutlass import CutlassScaledMMLinearKernel from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@ -16,7 +16,7 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
return 75 return 75
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if current_platform.is_cpu(): if current_platform.is_cpu():
return ( return (
False, False,

View File

@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
class XLAScaledMMLinearKernel(ScaledMMLinearKernel): class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@ -24,7 +24,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
) )
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_tpu(): if not current_platform.is_tpu():
return False, "ScaledMMXLA requires running on TPU." return False, "ScaledMMXLA requires running on TPU."

View File

@ -7,9 +7,9 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
choose_scaled_mm_linear_kernel, choose_scaled_mm_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BasevLLMParameter, BasevLLMParameter,
@ -50,7 +50,7 @@ class QuarkW8A8Int8(QuarkScheme):
): ):
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig(
is_channelwise=(self.qscheme == "per_channel"), is_channelwise=(self.qscheme == "per_channel"),
is_static_input_scheme=(self.is_static_input_scheme is True), is_static_input_scheme=(self.is_static_input_scheme is True),
input_symmetric=(self.input_symmetric is True), input_symmetric=(self.input_symmetric is True),