From 7d361487f7372199a5a8fdf307fd2afd0161f296 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 31 Oct 2025 14:52:51 +0000 Subject: [PATCH] update ptpc path; bug fixes Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 5 ++-- .../model_executor/layers/quantization/fp8.py | 5 ++-- .../kernels/scaled_mm/__init__.py | 7 ++--- .../scaled_mm/{torch.py => pytorch.py} | 0 .../layers/quantization/ptpc_fp8.py | 26 +++++++++++-------- .../quark/schemes/quark_w8a8_fp8.py | 5 ++-- 6 files changed, 28 insertions(+), 20 deletions(-) rename vllm/model_executor/layers/quantization/kernels/scaled_mm/{torch.py => pytorch.py} (100%) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 56fee0523a87b..f4ec97804fd48 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -82,10 +82,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): else: weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy] self.fp8_linear_kernel = init_fp8_linear_kernel( - is_static_input_scheme=self.is_static_input_scheme, + act_q_static=self.is_static_input_scheme, + act_q_group_shape=self.act_q_group_shape, weight_quant_strategy=weight_quant_strategy, - activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, + module_name=self.__class__.__name__ ) @classmethod diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0744f82ed27f3..91988a18fda71 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -394,10 +394,11 @@ class Fp8LinearMethod(LinearMethodBase): ) else: self.fp8_linear_kernel = init_fp8_linear_kernel( - is_static_input_scheme=self.act_q_static, + act_q_static=self.act_q_static, + act_q_group_shape=self.act_q_group_shape, weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, - activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, + module_name=self.__class__.__name__ ) def create_weights( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index c4cadecb3af58..629ed790b9666 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( +from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( ChannelWiseTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, RowWiseTorchScaledMMLinearKernel, @@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( XLAScaledMMLinearKernel, ) from vllm.platforms import PlatformEnum, current_platform -from vllm.vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape logger = init_logger(__name__) @@ -133,12 +133,13 @@ def choose_scaled_mm_linear_kernel( def init_fp8_linear_kernel( act_q_static: bool, act_q_group_shape: GroupShape, + weight_quant_strategy: ScaledMMLinearQuantStrategy, out_dtype: torch.dtype, module_name: str, ) -> FP8ScaledMMLinearKernel: scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( is_static_input_scheme=act_q_static, - weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + weight_quant_strategy=weight_quant_strategy, activation_group_shape=act_q_group_shape, out_dtype=out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py similarity index 100% rename from vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py rename to vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 26ba8e5b16bc0..5352ba9c45009 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -16,11 +16,16 @@ from vllm.model_executor.layers.quantization.fp8 import ( Fp8KVCacheMethod, Fp8LinearMethod, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -98,11 +103,15 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): ) super().__init__(quant_config=quant_config) # Force weight quantization - self.quant_config.is_checkpoint_fp8_serialized = False - self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN + self.fp8_linear_kernel = init_fp8_linear_kernel( + act_q_static=False, + act_q_group_shape=GroupShape.PER_TOKEN, + weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__ ) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) @@ -127,11 +136,6 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=None, - input_scale_ub=None, - bias=bias, + return self.fp8_linear_kernel.apply_weights( + layer, x, bias ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index f053b1c438e6c..94a90747e2814 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -174,10 +174,11 @@ class QuarkW8A8Fp8(QuarkScheme): weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme] self.fp8_linear_kernel = init_fp8_linear_kernel( - is_static_input_scheme=self.is_static_input_scheme, + act_q_static=self.is_static_input_scheme, + act_q_group_shape=self.act_quant_group_shape, weight_quant_strategy=weight_quant_strategy, - activation_group_shape=self.act_quant_group_shape, out_dtype=self.out_dtype, + module_name=self.__class__.__name__ ) def apply_weights(