fix int8 path

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-10-30 08:04:24 +00:00
parent 974e6820ce
commit e54e572085
13 changed files with 67 additions and 72 deletions

View File

@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
ScaledMMLinearLayerConfig,
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@ -91,10 +91,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif self.strategy == QuantizationStrategy.CHANNEL:
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=True,
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
weight_quant_strategy=weight_quant_strategy,
activation_group_shape=self.act_q_group_shape,
out_dtype=self.out_dtype,

View File

@ -11,8 +11,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
choose_scaled_mm_linear_kernel,
_POSSIBLE_INT8_KERNELS
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
logger = init_logger(__name__)
@ -50,13 +51,16 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
):
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=self.input_symmetric,
)
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_INT8_KERNELS
)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__)
@ -90,12 +94,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
input_zero_point=None
input_scale=None
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
)
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
@ -103,15 +107,21 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
)
layer.register_parameter("input_zero_point", input_zero_point)
layer.register_parameter("input_zero_point", input_zero_point)
layer.register_parameter("input_scale", input_scale)
if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None)
param_name_list = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
layer_mapping_function = lambda layer: (
tuple(getattr(layer, param_name) for param_name in param_name_list),
param_name_list,
)
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj",
layer_mapping_function = layer_mapping_function
)
# Checkpoints are serialized in compressed-tensors format, which is

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from typing import Generic, TypeVar
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
@ -25,16 +25,25 @@ class ScaledMMLinearQuantStrategy(Enum):
@dataclass
class ScaledMMLinearLayerConfig:
# TODO: remove is channelwise
pass
@dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
is_channelwise: bool
is_static_input_scheme: bool
input_symmetric: bool
out_dtype: torch.dtype | None
@dataclass
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
weight_quant_strategy: ScaledMMLinearQuantStrategy
activation_group_shape: GroupShape | None = GroupShape.PER_TENSOR
activation_group_shape: GroupShape
out_dtype: torch.dtype
class ScaledMMLinearKernel(ABC):
ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig)
class ScaledMMLinearKernel(Generic[ConfigT], ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
@ -42,11 +51,11 @@ class ScaledMMLinearKernel(ABC):
@classmethod
@abstractmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]:
raise NotImplementedError
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
self, c: ConfigT, layer_mapping_function: Callable
) -> None:
assert self.can_implement(c)
self.config = c
@ -63,21 +72,4 @@ class ScaledMMLinearKernel(ABC):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
# def _get_weight_params(
# self, layer: torch.nn.Module
# ) -> tuple[
# torch.Tensor, # weight
# torch.Tensor, # weight_scale
# torch.Tensor | None, # input_scale,
# torch.Tensor | None, # input_zp
# torch.Tensor | None, # azp_adj
# ]:
# return (
# getattr(layer, self.w_q_name),
# getattr(layer, self.w_s_name),
# getattr(layer, self.i_s_name),
# getattr(layer, self.i_zp_name),
# getattr(layer, self.azp_adj_name),
# )
raise NotImplementedError

View File

@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
ChannelWiseTorchScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel,

View File

@ -10,7 +10,7 @@ from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import process_weights_after_loading
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
def rocm_aiter_gemm_w8a8_impl(
@ -58,7 +58,7 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
return 90
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return (
False,

View File

@ -14,7 +14,7 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
@ -23,7 +23,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
return 75
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "CPUScaledMM requires running on CPU."

View File

@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig
def cutlass_w8a8_scaled_mm(
@ -36,7 +36,7 @@ def cutlass_w8a8_scaled_mm(
def process_weights_after_loading(
config: ScaledMMLinearLayerConfig,
config: Int8ScaledMMLinearLayerConfig,
layer: torch.nn.Module,
w_q_name: str,
w_s_name: str,
@ -98,9 +98,6 @@ def process_weights_after_loading(
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, i_s_name, None)
setattr(layer, i_zp_name, None)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
@ -119,8 +116,6 @@ def process_weights_after_loading(
azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, azp_adj_name, None)
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
@ -129,7 +124,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
return 75
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "CutlassScaledMM requires running on CUDA."
@ -177,7 +172,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
@ -192,7 +187,7 @@ class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
return 89
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return (
False,

View File

@ -11,7 +11,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from .ScaledMMLinearKernel import (
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
@ -32,7 +32,7 @@ def flashinfer_w8a8_scaled_mm(
class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
self, c: Int8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
@ -46,7 +46,7 @@ class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
return 100
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR

View File

@ -13,7 +13,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
from .ScaledMMLinearKernel import (
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
@ -90,7 +90,7 @@ if current_platform.is_rocm():
class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
@ -104,7 +104,7 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
return 90
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
# TODO: check if this causes an issue on non-ROCM platforms
from vllm.platforms.rocm import on_mi3xx

View File

@ -12,7 +12,7 @@ from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
@ -136,7 +136,7 @@ def torch_channelwise_w8a8_scaled_mm(
class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
vllm_config = get_current_vllm_config().compilation_config
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
@ -161,7 +161,7 @@ class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
@ -218,7 +218,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
return 94
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
@ -290,7 +290,7 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
return 94
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()

View File

@ -7,7 +7,7 @@ import torch
from vllm.platforms import current_platform
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@ -16,7 +16,7 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
return 75
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if current_platform.is_cpu():
return (
False,

View File

@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@ -24,7 +24,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
)
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_tpu():
return False, "ScaledMMXLA requires running on TPU."

View File

@ -7,9 +7,9 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.parameter import (
BasevLLMParameter,
@ -50,7 +50,7 @@ class QuarkW8A8Int8(QuarkScheme):
):
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig(
is_channelwise=(self.qscheme == "per_channel"),
is_static_input_scheme=(self.is_static_input_scheme is True),
input_symmetric=(self.input_symmetric is True),