From e845035f4c6c491914203c018c0ea51a564f780a Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 31 Oct 2025 16:38:26 +0000 Subject: [PATCH] bug fix Signed-off-by: vllmellm --- .../compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py | 2 ++ vllm/model_executor/layers/quantization/fbgemm_fp8.py | 2 +- vllm/model_executor/layers/quantization/fp8.py | 1 + .../layers/quantization/kernels/scaled_mm/__init__.py | 2 +- .../layers/quantization/quark/schemes/quark_w8a8_fp8.py | 2 ++ 5 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 1d0e36a3fc551..58ea30edcd639 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -146,6 +146,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader) layer.register_parameter("input_scale", input_scale) + layer.register_parameter("input_scale_ub", None) + def process_weights_after_loading(self, layer) -> None: if self.strategy == QuantizationStrategy.TENSOR: weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index fb16681f03a0c..a7b8e6ddda719 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 ScaledMMLinearQuantStrategy, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 484a8d7ab3af7..48697e3849e05 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -451,6 +451,7 @@ class Fp8LinearMethod(LinearMethodBase): weight_loader=weight_loader, ) layer.register_parameter("weight", weight) + layer.register_parameter("input_scale_ub", None) # If checkpoint is serialized fp8, load them. # Otherwise, wait until process_weights_after_loading. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 26baba602945d..3c0ee8323c555 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -152,5 +152,5 @@ def init_fp8_linear_kernel( return kernel_type( scaled_mm_linear_kernel_config, - layer_param_names=["weight", "weight_scale", "input_scale"], + layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"], ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index f32c14e27f68f..6fff449000075 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -172,6 +172,8 @@ class QuarkW8A8Fp8(QuarkScheme): input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) + layer.register_parameter("input_scale_ub", None) + weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme] self.fp8_linear_kernel = init_fp8_linear_kernel( act_q_static=self.is_static_input_scheme,