fix types; reduce boilerplate for int8

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-01 09:59:00 +00:00
parent e845035f4c
commit d92c23b446
4 changed files with 77 additions and 72 deletions

View File

@ -11,11 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_INT8_KERNELS,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
Int8ScaledMMLinearLayerConfig,
init_int8_linear_kernel,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
@ -51,15 +47,10 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
):
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig(
self.kernel = init_int8_linear_kernel(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=self.input_symmetric,
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_INT8_KERNELS,
module_name=self.__class__.__name__,
)
@ -110,18 +101,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None)
layer_param_names = [
"weight",
"weight_scale",
"input_scale",
"input_zero_point",
"azp_adj",
]
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
)
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

View File

@ -44,13 +44,13 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
out_dtype: torch.dtype
FP8ParamsT = tuple[
_FP8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_scale_ub,
]
Int8ParamsT = tuple[
_Int8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
@ -58,11 +58,11 @@ Int8ParamsT = tuple[
torch.Tensor | None, # azp_adj
]
ParamsT = TypeVar("ParamsT", Int8ParamsT, FP8ParamsT)
ConfigT = TypeVar("ConfigT", bound=ScaledMMLinearLayerConfig)
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig)
class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
@ -70,10 +70,10 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
@classmethod
@abstractmethod
def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]:
def can_implement(cls, c: _ConfigT) -> tuple[bool, str | None]:
raise NotImplementedError
def __init__(self, c: ConfigT, layer_param_names: Sequence[str]) -> None:
def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None:
assert self.can_implement(c)
self.config = c
self.layer_param_names = layer_param_names
@ -93,12 +93,12 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
# return a covariant type in the subclass
@abstractmethod
def _get_layer_params(self, layer) -> ParamsT:
def _get_layer_params(self, layer) -> _ParamsT:
raise NotImplementedError
class FP8ScaledMMLinearKernel(
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, _FP8ParamsT], ABC
):
def __init__(
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
@ -122,7 +122,7 @@ class FP8ScaledMMLinearKernel(
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def _get_layer_params(self, layer) -> FP8ParamsT:
def _get_layer_params(self, layer) -> _FP8ParamsT:
w, w_s, x_s, x_s_ub = self.layer_param_names
return (
getattr(layer, w),
@ -133,9 +133,9 @@ class FP8ScaledMMLinearKernel(
class Int8ScaledMMLinearKernel(
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC
):
def _get_layer_params(self, layer) -> Int8ParamsT:
def _get_layer_params(self, layer) -> _Int8ParamsT:
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
return (
getattr(layer, w_q),

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import TypeVar
import torch
@ -13,6 +14,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
CPUScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassFP8ScaledMMLinearKernel,
CutlassScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
@ -26,6 +28,8 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
@ -42,15 +46,16 @@ from vllm.platforms import PlatformEnum, current_platform
logger = init_logger(__name__)
# in priority/performance order (when available)
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
# in priority/performance order (when available)
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
PlatformEnum.CUDA: [CutlassFP8ScaledMMLinearKernel],
PlatformEnum.ROCM: [
ROCmScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel,
@ -59,21 +64,25 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
],
}
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel, covariant=True)
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
def choose_scaled_mm_linear_kernel(
config: ScaledMMLinearLayerConfig,
possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]],
module_name: str,
config: _KernelConfigT,
possible_kernels: dict[PlatformEnum, list[type[_KernelT]]],
compute_capability: int | None = None,
) -> type[ScaledMMLinearKernel]:
) -> type[_KernelT]:
"""
Choose an ScaledMMLinearKernel that can implement the given config for the
Choose a _KernelT that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (ScaledMMLinearLayerConfig): Description of the linear layer
config (_KernelConfigT): Description of the linear layer
to be implemented.
possible_kernels (dict[PlatformEnum, list[_KernelT]]): A
dictionary of platforms and their list list of possible kernels.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the
compute capability. Defaults to None.
@ -82,7 +91,7 @@ def choose_scaled_mm_linear_kernel(
ValueError: If no kernel can implement the given config.
Returns:
type[ScaledMMLinearKernel]: Chosen kernel.
_KernelT: Chosen kernel.
"""
if compute_capability is None:
@ -115,9 +124,6 @@ def choose_scaled_mm_linear_kernel(
can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
logger.info_once(
"Selected %s for %s", kernel.__name__, module_name, scope="global"
)
return kernel
else:
failure_reasons.append(
@ -147,10 +153,51 @@ def init_fp8_linear_kernel(
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_FP8_KERNELS,
module_name=module_name,
)
logger.info_once(
"Selected %s for %s",
kernel_type.__class__.__name__,
module_name,
scope="global",
)
return kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
)
def init_int8_linear_kernel(
is_channelwise: bool,
is_static_input_scheme: bool,
input_symmetric: bool,
module_name: str,
) -> Int8ScaledMMLinearKernel:
config = Int8ScaledMMLinearLayerConfig(
is_channelwise=is_channelwise,
is_static_input_scheme=is_static_input_scheme,
input_symmetric=input_symmetric,
)
kernel_type = choose_scaled_mm_linear_kernel(
config, _POSSIBLE_INT8_KERNELS,
)
logger.info_once(
"Selected %s for %s",
kernel_type.__class__.__name__,
module_name,
scope="global",
)
return kernel_type(
config,
layer_param_names=[
"weight",
"weight_scale",
"input_scale",
"input_zero_point",
"azp_adj",
],
)

View File

@ -7,11 +7,7 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_INT8_KERNELS,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
Int8ScaledMMLinearLayerConfig,
init_int8_linear_kernel,
)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.parameter import (
@ -51,15 +47,10 @@ class QuarkW8A8Int8(QuarkScheme):
):
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig(
self.kernel = init_int8_linear_kernel(
is_channelwise=(self.qscheme == "per_channel"),
is_static_input_scheme=(self.is_static_input_scheme is True),
input_symmetric=(self.input_symmetric is True),
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
possible_kernels=_POSSIBLE_INT8_KERNELS,
module_name=self.__class__.__name__,
)
@ -119,18 +110,6 @@ class QuarkW8A8Int8(QuarkScheme):
if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None)
layer_param_names = [
"weight",
"weight_scale",
"input_scale",
"input_zero_point",
"azp_adj",
]
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
)
# Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: