From c05027f67a8d8cc645207163a43838ffbf90174a Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 30 Oct 2025 12:27:04 +0000 Subject: [PATCH] clean up; fix quark path Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 26 +- .../schemes/compressed_tensors_w8a8_int8.py | 8 +- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 91 +++++-- .../quantization/kernels/scaled_mm/aiter.py | 13 +- .../quantization/kernels/scaled_mm/cpu.py | 66 +++-- .../quantization/kernels/scaled_mm/cutlass.py | 232 +++++++----------- .../kernels/scaled_mm/flash_infer.py | 60 ++--- .../quantization/kernels/scaled_mm/rocm.py | 64 ++--- .../quantization/kernels/scaled_mm/torch.py | 146 +++-------- .../quantization/kernels/scaled_mm/utils.py | 44 ++++ .../quantization/kernels/scaled_mm/xla.py | 23 +- .../quark/schemes/quark_w8a8_int8.py | 17 +- 12 files changed, 357 insertions(+), 433 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index bd9a6bd0ef043..53e7ed2fb3fcb 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -6,6 +6,7 @@ from collections.abc import Callable import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from torch.nn import Parameter +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, @@ -17,6 +18,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, + QUANT_STRATEGY_MAP, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -49,8 +51,11 @@ strategy_to_parameter_type = { QuantizationStrategy.TENSOR: PerTensorScaleParameter, } +logger = init_logger(__name__) class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): + _kernel_backends_being_used: set[str] = set() + def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): self.weight_quant = weight_quant self.strategy = weight_quant.strategy @@ -79,19 +84,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - param_name_list = ["weight", "weight_scale", "input_scale"] - layer_mapping_function = lambda layer: ( - tuple(getattr(layer, param_name) for param_name in param_name_list), - param_name_list, - ) - - # TODO: clean up - if self.strategy == QuantizationStrategy.TENSOR: - weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR - elif self.strategy == QuantizationStrategy.CHANNEL: - weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL - + layer_param_names = ["weight", "weight_scale", "input_scale"] + weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy] scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + is_static_input_scheme=self.is_static_input_scheme, weight_quant_strategy=weight_quant_strategy, activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, @@ -101,9 +97,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): _POSSIBLE_FP8_KERNELS, ) self.fp8_linear = kernel( - scaled_mm_linear_kernel_config, layer_mapping_function + scaled_mm_linear_kernel_config, layer_param_names = layer_param_names ) + if kernel.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW8A8FP8", kernel.__name__) + self._kernel_backends_being_used.add(kernel.__name__) + @classmethod def get_min_capability(cls) -> int: # lovelace and up diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 049f96f1faa35..a0ae8655ca650 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -113,15 +113,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): if not hasattr(layer, "azp_adj"): layer.register_parameter("azp_adj", None) - param_name_list = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] + layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] - layer_mapping_function = lambda layer: ( - tuple(getattr(layer, param_name) for param_name in param_name_list), - param_name_list, - ) self.kernel = kernel_type( c=scaled_mm_linear_kernel_config, - layer_mapping_function = layer_mapping_function + layer_param_names = layer_param_names ) # Checkpoints are serialized in compressed-tensors format, which is diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index e12aa2c5c4d2c..27af30ae131c8 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -5,10 +5,12 @@ from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass from enum import Enum -from typing import Generic, TypeVar +from typing import Generic, Sequence, TypeVar import torch +from compressed_tensors.quantization import QuantizationStrategy from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 class ScaledMMLinearQuantStrategy(Enum): @@ -16,21 +18,19 @@ class ScaledMMLinearQuantStrategy(Enum): CHANNEL = "channel" BLOCK = "block" - def is_per_token(self) -> bool: - return self.row == 1 and self.col == -1 - - def is_per_group(self) -> bool: - return self.row == 1 and self.col >= 1 - +QUANT_STRATEGY_MAP = { + QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR, + QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL, + QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.BLOCK, +} @dataclass class ScaledMMLinearLayerConfig: - pass + is_static_input_scheme: bool @dataclass class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): is_channelwise: bool - is_static_input_scheme: bool input_symmetric: bool @dataclass @@ -40,10 +40,24 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): out_dtype: torch.dtype + +Int8ParamsT = tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + torch.Tensor | None, # input_scale, +] +FP8ParamsT = tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + torch.Tensor | None, # input_scale, + torch.Tensor | None, # input_zp + torch.Tensor | None, # azp_adj + ] + +ParamsT = TypeVar('ParamsT', Int8ParamsT, FP8ParamsT) ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig) - -class ScaledMMLinearKernel(Generic[ConfigT], ABC): +class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC): @classmethod @abstractmethod def get_min_capability(cls) -> int: @@ -55,11 +69,11 @@ class ScaledMMLinearKernel(Generic[ConfigT], ABC): raise NotImplementedError def __init__( - self, c: ConfigT, layer_mapping_function: Callable + self, c: ConfigT, layer_param_names: Sequence[str] ) -> None: assert self.can_implement(c) self.config = c - self.layer_mapping_function = layer_mapping_function + self.layer_param_names = layer_param_names @abstractmethod def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -72,4 +86,53 @@ class ScaledMMLinearKernel(Generic[ConfigT], ABC): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + # return a covariant type in the subclass + @abstractmethod + def _get_layer_params(self, layer) -> ParamsT: + raise NotImplementedError + + +class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC): + def __init__( + self, c: ConfigT, layer_param_names: Sequence[str] + ) -> None: + self.quant_fp8 = QuantFP8( + static=c.is_static_input_scheme, + group_shape=c.activation_group_shape, + num_token_padding=self.get_ouput_padding(), + ) + super().__init__(c, layer_param_names) + + @abstractmethod + def get_ouput_padding(self) -> int | None: + raise NotImplementedError + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def _get_layer_params(self, layer) -> FP8ParamsT: + w, w_s, x_s = self.layer_param_names + return ( + getattr(layer, w), + getattr(layer, w_s), + getattr(layer, x_s), + ) + + +class Int8ScaledMMLinearKernel(ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC): + def _get_layer_params(self, layer) -> Int8ParamsT: + w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names + return ( + getattr(layer, w_q), + getattr(layer, w_s), + getattr(layer, i_s), + getattr(layer, i_zp), + getattr(layer, azp_adj), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index a39e96bca614b..3ac90553bbc74 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -9,8 +9,8 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op -from .cutlass import process_weights_after_loading -from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig +from .cutlass import CutlassScaledMMLinearKernel +from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig def rocm_aiter_gemm_w8a8_impl( @@ -52,7 +52,7 @@ if current_platform.is_rocm(): ) -class AiterScaledMMLinearKernel(ScaledMMLinearKernel): +class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: return 90 @@ -91,11 +91,6 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel): ) return True, None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - _, param_names = self.layer_mapping_function(layer) - - process_weights_after_loading(self.config, layer, *param_names) - def apply_weights( self, layer: torch.nn.Module, @@ -112,7 +107,7 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel): w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support ATIER block scaled GEMM and mix-precision GEMM. """ - (w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer) + w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index 9c8ece8559b48..b84ef7814f0a2 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -14,10 +14,10 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig -class CPUScaledMMLinearKernel(ScaledMMLinearKernel): +class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: return 75 @@ -30,7 +30,8 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - weight = getattr(layer, self.w_q_name) + w_q_name, _, _, _, _ = self.layer_param_names + weight = getattr(layer, w_q_name) dtype = weight.dtype N, K = weight.size() if ( @@ -48,10 +49,13 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: # WEIGHT # Transpose to [K, N] for convenience - weight = getattr(layer, self.w_q_name) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( + self.layer_param_names + ) + weight = getattr(layer, w_q_name) replace_parameter( layer, - self.w_q_name, + w_q_name, torch.nn.Parameter(weight.t().data, requires_grad=False), ) @@ -60,28 +64,27 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): # If we have a fused module (QKV, MLP) with per tensor scales (thus N # scales being passed to the kernel), convert to the per-channel case. is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) + weight_scale = getattr(layer, w_s_name) if is_fused_module and not self.config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) # INPUT SCALE if self.config.is_static_input_scheme: - input_scale = getattr(layer, self.i_s_name) + input_scale = getattr(layer, i_s_name) if self.config.input_symmetric: replace_parameter( layer, - self.i_s_name, + i_s_name, torch.nn.Parameter(input_scale.max(), requires_grad=False), ) - setattr(layer, self.i_zp_name, None) else: - input_zero_point = getattr(layer, self.i_zp_name) + input_zero_point = getattr(layer, i_zp_name) # reconstruct the ranges int8_traits = torch.iinfo(torch.int8) @@ -91,20 +94,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) replace_parameter( - layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) + layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False) ) azp = ( (int8_traits.min - range_min / scale).round().to(dtype=torch.int32) ) replace_parameter( - layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) ) - else: - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - # Different from cutlass, oneDNN kernels only need the AZP adjustment # term for dynamic quantization. And s_b should be folded into the # term. Such as: @@ -112,38 +111,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias = # s_a * GEMM_output - s_a * zp_a * adj + bias if not (self.config.input_symmetric and self.config.is_static_input_scheme): - weight = getattr(layer, self.w_q_name) - weight_scale = getattr(layer, self.w_s_name) + weight = getattr(layer, w_q_name) + weight_scale = getattr(layer, w_s_name) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32) azp_adj = azp_adj * weight_scale.squeeze() setattr( layer, - self.azp_adj_name, + azp_adj_name, torch.nn.Parameter(azp_adj, requires_grad=False), ) - else: - setattr(layer, self.azp_adj_name, None) - weight = getattr(layer, self.w_q_name) + weight = getattr(layer, w_q_name) self.dnnl_handler = ops.create_onednn_scaled_mm( weight, - getattr(layer, self.w_s_name), + getattr(layer, w_s_name), torch.get_default_dtype(), - getattr(layer, self.i_s_name) is None, + getattr(layer, i_s_name) is None, not self.config.input_symmetric, 32, ) # weight is prepacked and maintained by the dnnl_handler, # release the original weight - setattr(layer, self.w_q_name, None) + setattr(layer, w_q_name, None) del weight def process_weights_for_sgl(self, layer: torch.nn.Module) -> None: + w_q_name, w_s_name, _, _, _ = self.layer_param_names # WEIGHT - weight = getattr(layer, self.w_q_name) + weight = getattr(layer, w_q_name) packed_weight = torch.ops._C.convert_weight_packed(weight) replace_parameter( - layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False) + layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False) ) if layer.bias is not None: @@ -155,19 +153,15 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): # WEIGHT SCALE # CPU SGL kernels only support per-channel. # For per-tensor quant, convert to the per-channel case. - weight_scale = getattr(layer, self.w_s_name) + weight_scale = getattr(layer, w_s_name) if not self.config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - setattr(layer, self.azp_adj_name, None) - def apply_weights( self, layer: torch.nn.Module, @@ -186,7 +180,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. @@ -208,7 +202,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, _, _, _ = self._get_weight_params(layer) + w_q, w_s, _, _, _ = self._get_layer_params(layer) return torch.ops._C.int8_scaled_mm_with_quant( x, w_q, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index b81d670686930..2a8b68980949a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -15,10 +15,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearKernel, Int8ScaledMMLinearKernel +from .utils import apply_weights_fp8 - -def cutlass_w8a8_scaled_mm( +def cutlass_w8a8_scaled_mm_fp8( *, A: torch.Tensor, B: torch.Tensor, @@ -34,91 +34,7 @@ def cutlass_w8a8_scaled_mm( ) return output.view(*output_shape) - -def process_weights_after_loading( - config: Int8ScaledMMLinearLayerConfig, - layer: torch.nn.Module, - w_q_name: str, - w_s_name: str, - i_s_name: str, - i_zp_name: str, - azp_adj_name: str, -): - # WEIGHT - # Cutlass kernels need transposed weight. - weight = getattr(layer, w_q_name) - replace_parameter( - layer, - w_q_name, - torch.nn.Parameter(weight.t().data, requires_grad=False), - ) - - # WEIGHT SCALE - # Cutlass kernels support only per-tensor and per-channel. - # If we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), convert to the per-channel case. - is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, w_s_name) - if is_fused_module and not config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) - replace_parameter( - layer, - w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False), - ) - - # INPUT SCALE - if config.is_static_input_scheme: - input_scale = getattr(layer, i_s_name) - - if config.input_symmetric: - replace_parameter( - layer, - i_s_name, - torch.nn.Parameter(input_scale.max(), requires_grad=False), - ) - setattr(layer, i_zp_name, None) - else: - input_zero_point = getattr(layer, i_zp_name) - - # reconstruct the ranges - int8_traits = torch.iinfo(torch.int8) - azps = input_zero_point.to(dtype=torch.int32) - range_max = (input_scale * (int8_traits.max - azps)).max() - range_min = (input_scale * (int8_traits.min - azps)).min() - - scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) - replace_parameter( - layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False) - ) - - # AZP loaded as int8 but used as int32 - azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) - replace_parameter( - layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) - ) - - - # azp_adj is the AZP adjustment term, used to account for weights. - # It does not depend on scales or azp, so it is the same for - # static and dynamic quantization. - # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md - # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md - if not config.input_symmetric: - weight = getattr(layer, w_q_name) - azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) - if config.is_static_input_scheme: - # cutlass_w8a8 requires azp to be folded into azp_adj - # in the per-tensor case - azp_adj = getattr(layer, i_zp_name) * azp_adj - setattr( - layer, - azp_adj_name, - torch.nn.Parameter(azp_adj, requires_grad=False), - ) - - -class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): +class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: return 75 @@ -131,9 +47,83 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - _, param_names = self.layer_mapping_function(layer) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( + self.layer_param_names + ) + config = self.config + # WEIGHT + # Cutlass kernels need transposed weight. + weight = getattr(layer, w_q_name) + replace_parameter( + layer, + w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False), + ) + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, w_s_name) + if is_fused_module and not config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) + replace_parameter( + layer, + w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False), + ) + + # INPUT SCALE + if config.is_static_input_scheme: + input_scale = getattr(layer, i_s_name) + + if config.input_symmetric: + replace_parameter( + layer, + i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False), + ) + setattr(layer, i_zp_name, None) + else: + input_zero_point = getattr(layer, i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) + replace_parameter( + layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False) + ) + + # AZP loaded as int8 but used as int32 + azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) + replace_parameter( + layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + ) + + + # azp_adj is the AZP adjustment term, used to account for weights. + # It does not depend on scales or azp, so it is the same for + # static and dynamic quantization. + # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md + # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md + if not config.input_symmetric: + weight = getattr(layer, w_q_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) + if config.is_static_input_scheme: + # cutlass_w8a8 requires azp to be folded into azp_adj + # in the per-tensor case + azp_adj = getattr(layer, i_zp_name) * azp_adj + setattr( + layer, + azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False), + ) - process_weights_after_loading(self.config, layer, *param_names) def apply_weights( self, @@ -141,7 +131,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - (w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer) + w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. @@ -170,21 +160,10 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): ) -class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel): - def __init__( - self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable - ) -> None: - self.quant_fp8 = QuantFP8( - static=c.is_static_input_scheme, - group_shape=GroupShape.PER_TENSOR, - num_token_padding=None, - ) - super().__init__(c, layer_mapping_function) +class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): - @classmethod - def get_min_capability(cls) -> int: - # lovelace and up - return 89 + def get_ouput_padding(self) -> int | None: + return None @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: @@ -197,41 +176,20 @@ class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel): return True, None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass - def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ): - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - (w, w_s, x_s), _ = self.layer_mapping_function(layer) - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - - out_dtype = self.config.out_dtype - out_dtype = x.dtype if out_dtype is None else out_dtype - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = self.quant_fp8( - x_2d, - x_s, - ) - - output_shape = [*x_2d_q.shape[:-1], w.shape[1]] - - return cutlass_w8a8_scaled_mm( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, - ) + w, w_s, x_s = self._get_layer_params(layer) + return apply_weights_fp8( + cutlass_w8a8_scaled_mm_fp8, + self.quant_fp8, + w, + x, + w_s, + x_s, + bias, + self.config.out_dtype + ) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py index 9fcbb2ff8ec8a..5cb4fa7150d41 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -10,10 +10,11 @@ from vllm.platforms import current_platform from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer from .ScaledMMLinearKernel import ( - ScaledMMLinearKernel, + FP8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) +from .utils import apply_weights_fp8 def flashinfer_w8a8_scaled_mm( @@ -30,16 +31,10 @@ def flashinfer_w8a8_scaled_mm( ) -class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel): - def __init__( - self, c: Int8ScaledMMLinearLayerConfig, layer_mapping_function: Callable - ) -> None: - self.quant_fp8 = QuantFP8( - static=c.is_static_input_scheme, - group_shape=GroupShape.PER_TENSOR, - num_token_padding=None, - ) - super().__init__(c, layer_mapping_function) +class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): + + def get_ouput_padding(self) -> int | None: + return None @classmethod def get_min_capability(cls) -> int: @@ -80,41 +75,20 @@ class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel): ) return True, None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass - def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ): - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - (w, w_s, x_s), _ = self.layer_mapping_function(layer) - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - - out_dtype = self.config.out_dtype - out_dtype = x.dtype if out_dtype is None else out_dtype - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = self.quant_fp8( - x_2d, - x_s, - ) - - output_shape = [*x_2d_q.shape[:-1], w.shape[1]] - - return flashinfer_w8a8_scaled_mm( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, - ) + w, w_s, x_s = self._get_layer_params(layer) + return apply_weights_fp8( + flashinfer_w8a8_scaled_mm, + self.quant_fp8, + w, + x, + w_s, + x_s, + bias, + self.config.out_dtype + ) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 17b932f2336db..8abe124c4b6f4 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -12,11 +12,11 @@ from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op from .ScaledMMLinearKernel import ( - ScaledMMLinearKernel, + FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) - +from .utils import apply_weights_fp8 def rocm_per_tensor_float_w8a8_scaled_mm_impl( A: torch.Tensor, @@ -88,20 +88,9 @@ if current_platform.is_rocm(): ) -class ROCmScaledMMLinearKernel(ScaledMMLinearKernel): - def __init__( - self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable - ) -> None: - self.quant_fp8 = QuantFP8( - static=c.is_static_input_scheme, - group_shape=GroupShape.PER_TENSOR, - num_token_padding=None, - ) - super().__init__(c, layer_mapping_function) - - @classmethod - def get_min_capability(cls) -> int: - return 90 +class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): + def get_ouput_padding(self) -> int | None: + return None @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: @@ -128,7 +117,7 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel): return ( False, "VLLM_ROCM_USE_SKINNY_GEMM must be enabled " - + "to use ROCmScaledMMLinearKernel ", + + "to use ROCmScaledMMLinearKernel.", ) if not (per_tensor_activation_scales and per_tensor_weight_scales): @@ -139,41 +128,20 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel): ) return True, None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass - def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ): - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - (w, w_s, x_s), _ = self.layer_mapping_function(layer) - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - - out_dtype = self.config.out_dtype - out_dtype = x.dtype if out_dtype is None else out_dtype - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = self.quant_fp8( - x_2d, - x_s, - ) - - output_shape = [*x_2d_q.shape[:-1], w.shape[1]] - - return rocm_per_tensor_float_w8a8_scaled_mm( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, + w, w_s, x_s = self._get_layer_params(layer) + return apply_weights_fp8( + rocm_per_tensor_float_w8a8_scaled_mm, + self.quant_fp8, + w, + x, + w_s, + x_s, + bias, + self.config.out_dtype ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py index 7d82496dca023..8e5fc66e4fed8 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py @@ -11,11 +11,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from .ScaledMMLinearKernel import ( - ScaledMMLinearKernel, + FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) +from .utils import apply_weights_fp8 # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale TORCH_DEVICE_IDENTITY = None @@ -134,35 +135,16 @@ def torch_channelwise_w8a8_scaled_mm( return output.to(out_dtype).view(*output_shape) -class TorchScaledMMLinearKernel(ScaledMMLinearKernel): - def __init__( - self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable - ) -> None: +class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel): + def get_ouput_padding(self) -> int | None: vllm_config = get_current_vllm_config().compilation_config pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE - output_padding = 17 if pad_output else None - - self.quant_fp8 = QuantFP8( - static=c.is_static_input_scheme, - group_shape=GroupShape.PER_TENSOR, - num_token_padding=output_padding, - ) - super().__init__(c, layer_mapping_function) - - @classmethod - def get_min_capability(cls) -> int: - # lovelace and up - return 89 - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - return - + return output_padding class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - assert c.activation_group_shape is not None per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_weight_scales = ( c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR @@ -182,36 +164,18 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ): - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - (w, w_s, x_s), _ = self.layer_mapping_function(layer) - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - - out_dtype = self.config.out_dtype - out_dtype = x.dtype if out_dtype is None else out_dtype - - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = self.quant_fp8( - x_2d, - x_s, - ) - output_shape = [*x_2d_q.shape[:-1], w.shape[1]] - return torch_per_tensor_w8a8_scaled_mm( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, + w, w_s, x_s = self._get_layer_params(layer) + return apply_weights_fp8( + torch_per_tensor_w8a8_scaled_mm, + self.quant_fp8, + w, + x, + w_s, + x_s, + bias, + self.config.out_dtype ) - class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -219,14 +183,12 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - assert c.activation_group_shape is not None - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_weight_scales = ( c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR ) - if per_tensor_activation_scales and per_tensor_weight_scales: + if per_tensor_activation_scales or per_tensor_weight_scales: return ( False, "RowWiseTorchScaledMMLinearKernel cannot be used with " @@ -254,33 +216,16 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ): - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - (w, w_s, x_s), _ = self.layer_mapping_function(layer) - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - - out_dtype = self.config.out_dtype - out_dtype = x.dtype if out_dtype is None else out_dtype - - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = self.quant_fp8( - x_2d, - x_s, - ) - output_shape = [*x_2d_q.shape[:-1], w.shape[1]] - return torch_row_wise_w8a8_scaled_mm( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, + w, w_s, x_s = self._get_layer_params(layer) + return apply_weights_fp8( + torch_row_wise_w8a8_scaled_mm, + self.quant_fp8, + w, + x, + w_s, + x_s, + bias, + self.config.out_dtype ) @@ -291,8 +236,6 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - assert c.activation_group_shape is not None - per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_weight_scales = ( c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR @@ -313,31 +256,14 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ): - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - (w, w_s, x_s), _ = self.layer_mapping_function(layer) - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - - out_dtype = self.config.out_dtype - out_dtype = x.dtype if out_dtype is None else out_dtype - - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = self.quant_fp8( - x_2d, - x_s, - ) - output_shape = [*x_2d_q.shape[:-1], w.shape[1]] - return torch_channelwise_w8a8_scaled_mm( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, + w, w_s, x_s = self._get_layer_params(layer) + return apply_weights_fp8( + torch_channelwise_w8a8_scaled_mm, + self.quant_fp8, + w, + x, + w_s, + x_s, + bias, + self.config.out_dtype ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py new file mode 100644 index 0000000000000..e1d5a291b8463 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py @@ -0,0 +1,44 @@ +from collections.abc import Callable +import torch +from vllm.platforms import current_platform + +FP8ScaledMMCallBack = Callable[..., torch.Tensor] +FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]] + +def apply_weights_fp8( + scaled_mm_func: FP8ScaledMMCallBack, + quant_fp8_func: FP8QuantCallback, + w:torch.Tensor, + x:torch.Tensor, + w_s:torch.Tensor, + x_s:torch.Tensor, + bias:torch.Tensor, + maybe_out_dtype: torch.dtype | None, + ) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_s computed from x. + # If static, layer.input_scale is scalar and x_s is input_scale. + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + output_shape = [*x.shape[:-1], w.shape[1]] + + out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype + + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = quant_fp8_func( + x_2d, + x_s, + ) + + return scaled_mm_func( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 6150270c8773f..bafaf06ed7962 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -12,10 +12,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig -class XLAScaledMMLinearKernel(ScaledMMLinearKernel): +class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: raise NotImplementedError( @@ -42,9 +42,12 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # WEIGHT # [out, in] (different than cutlass_scaled_mm) - weight = getattr(layer, self.w_q_name) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( + self.layer_param_names + ) + weight = getattr(layer, w_q_name) replace_parameter( - layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) + layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) ) # WEIGHT SCALE @@ -52,7 +55,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): # If we have a fused module (QKV, MLP) with per tensor scales (thus N # scales being passed to the kernel), convert to the per-channel case. is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) + weight_scale = getattr(layer, w_s_name) if is_fused_module and not self.config.is_channelwise: weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) @@ -60,14 +63,14 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): weight_scale = weight_scale.squeeze(-1) replace_parameter( layer, - self.w_s_name, + w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False), ) # Only support symmetric dynamic activation quantization. - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - setattr(layer, self.azp_adj_name, None) + setattr(layer, i_s_name, None) + setattr(layer, i_zp_name, None) + setattr(layer, azp_adj_name, None) # Filter warning for cond usage in apply_weights. It is okay # to specialize the graph since bias is not dynamic. @@ -88,7 +91,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, _, _, _ = self._get_weight_params(layer) + w_q, w_s, _, _, _ = self._get_layer_params(layer) # Required to register custom ops. import torch_xla.experimental.custom_kernel # noqa: F401 diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index 3d51ea2cd9580..856d7fb32c096 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -102,24 +102,27 @@ class QuarkW8A8Int8(QuarkScheme): layer.register_parameter("weight_zero_point", weight_zero_point) # INPUT SCALE + input_zero_point=None + input_scale=None if self.is_static_input_scheme: input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader ) - layer.register_parameter("input_scale", input_scale) input_zero_point = BasevLLMParameter( data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader ) - layer.register_parameter("input_zero_point", input_zero_point) + + layer.register_parameter("input_scale", input_scale) + layer.register_parameter("input_zero_point", input_zero_point) + if not hasattr(layer, "azp_adj"): + layer.register_parameter("azp_adj", None) + + layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] self.kernel = kernel_type( c=scaled_mm_linear_kernel_config, - w_q_param_name="weight", - w_s_param_name="weight_scale", - i_s_param_name="input_scale", - i_zp_param_name="input_zero_point", - azp_adj_param_name="azp_adj", + layer_param_names = layer_param_names ) # Checkpoints are serialized in quark format, which is