From 423e2a625e5fbbf2e35029d092649a147a72f5af Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 31 Oct 2025 14:07:09 +0000 Subject: [PATCH] reduce logging boilerplate; update fp8 path Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 10 +----- .../schemes/compressed_tensors_w8a8_int8.py | 10 ++---- .../model_executor/layers/quantization/fp8.py | 36 ++++++++++++------- .../kernels/scaled_mm/__init__.py | 7 ++++ .../quantization/kernels/scaled_mm/utils.py | 4 +-- .../quark/schemes/quark_w8a8_fp8.py | 7 +--- .../quark/schemes/quark_w8a8_int8.py | 10 ++---- 7 files changed, 41 insertions(+), 43 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index a872ee15c7ae6..633a41261ca9d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -54,8 +54,6 @@ logger = init_logger(__name__) class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): - _kernel_backends_being_used: set[str] = set() - def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): self.weight_quant = weight_quant self.strategy = weight_quant.strategy @@ -95,14 +93,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): kernel_type = choose_scaled_mm_linear_kernel( scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, + module_name=self.__class__.__name__, ) - - if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info( - "Using %s for CompressedTensorsW8A8FP8", kernel_type.__name__ - ) - self._kernel_backends_being_used.add(kernel_type.__name__) - self.kernel = kernel_type( scaled_mm_linear_kernel_config, layer_param_names=layer_param_names ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index e662a1af7f1ff..914d0e1bd08a0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -28,8 +28,6 @@ logger = init_logger(__name__) class CompressedTensorsW8A8Int8(CompressedTensorsScheme): - _kernel_backends_being_used: set[str] = set() - def __init__( self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool ): @@ -60,13 +58,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ) kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, _POSSIBLE_INT8_KERNELS + scaled_mm_linear_kernel_config, + _POSSIBLE_INT8_KERNELS, + module_name=self.__class__.__name__, ) - if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__) - self._kernel_backends_being_used.add(kernel_type.__name__) - # WEIGHT weight = ModelWeightParameter( data=torch.empty( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f82eccb88ce09..2bec9bf553c8c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -42,6 +42,14 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + _POSSIBLE_FP8_KERNELS, + choose_scaled_mm_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501 + FP8ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -77,7 +85,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, cutlass_fp8_supported, @@ -387,9 +394,21 @@ class Fp8LinearMethod(LinearMethodBase): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape, + scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + is_static_input_scheme=self.act_q_static, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + activation_group_shape=self.act_q_group_shape, + out_dtype=self.out_dtype, + ) + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, + _POSSIBLE_FP8_KERNELS, + module_name=self.__class__.__name__, + ) + + self.fp8_linear_kernel = kernel_type( + scaled_mm_linear_kernel_config, + layer_param_names=["weight", "weight_scale", "input_scale"], ) def create_weights( @@ -674,14 +693,7 @@ class Fp8LinearMethod(LinearMethodBase): bias=bias, ) - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear_kernel.apply_weights(layer, x, bias) class Fp8MoEMethod(FusedMoEMethodBase): diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 2ad21162995fe..35f9034cacdbd 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -3,6 +3,7 @@ import os +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( AiterScaledMMLinearKernel, ) @@ -32,6 +33,8 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( ) from vllm.platforms import PlatformEnum, current_platform +logger = init_logger(__name__) + # in priority/performance order (when available) _POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CPUScaledMMLinearKernel], @@ -54,6 +57,7 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { def choose_scaled_mm_linear_kernel( config: ScaledMMLinearLayerConfig, possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]], + module_name: str, compute_capability: int | None = None, ) -> type[ScaledMMLinearKernel]: """ @@ -105,6 +109,9 @@ def choose_scaled_mm_linear_kernel( can_implement, failure_reason = kernel.can_implement(config) if can_implement: + logger.info_once( + "Selected %s for %s", kernel.__name__, module_name, scope="global" + ) return kernel else: failure_reasons.append( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py index 9f4e9a7befc45..ca1a2c5b4f29b 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py @@ -4,15 +4,15 @@ from collections.abc import Callable import torch +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.platforms import current_platform FP8ScaledMMCallBack = Callable[..., torch.Tensor] -FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]] def apply_weights_fp8( scaled_mm_func: FP8ScaledMMCallBack, - quant_fp8_func: FP8QuantCallback, + quant_fp8_func: QuantFP8, w: torch.Tensor, x: torch.Tensor, w_s: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 6c296fe9a5807..e8145f261b9bd 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -40,8 +40,6 @@ QUANT_STRATEGY_MAP = { class QuarkW8A8Fp8(QuarkScheme): - _kernel_backends_being_used: set[str] = set() - def __init__( self, weight_config: dict[str, Any], input_config: dict[str, Any] | None ): @@ -187,12 +185,9 @@ class QuarkW8A8Fp8(QuarkScheme): kernel_type = choose_scaled_mm_linear_kernel( scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, + module_name=self.__class__.__name__, ) - if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for QuarkW8A8FP8", kernel_type.__name__) - self._kernel_backends_being_used.add(kernel_type.__name__) - layer_param_names = ["weight", "weight_scale", "input_scale"] self.kernel = kernel_type( c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index 2fb69fe5e40e2..ea8db2456f865 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -25,8 +25,6 @@ logger = init_logger(__name__) class QuarkW8A8Int8(QuarkScheme): - _kernel_backends_being_used: set[str] = set() - def __init__( self, qscheme: str, @@ -60,13 +58,11 @@ class QuarkW8A8Int8(QuarkScheme): ) kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, possible_kernels=_POSSIBLE_INT8_KERNELS + scaled_mm_linear_kernel_config, + possible_kernels=_POSSIBLE_INT8_KERNELS, + module_name=self.__class__.__name__, ) - if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__) - self._kernel_backends_being_used.add(kernel_type.__name__) - # WEIGHT weight = ModelWeightParameter( data=torch.empty(