From 7fb465744c12b0d10372cbb1e514606aebbbcc88 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 7 Nov 2025 07:17:41 +0000 Subject: [PATCH] implement apply func in base FP8ScaledMMLinearKernel class Signed-off-by: vllmellm --- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 55 ++++++++++++++-- .../quantization/kernels/scaled_mm/cutlass.py | 29 +++------ .../kernels/scaled_mm/flashinfer.py | 37 ++++------- .../quantization/kernels/scaled_mm/pytorch.py | 63 +++---------------- .../quantization/kernels/scaled_mm/rocm.py | 29 +++------ .../quantization/kernels/scaled_mm/utils.py | 49 --------------- 6 files changed, 83 insertions(+), 179 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index a8a2fc245f62d..5baa7f73077aa 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Generic, TypeVar @@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, ) +from vllm.platforms import current_platform @dataclass @@ -98,12 +99,9 @@ class FP8ScaledMMLinearKernel( group_shape=act_scale_descriptor.group_shape, num_token_padding=self.get_ouput_padding(), ) + self.fp8_dtype = current_platform.fp8_dtype() super().__init__(c, layer_param_names) - @abstractmethod - def get_ouput_padding(self) -> int | None: - raise NotImplementedError - @classmethod def get_min_capability(cls) -> int: # lovelace and up @@ -121,6 +119,53 @@ class FP8ScaledMMLinearKernel( getattr(layer, x_s_ub), ) + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + scaled_mm_func = self.get_scaled_mm_func() + quant_fp8 = self.quant_fp8 + fp8_dtype = self.fp8_dtype + maybe_out_dtype = self.config.out_dtype + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) + + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_s computed from x. + # If static, layer.input_scale is scalar and x_s is input_scale. + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + output_shape = [*x.shape[:-1], w.shape[1]] + out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype + + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != fp8_dtype: + x_2d_q, x_s = quant_fp8( + x_2d, + x_s, + x_s_ub, + ) + return scaled_mm_func( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) + + @abstractmethod + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def get_ouput_padding(self) -> int | None: + raise NotImplementedError + class Int8ScaledMMLinearKernel( ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index fc8893cb7e1b0..dbed970785568 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch from vllm import _custom_ops as ops @@ -17,7 +19,6 @@ from .ScaledMMLinearKernel import ( Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, ) -from .utils import apply_weights_fp8 def cutlass_w8a8_scaled_mm_fp8( @@ -160,9 +161,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel): class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): - def get_ouput_padding(self) -> int | None: - return None - @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_cuda(): @@ -174,21 +172,8 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): return True, None - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ): - w, w_s, x_s, x_s_ub = self._get_layer_params(layer) - return apply_weights_fp8( - cutlass_w8a8_scaled_mm_fp8, - self.quant_fp8, - w, - x, - w_s, - x_s, - bias, - x_s_ub, - self.config.out_dtype, - ) + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return cutlass_w8a8_scaled_mm_fp8 + + def get_ouput_padding(self) -> int | None: + return None diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py index 3bac71950dda2..e816f5d2c156e 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flashinfer.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch from vllm.platforms import current_platform @@ -10,7 +12,6 @@ from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, ) -from .utils import apply_weights_fp8 def flashinfer_w8a8_scaled_mm( @@ -29,13 +30,6 @@ def flashinfer_w8a8_scaled_mm( class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): - def get_ouput_padding(self) -> int | None: - return None - - @classmethod - def get_min_capability(cls) -> int: - return 100 - @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: per_tensor_activation_scales = ( @@ -71,21 +65,12 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): ) return True, None - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ): - w, w_s, x_s, x_s_ub = self._get_layer_params(layer) - return apply_weights_fp8( - flashinfer_w8a8_scaled_mm, - self.quant_fp8, - w, - x, - w_s, - x_s, - bias, - x_s_ub, - self.config.out_dtype, - ) + @classmethod + def get_min_capability(cls) -> int: + return 100 + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return flashinfer_w8a8_scaled_mm + + def get_ouput_padding(self) -> int | None: + return None diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index 10293c445a347..b7aed6105d10c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch from packaging import version @@ -11,7 +13,6 @@ from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, ) -from .utils import apply_weights_fp8 # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -155,24 +156,8 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): ) return True, None - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ): - w, w_s, x_s, x_s_ub = self._get_layer_params(layer) - return apply_weights_fp8( - torch_per_tensor_w8a8_scaled_mm, - self.quant_fp8, - w, - x, - w_s, - x_s, - bias, - x_s_ub, - self.config.out_dtype, - ) + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return torch_per_tensor_w8a8_scaled_mm class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @@ -209,24 +194,8 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): return True, None - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ): - w, w_s, x_s, x_s_ub = self._get_layer_params(layer) - return apply_weights_fp8( - torch_row_wise_w8a8_scaled_mm, - self.quant_fp8, - w, - x, - w_s, - x_s, - bias, - x_s_ub, - self.config.out_dtype, - ) + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return torch_row_wise_w8a8_scaled_mm class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @@ -245,21 +214,5 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): return True, None - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ): - w, w_s, x_s, x_s_ub = self._get_layer_params(layer) - return apply_weights_fp8( - torch_channelwise_w8a8_scaled_mm, - self.quant_fp8, - w, - x, - w_s, - x_s, - bias, - x_s_ub, - self.config.out_dtype, - ) + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return torch_channelwise_w8a8_scaled_mm diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 26463a19c6f48..852e0088d0d97 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch import vllm.envs as envs @@ -12,7 +14,6 @@ from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, FP8ScaledMMLinearLayerConfig, ) -from .utils import apply_weights_fp8 def rocm_per_tensor_float_w8a8_scaled_mm_impl( @@ -86,9 +87,6 @@ direct_register_custom_op( class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): - def get_ouput_padding(self) -> int | None: - return None - @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: # TODO: check if this causes an issue on non-ROCM platforms @@ -125,21 +123,8 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): ) return True, None - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ): - w, w_s, x_s, x_s_ub = self._get_layer_params(layer) - return apply_weights_fp8( - rocm_per_tensor_float_w8a8_scaled_mm, - self.quant_fp8, - w, - x, - w_s, - x_s, - bias, - x_s_ub, - self.config.out_dtype, - ) + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return rocm_per_tensor_float_w8a8_scaled_mm + + def get_ouput_padding(self) -> int | None: + return None diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py deleted file mode 100644 index e5ab5ad4d47cf..0000000000000 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable - -import torch - -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.platforms import current_platform - - -def apply_weights_fp8( - scaled_mm_func: Callable[..., torch.Tensor], - quant_fp8_func: QuantFP8, - w: torch.Tensor, - x: torch.Tensor, - w_s: torch.Tensor, - x_s: torch.Tensor, - bias: torch.Tensor, - x_s_ub: torch.Tensor | None, - maybe_out_dtype: torch.dtype | None = None, -) -> torch.Tensor: - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_s computed from x. - # If static, layer.input_scale is scalar and x_s is input_scale. - # View input as 2D matrix for fp8 methods - x_2d = x.view(-1, x.shape[-1]) - output_shape = [*x.shape[:-1], w.shape[1]] - - out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype - - # If input not quantized - # TODO(luka) remove this path if not used anymore - x_2d_q = x_2d - if x.dtype != current_platform.fp8_dtype(): - x_2d_q, x_s = quant_fp8_func( - x_2d, - x_s, - x_s_ub, - ) - - return scaled_mm_func( - A=x_2d_q, - B=w, - out_dtype=out_dtype, - As=x_s, - Bs=w_s, - bias=bias, - output_shape=output_shape, - )