From dd001064c03c5a4dd9e179ea1886e9fb2b17d796 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 31 Oct 2025 14:21:49 +0000 Subject: [PATCH] reduce kernel init boilerplate Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 17 ++-------- .../model_executor/layers/quantization/fp8.py | 15 ++------- .../kernels/scaled_mm/__init__.py | 31 +++++++++++++++++++ .../quark/schemes/quark_w8a8_fp8.py | 19 ++---------- 4 files changed, 39 insertions(+), 43 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 633a41261ca9d..56fee0523a87b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -12,12 +12,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - _POSSIBLE_FP8_KERNELS, - choose_scaled_mm_linear_kernel, + init_fp8_linear_kernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 QUANT_STRATEGY_MAP, - FP8ScaledMMLinearLayerConfig, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -82,22 +80,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - layer_param_names = ["weight", "weight_scale", "input_scale"] weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy] - scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + self.fp8_linear_kernel = init_fp8_linear_kernel( is_static_input_scheme=self.is_static_input_scheme, weight_quant_strategy=weight_quant_strategy, activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, ) - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - _POSSIBLE_FP8_KERNELS, - module_name=self.__class__.__name__, - ) - self.kernel = kernel_type( - scaled_mm_linear_kernel_config, layer_param_names=layer_param_names - ) @classmethod def get_min_capability(cls) -> int: @@ -212,4 +201,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): bias=bias, ) - return self.kernel.apply_weights(layer, x, bias) + return self.fp8_linear_kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2bec9bf553c8c..0744f82ed27f3 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -43,8 +43,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ) from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - _POSSIBLE_FP8_KERNELS, - choose_scaled_mm_linear_kernel, + init_fp8_linear_kernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501 FP8ScaledMMLinearLayerConfig, @@ -394,22 +393,12 @@ class Fp8LinearMethod(LinearMethodBase): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + self.fp8_linear_kernel = init_fp8_linear_kernel( is_static_input_scheme=self.act_q_static, weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, ) - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - _POSSIBLE_FP8_KERNELS, - module_name=self.__class__.__name__, - ) - - self.fp8_linear_kernel = kernel_type( - scaled_mm_linear_kernel_config, - layer_param_names=["weight", "weight_scale", "input_scale"], - ) def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 35f9034cacdbd..c4cadecb3af58 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -3,6 +3,8 @@ import os +import torch + from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( AiterScaledMMLinearKernel, @@ -17,8 +19,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( ROCmScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, ScaledMMLinearKernel, ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( ChannelWiseTorchScaledMMLinearKernel, @@ -32,6 +37,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( XLAScaledMMLinearKernel, ) from vllm.platforms import PlatformEnum, current_platform +from vllm.vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape logger = init_logger(__name__) @@ -122,3 +128,28 @@ def choose_scaled_mm_linear_kernel( "Failed to find a kernel that can implement the " "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons) ) + + +def init_fp8_linear_kernel( + act_q_static: bool, + act_q_group_shape: GroupShape, + out_dtype: torch.dtype, + module_name: str, +) -> FP8ScaledMMLinearKernel: + scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + is_static_input_scheme=act_q_static, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_group_shape=act_q_group_shape, + out_dtype=out_dtype, + ) + + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, + _POSSIBLE_FP8_KERNELS, + module_name=module_name, + ) + + return kernel_type( + scaled_mm_linear_kernel_config, + layer_param_names=["weight", "weight_scale", "input_scale"], + ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index e8145f261b9bd..f053b1c438e6c 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -9,11 +9,9 @@ from torch.nn import Parameter from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - _POSSIBLE_FP8_KERNELS, - choose_scaled_mm_linear_kernel, + init_fp8_linear_kernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme @@ -174,24 +172,13 @@ class QuarkW8A8Fp8(QuarkScheme): input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) - layer_param_names = ["weight", "weight_scale", "input_scale"] weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme] - scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + self.fp8_linear_kernel = init_fp8_linear_kernel( is_static_input_scheme=self.is_static_input_scheme, weight_quant_strategy=weight_quant_strategy, activation_group_shape=self.act_quant_group_shape, out_dtype=self.out_dtype, ) - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - _POSSIBLE_FP8_KERNELS, - module_name=self.__class__.__name__, - ) - - layer_param_names = ["weight", "weight_scale", "input_scale"] - self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names - ) def apply_weights( self, @@ -199,4 +186,4 @@ class QuarkW8A8Fp8(QuarkScheme): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.kernel.apply_weights(layer, x, bias) + return self.fp8_linear_kernel.apply_weights(layer, x, bias)