diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 914d0e1bd08a0..652feb1964575 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -11,11 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - _POSSIBLE_INT8_KERNELS, - choose_scaled_mm_linear_kernel, -) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - Int8ScaledMMLinearLayerConfig, + init_int8_linear_kernel, ) from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -51,15 +47,10 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ): layer.logical_widths = output_partition_sizes - scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig( + self.kernel = init_int8_linear_kernel( is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), is_static_input_scheme=self.is_static_input_scheme, input_symmetric=self.input_symmetric, - ) - - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - _POSSIBLE_INT8_KERNELS, module_name=self.__class__.__name__, ) @@ -110,18 +101,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): if not hasattr(layer, "azp_adj"): layer.register_parameter("azp_adj", None) - layer_param_names = [ - "weight", - "weight_scale", - "input_scale", - "input_zero_point", - "azp_adj", - ] - - self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names - ) - # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 9798f88b140a5..329078f0a489c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -44,13 +44,13 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): out_dtype: torch.dtype -FP8ParamsT = tuple[ +_FP8ParamsT = tuple[ torch.Tensor, # weight torch.Tensor, # weight_scale torch.Tensor | None, # input_scale, torch.Tensor | None, # input_scale_ub, ] -Int8ParamsT = tuple[ +_Int8ParamsT = tuple[ torch.Tensor, # weight torch.Tensor, # weight_scale torch.Tensor | None, # input_scale, @@ -58,11 +58,11 @@ Int8ParamsT = tuple[ torch.Tensor | None, # azp_adj ] -ParamsT = TypeVar("ParamsT", Int8ParamsT, FP8ParamsT) -ConfigT = TypeVar("ConfigT", bound=ScaledMMLinearLayerConfig) +_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT) +_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig) -class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC): +class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC): @classmethod @abstractmethod def get_min_capability(cls) -> int: @@ -70,10 +70,10 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC): @classmethod @abstractmethod - def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]: + def can_implement(cls, c: _ConfigT) -> tuple[bool, str | None]: raise NotImplementedError - def __init__(self, c: ConfigT, layer_param_names: Sequence[str]) -> None: + def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None: assert self.can_implement(c) self.config = c self.layer_param_names = layer_param_names @@ -93,12 +93,12 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC): # return a covariant type in the subclass @abstractmethod - def _get_layer_params(self, layer) -> ParamsT: + def _get_layer_params(self, layer) -> _ParamsT: raise NotImplementedError class FP8ScaledMMLinearKernel( - ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC + ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, _FP8ParamsT], ABC ): def __init__( self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str] @@ -122,7 +122,7 @@ class FP8ScaledMMLinearKernel( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass - def _get_layer_params(self, layer) -> FP8ParamsT: + def _get_layer_params(self, layer) -> _FP8ParamsT: w, w_s, x_s, x_s_ub = self.layer_param_names return ( getattr(layer, w), @@ -133,9 +133,9 @@ class FP8ScaledMMLinearKernel( class Int8ScaledMMLinearKernel( - ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC + ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC ): - def _get_layer_params(self, layer) -> Int8ParamsT: + def _get_layer_params(self, layer) -> _Int8ParamsT: w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names return ( getattr(layer, w_q), diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 3c0ee8323c555..2e00775b90d6e 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from typing import TypeVar import torch @@ -13,6 +14,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( CPUScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassFP8ScaledMMLinearKernel, CutlassScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( @@ -26,6 +28,8 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, ScaledMMLinearKernel, ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, @@ -42,15 +46,16 @@ from vllm.platforms import PlatformEnum, current_platform logger = init_logger(__name__) # in priority/performance order (when available) -_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { +_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], } -_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { - PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], +# in priority/performance order (when available) +_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = { + PlatformEnum.CUDA: [CutlassFP8ScaledMMLinearKernel], PlatformEnum.ROCM: [ ROCmScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, @@ -59,21 +64,25 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { ], } +_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel, covariant=True) +_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig) + def choose_scaled_mm_linear_kernel( - config: ScaledMMLinearLayerConfig, - possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]], - module_name: str, + config: _KernelConfigT, + possible_kernels: dict[PlatformEnum, list[type[_KernelT]]], compute_capability: int | None = None, -) -> type[ScaledMMLinearKernel]: +) -> type[_KernelT]: """ - Choose an ScaledMMLinearKernel that can implement the given config for the + Choose a _KernelT that can implement the given config for the given compute capability. Attempts to choose the best kernel in terms of performance. Args: - config (ScaledMMLinearLayerConfig): Description of the linear layer + config (_KernelConfigT): Description of the linear layer to be implemented. + possible_kernels (dict[PlatformEnum, list[_KernelT]]): A + dictionary of platforms and their list list of possible kernels. compute_capability (Optional[int], optional): The compute capability of the target device, if None uses `current_platform` to get the compute capability. Defaults to None. @@ -82,7 +91,7 @@ def choose_scaled_mm_linear_kernel( ValueError: If no kernel can implement the given config. Returns: - type[ScaledMMLinearKernel]: Chosen kernel. + _KernelT: Chosen kernel. """ if compute_capability is None: @@ -115,9 +124,6 @@ def choose_scaled_mm_linear_kernel( can_implement, failure_reason = kernel.can_implement(config) if can_implement: - logger.info_once( - "Selected %s for %s", kernel.__name__, module_name, scope="global" - ) return kernel else: failure_reasons.append( @@ -147,10 +153,51 @@ def init_fp8_linear_kernel( kernel_type = choose_scaled_mm_linear_kernel( scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, - module_name=module_name, + ) + + logger.info_once( + "Selected %s for %s", + kernel_type.__class__.__name__, + module_name, + scope="global", ) return kernel_type( scaled_mm_linear_kernel_config, layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"], ) + + +def init_int8_linear_kernel( + is_channelwise: bool, + is_static_input_scheme: bool, + input_symmetric: bool, + module_name: str, +) -> Int8ScaledMMLinearKernel: + config = Int8ScaledMMLinearLayerConfig( + is_channelwise=is_channelwise, + is_static_input_scheme=is_static_input_scheme, + input_symmetric=input_symmetric, + ) + + kernel_type = choose_scaled_mm_linear_kernel( + config, _POSSIBLE_INT8_KERNELS, + ) + + logger.info_once( + "Selected %s for %s", + kernel_type.__class__.__name__, + module_name, + scope="global", + ) + + return kernel_type( + config, + layer_param_names=[ + "weight", + "weight_scale", + "input_scale", + "input_zero_point", + "azp_adj", + ], + ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index ea8db2456f865..a7a7726bae0e2 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -7,11 +7,7 @@ import torch from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - _POSSIBLE_INT8_KERNELS, - choose_scaled_mm_linear_kernel, -) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - Int8ScaledMMLinearLayerConfig, + init_int8_linear_kernel, ) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.parameter import ( @@ -51,15 +47,10 @@ class QuarkW8A8Int8(QuarkScheme): ): layer.logical_widths = output_partition_sizes - scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig( + self.kernel = init_int8_linear_kernel( is_channelwise=(self.qscheme == "per_channel"), is_static_input_scheme=(self.is_static_input_scheme is True), input_symmetric=(self.input_symmetric is True), - ) - - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - possible_kernels=_POSSIBLE_INT8_KERNELS, module_name=self.__class__.__name__, ) @@ -119,18 +110,6 @@ class QuarkW8A8Int8(QuarkScheme): if not hasattr(layer, "azp_adj"): layer.register_parameter("azp_adj", None) - layer_param_names = [ - "weight", - "weight_scale", - "input_scale", - "input_zero_point", - "azp_adj", - ] - - self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names - ) - # Checkpoints are serialized in quark format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: