mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 19:37:05 +08:00
fix types; reduce boilerplate for int8
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
e845035f4c
commit
d92c23b446
@ -11,11 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
_POSSIBLE_INT8_KERNELS,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
init_int8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
@ -51,15 +47,10 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig(
|
||||
self.kernel = init_int8_linear_kernel(
|
||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
input_symmetric=self.input_symmetric,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config,
|
||||
_POSSIBLE_INT8_KERNELS,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@ -110,18 +101,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
layer_param_names = [
|
||||
"weight",
|
||||
"weight_scale",
|
||||
"input_scale",
|
||||
"input_zero_point",
|
||||
"azp_adj",
|
||||
]
|
||||
|
||||
self.kernel = kernel_type(
|
||||
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
@ -44,13 +44,13 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
out_dtype: torch.dtype
|
||||
|
||||
|
||||
FP8ParamsT = tuple[
|
||||
_FP8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_scale_ub,
|
||||
]
|
||||
Int8ParamsT = tuple[
|
||||
_Int8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
@ -58,11 +58,11 @@ Int8ParamsT = tuple[
|
||||
torch.Tensor | None, # azp_adj
|
||||
]
|
||||
|
||||
ParamsT = TypeVar("ParamsT", Int8ParamsT, FP8ParamsT)
|
||||
ConfigT = TypeVar("ConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
|
||||
_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
|
||||
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -70,10 +70,10 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: _ConfigT) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, c: ConfigT, layer_param_names: Sequence[str]) -> None:
|
||||
def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.layer_param_names = layer_param_names
|
||||
@ -93,12 +93,12 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
|
||||
|
||||
# return a covariant type in the subclass
|
||||
@abstractmethod
|
||||
def _get_layer_params(self, layer) -> ParamsT:
|
||||
def _get_layer_params(self, layer) -> _ParamsT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FP8ScaledMMLinearKernel(
|
||||
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC
|
||||
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, _FP8ParamsT], ABC
|
||||
):
|
||||
def __init__(
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||
@ -122,7 +122,7 @@ class FP8ScaledMMLinearKernel(
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def _get_layer_params(self, layer) -> FP8ParamsT:
|
||||
def _get_layer_params(self, layer) -> _FP8ParamsT:
|
||||
w, w_s, x_s, x_s_ub = self.layer_param_names
|
||||
return (
|
||||
getattr(layer, w),
|
||||
@ -133,9 +133,9 @@ class FP8ScaledMMLinearKernel(
|
||||
|
||||
|
||||
class Int8ScaledMMLinearKernel(
|
||||
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC
|
||||
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC
|
||||
):
|
||||
def _get_layer_params(self, layer) -> Int8ParamsT:
|
||||
def _get_layer_params(self, layer) -> _Int8ParamsT:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
|
||||
return (
|
||||
getattr(layer, w_q),
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
@ -13,6 +14,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||
CPUScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
CutlassScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
|
||||
@ -26,6 +28,8 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
@ -42,15 +46,16 @@ from vllm.platforms import PlatformEnum, current_platform
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||
}
|
||||
|
||||
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [CutlassFP8ScaledMMLinearKernel],
|
||||
PlatformEnum.ROCM: [
|
||||
ROCmScaledMMLinearKernel,
|
||||
PerTensorTorchScaledMMLinearKernel,
|
||||
@ -59,21 +64,25 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
],
|
||||
}
|
||||
|
||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel, covariant=True)
|
||||
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
def choose_scaled_mm_linear_kernel(
|
||||
config: ScaledMMLinearLayerConfig,
|
||||
possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]],
|
||||
module_name: str,
|
||||
config: _KernelConfigT,
|
||||
possible_kernels: dict[PlatformEnum, list[type[_KernelT]]],
|
||||
compute_capability: int | None = None,
|
||||
) -> type[ScaledMMLinearKernel]:
|
||||
) -> type[_KernelT]:
|
||||
"""
|
||||
Choose an ScaledMMLinearKernel that can implement the given config for the
|
||||
Choose a _KernelT that can implement the given config for the
|
||||
given compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (ScaledMMLinearLayerConfig): Description of the linear layer
|
||||
config (_KernelConfigT): Description of the linear layer
|
||||
to be implemented.
|
||||
possible_kernels (dict[PlatformEnum, list[_KernelT]]): A
|
||||
dictionary of platforms and their list list of possible kernels.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get the
|
||||
compute capability. Defaults to None.
|
||||
@ -82,7 +91,7 @@ def choose_scaled_mm_linear_kernel(
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
type[ScaledMMLinearKernel]: Chosen kernel.
|
||||
_KernelT: Chosen kernel.
|
||||
"""
|
||||
|
||||
if compute_capability is None:
|
||||
@ -115,9 +124,6 @@ def choose_scaled_mm_linear_kernel(
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if can_implement:
|
||||
logger.info_once(
|
||||
"Selected %s for %s", kernel.__name__, module_name, scope="global"
|
||||
)
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
@ -147,10 +153,51 @@ def init_fp8_linear_kernel(
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config,
|
||||
_POSSIBLE_FP8_KERNELS,
|
||||
module_name=module_name,
|
||||
)
|
||||
|
||||
logger.info_once(
|
||||
"Selected %s for %s",
|
||||
kernel_type.__class__.__name__,
|
||||
module_name,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
return kernel_type(
|
||||
scaled_mm_linear_kernel_config,
|
||||
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
|
||||
)
|
||||
|
||||
|
||||
def init_int8_linear_kernel(
|
||||
is_channelwise: bool,
|
||||
is_static_input_scheme: bool,
|
||||
input_symmetric: bool,
|
||||
module_name: str,
|
||||
) -> Int8ScaledMMLinearKernel:
|
||||
config = Int8ScaledMMLinearLayerConfig(
|
||||
is_channelwise=is_channelwise,
|
||||
is_static_input_scheme=is_static_input_scheme,
|
||||
input_symmetric=input_symmetric,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
config, _POSSIBLE_INT8_KERNELS,
|
||||
)
|
||||
|
||||
logger.info_once(
|
||||
"Selected %s for %s",
|
||||
kernel_type.__class__.__name__,
|
||||
module_name,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
return kernel_type(
|
||||
config,
|
||||
layer_param_names=[
|
||||
"weight",
|
||||
"weight_scale",
|
||||
"input_scale",
|
||||
"input_zero_point",
|
||||
"azp_adj",
|
||||
],
|
||||
)
|
||||
|
||||
@ -7,11 +7,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
_POSSIBLE_INT8_KERNELS,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
init_int8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.parameter import (
|
||||
@ -51,15 +47,10 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig(
|
||||
self.kernel = init_int8_linear_kernel(
|
||||
is_channelwise=(self.qscheme == "per_channel"),
|
||||
is_static_input_scheme=(self.is_static_input_scheme is True),
|
||||
input_symmetric=(self.input_symmetric is True),
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config,
|
||||
possible_kernels=_POSSIBLE_INT8_KERNELS,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@ -119,18 +110,6 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
layer_param_names = [
|
||||
"weight",
|
||||
"weight_scale",
|
||||
"input_scale",
|
||||
"input_zero_point",
|
||||
"azp_adj",
|
||||
]
|
||||
|
||||
self.kernel = kernel_type(
|
||||
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in quark format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user