From c089ea5753cf5dff4d26fc21f9c729dd3485ef6c Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 30 Oct 2025 14:24:19 +0000 Subject: [PATCH] update quark fp8 path; format Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 27 +++++---- .../schemes/compressed_tensors_w8a8_int8.py | 24 +++++--- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 60 ++++++++++--------- .../kernels/scaled_mm/__init__.py | 1 - .../quantization/kernels/scaled_mm/cpu.py | 9 +-- .../quantization/kernels/scaled_mm/cutlass.py | 24 ++++---- .../kernels/scaled_mm/flash_infer.py | 12 ++-- .../quantization/kernels/scaled_mm/rocm.py | 8 +-- .../quantization/kernels/scaled_mm/torch.py | 15 +++-- .../quantization/kernels/scaled_mm/utils.py | 25 ++++---- .../quantization/kernels/scaled_mm/xla.py | 9 +-- .../quark/schemes/quark_w8a8_fp8.py | 54 +++++++++++++---- .../quark/schemes/quark_w8a8_int8.py | 26 +++++--- 13 files changed, 172 insertions(+), 122 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 53e7ed2fb3fcb..a872ee15c7ae6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -6,8 +6,8 @@ from collections.abc import Callable import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from torch.nn import Parameter -from vllm.logger import init_logger +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) @@ -15,10 +15,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( _POSSIBLE_FP8_KERNELS, choose_scaled_mm_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( - FP8ScaledMMLinearLayerConfig, - ScaledMMLinearQuantStrategy, +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 QUANT_STRATEGY_MAP, + FP8ScaledMMLinearLayerConfig, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -53,6 +52,7 @@ strategy_to_parameter_type = { logger = init_logger(__name__) + class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): _kernel_backends_being_used: set[str] = set() @@ -92,17 +92,20 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): activation_group_shape=self.act_q_group_shape, out_dtype=self.out_dtype, ) - kernel = choose_scaled_mm_linear_kernel( + kernel_type = choose_scaled_mm_linear_kernel( scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, ) - self.fp8_linear = kernel( - scaled_mm_linear_kernel_config, layer_param_names = layer_param_names - ) - if kernel.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsW8A8FP8", kernel.__name__) - self._kernel_backends_being_used.add(kernel.__name__) + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info( + "Using %s for CompressedTensorsW8A8FP8", kernel_type.__name__ + ) + self._kernel_backends_being_used.add(kernel_type.__name__) + + self.kernel = kernel_type( + scaled_mm_linear_kernel_config, layer_param_names=layer_param_names + ) @classmethod def get_min_capability(cls) -> int: @@ -217,4 +220,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): bias=bias, ) - return self.fp8_linear.apply_weights(layer, x, bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index a0ae8655ca650..e662a1af7f1ff 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -11,8 +11,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + _POSSIBLE_INT8_KERNELS, choose_scaled_mm_linear_kernel, - _POSSIBLE_INT8_KERNELS +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + Int8ScaledMMLinearLayerConfig, ) from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -20,7 +23,6 @@ from vllm.model_executor.parameter import ( ModelWeightParameter, PerTensorScaleParameter, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig logger = init_logger(__name__) @@ -58,8 +60,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ) kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config, - _POSSIBLE_INT8_KERNELS + scaled_mm_linear_kernel_config, _POSSIBLE_INT8_KERNELS ) if kernel_type.__name__ not in self._kernel_backends_being_used: @@ -94,8 +95,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE - input_zero_point=None - input_scale=None + input_zero_point = None + input_scale = None if self.is_static_input_scheme: input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader @@ -113,11 +114,16 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): if not hasattr(layer, "azp_adj"): layer.register_parameter("azp_adj", None) - layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] + layer_param_names = [ + "weight", + "weight_scale", + "input_scale", + "input_zero_point", + "azp_adj", + ] self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, - layer_param_names = layer_param_names + c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names ) # Checkpoints are serialized in compressed-tensors format, which is diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 27af30ae131c8..b9acd89f69d82 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -2,15 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum -from typing import Generic, Sequence, TypeVar +from typing import Generic, TypeVar + import torch from compressed_tensors.quantization import QuantizationStrategy -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape class ScaledMMLinearQuantStrategy(Enum): @@ -18,21 +19,24 @@ class ScaledMMLinearQuantStrategy(Enum): CHANNEL = "channel" BLOCK = "block" + QUANT_STRATEGY_MAP = { QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR, QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL, - QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.BLOCK, } + @dataclass class ScaledMMLinearLayerConfig: is_static_input_scheme: bool + @dataclass class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): is_channelwise: bool input_symmetric: bool + @dataclass class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): weight_quant_strategy: ScaledMMLinearQuantStrategy @@ -40,22 +44,22 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): out_dtype: torch.dtype - -Int8ParamsT = tuple[ - torch.Tensor, # weight - torch.Tensor, # weight_scale - torch.Tensor | None, # input_scale, -] FP8ParamsT = tuple[ - torch.Tensor, # weight - torch.Tensor, # weight_scale - torch.Tensor | None, # input_scale, - torch.Tensor | None, # input_zp - torch.Tensor | None, # azp_adj - ] + torch.Tensor, # weight + torch.Tensor, # weight_scale + torch.Tensor | None, # input_scale, +] +Int8ParamsT = tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + torch.Tensor | None, # input_scale, + torch.Tensor | None, # input_zp + torch.Tensor | None, # azp_adj +] + +ParamsT = TypeVar("ParamsT", Int8ParamsT, FP8ParamsT) +ConfigT = TypeVar("ConfigT", bound=ScaledMMLinearLayerConfig) -ParamsT = TypeVar('ParamsT', Int8ParamsT, FP8ParamsT) -ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig) class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC): @classmethod @@ -68,9 +72,7 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC): def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]: raise NotImplementedError - def __init__( - self, c: ConfigT, layer_param_names: Sequence[str] - ) -> None: + def __init__(self, c: ConfigT, layer_param_names: Sequence[str]) -> None: assert self.can_implement(c) self.config = c self.layer_param_names = layer_param_names @@ -87,16 +89,18 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC): bias: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError - + # return a covariant type in the subclass @abstractmethod def _get_layer_params(self, layer) -> ParamsT: raise NotImplementedError -class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC): +class FP8ScaledMMLinearKernel( + ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC +): def __init__( - self, c: ConfigT, layer_param_names: Sequence[str] + self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str] ) -> None: self.quant_fp8 = QuantFP8( static=c.is_static_input_scheme, @@ -104,7 +108,7 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, num_token_padding=self.get_ouput_padding(), ) super().__init__(c, layer_param_names) - + @abstractmethod def get_ouput_padding(self) -> int | None: raise NotImplementedError @@ -113,7 +117,7 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, def get_min_capability(cls) -> int: # lovelace and up return 89 - + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass @@ -126,7 +130,9 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, ) -class Int8ScaledMMLinearKernel(ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC): +class Int8ScaledMMLinearKernel( + ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC +): def _get_layer_params(self, layer) -> Int8ParamsT: w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names return ( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 85aaf51ae844c..2ad21162995fe 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer ScaledMMLinearKernel, ScaledMMLinearLayerConfig, ) - from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( ChannelWiseTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index b84ef7814f0a2..7fa47dd854af6 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -14,7 +14,10 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ( + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, +) class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel): @@ -49,9 +52,7 @@ class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel): def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: # WEIGHT # Transpose to [K, N] for convenience - w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( - self.layer_param_names - ) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names weight = getattr(layer, w_q_name) replace_parameter( layer, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 2a8b68980949a..28348f50fc273 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -2,22 +2,24 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable - import torch from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils import replace_parameter -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearKernel, Int8ScaledMMLinearKernel +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, +) from .utils import apply_weights_fp8 + def cutlass_w8a8_scaled_mm_fp8( *, A: torch.Tensor, @@ -34,6 +36,7 @@ def cutlass_w8a8_scaled_mm_fp8( ) return output.view(*output_shape) + class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -47,9 +50,7 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel): return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( - self.layer_param_names - ) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names config = self.config # WEIGHT # Cutlass kernels need transposed weight. @@ -105,7 +106,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel): layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) ) - # azp_adj is the AZP adjustment term, used to account for weights. # It does not depend on scales or azp, so it is the same for # static and dynamic quantization. @@ -124,7 +124,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel): torch.nn.Parameter(azp_adj, requires_grad=False), ) - def apply_weights( self, layer: torch.nn.Module, @@ -161,7 +160,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel): class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): - def get_ouput_padding(self) -> int | None: return None @@ -191,5 +189,5 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): w_s, x_s, bias, - self.config.out_dtype - ) \ No newline at end of file + self.config.out_dtype, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py index 5cb4fa7150d41..8fd2c88857cab 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -1,17 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer from .ScaledMMLinearKernel import ( FP8ScaledMMLinearKernel, - Int8ScaledMMLinearLayerConfig, + FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) from .utils import apply_weights_fp8 @@ -32,7 +29,6 @@ def flashinfer_w8a8_scaled_mm( class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): - def get_ouput_padding(self) -> int | None: return None @@ -41,7 +37,7 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): return 100 @classmethod - def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_weight_scales = ( c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR @@ -90,5 +86,5 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): w_s, x_s, bias, - self.config.out_dtype - ) \ No newline at end of file + self.config.out_dtype, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 8abe124c4b6f4..6144a94b7fb91 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op @@ -18,6 +15,7 @@ from .ScaledMMLinearKernel import ( ) from .utils import apply_weights_fp8 + def rocm_per_tensor_float_w8a8_scaled_mm_impl( A: torch.Tensor, B: torch.Tensor, @@ -40,7 +38,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl( current_platform.get_cu_count(), bias, ) - # Fallabck + # Fallback else: output = torch._scaled_mm( A, @@ -143,5 +141,5 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): w_s, x_s, bias, - self.config.out_dtype + self.config.out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py index 8e5fc66e4fed8..c2a8474ac5b47 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch from packaging import version from vllm.config import CompilationMode, get_current_vllm_config -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from .ScaledMMLinearKernel import ( @@ -15,8 +12,8 @@ from .ScaledMMLinearKernel import ( FP8ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) - from .utils import apply_weights_fp8 + # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale TORCH_DEVICE_IDENTITY = None @@ -142,6 +139,7 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel): output_padding = 17 if pad_output else None return output_padding + class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: @@ -173,9 +171,10 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): w_s, x_s, bias, - self.config.out_dtype + self.config.out_dtype, ) + class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -199,7 +198,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): return ( False, "RowWiseTorchScaledMMLinearKernel is only supported " - + "in ROCm platforms.", + + "on ROCm platforms.", ) if not version.parse(torch.__version__) >= version.parse("2.7"): @@ -225,7 +224,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): w_s, x_s, bias, - self.config.out_dtype + self.config.out_dtype, ) @@ -265,5 +264,5 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): w_s, x_s, bias, - self.config.out_dtype + self.config.out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py index e1d5a291b8463..9f4e9a7befc45 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py @@ -1,20 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable -import torch + +import torch + from vllm.platforms import current_platform FP8ScaledMMCallBack = Callable[..., torch.Tensor] FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]] + def apply_weights_fp8( - scaled_mm_func: FP8ScaledMMCallBack, - quant_fp8_func: FP8QuantCallback, - w:torch.Tensor, - x:torch.Tensor, - w_s:torch.Tensor, - x_s:torch.Tensor, - bias:torch.Tensor, - maybe_out_dtype: torch.dtype | None, - ) -> torch.Tensor: + scaled_mm_func: FP8ScaledMMCallBack, + quant_fp8_func: FP8QuantCallback, + w: torch.Tensor, + x: torch.Tensor, + w_s: torch.Tensor, + x_s: torch.Tensor, + bias: torch.Tensor, + maybe_out_dtype: torch.dtype | None, +) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_s computed from x. # If static, layer.input_scale is scalar and x_s is input_scale. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index bafaf06ed7962..02ec0d931bfdd 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -12,7 +12,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ( + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, +) class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel): @@ -42,9 +45,7 @@ class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # WEIGHT # [out, in] (different than cutlass_scaled_mm) - w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( - self.layer_param_names - ) + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names weight = getattr(layer, w_q_name) replace_parameter( layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 1e5ee93b61f2b..6c296fe9a5807 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -7,10 +7,18 @@ from typing import Any, cast import torch from torch.nn import Parameter +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + _POSSIBLE_FP8_KERNELS, + choose_scaled_mm_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + FP8ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale, ) @@ -23,8 +31,17 @@ from vllm.platforms import current_platform __all__ = ["QuarkW8A8Fp8"] +logger = init_logger(__name__) + +QUANT_STRATEGY_MAP = { + "per_tensor": ScaledMMLinearQuantStrategy.TENSOR, + "per_channel": ScaledMMLinearQuantStrategy.CHANNEL, +} + class QuarkW8A8Fp8(QuarkScheme): + _kernel_backends_being_used: set[str] = set() + def __init__( self, weight_config: dict[str, Any], input_config: dict[str, Any] | None ): @@ -41,10 +58,6 @@ class QuarkW8A8Fp8(QuarkScheme): self.act_quant_group_shape = ( GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR ) - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_quant_group_shape, - ) self.out_dtype = torch.get_default_dtype() @classmethod @@ -163,17 +176,32 @@ class QuarkW8A8Fp8(QuarkScheme): input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) + layer_param_names = ["weight", "weight_scale", "input_scale"] + weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme] + scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( + is_static_input_scheme=self.is_static_input_scheme, + weight_quant_strategy=weight_quant_strategy, + activation_group_shape=self.act_quant_group_shape, + out_dtype=self.out_dtype, + ) + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, + _POSSIBLE_FP8_KERNELS, + ) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for QuarkW8A8FP8", kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + layer_param_names = ["weight", "weight_scale", "input_scale"] + self.kernel = kernel_type( + c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names + ) + def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index 856d7fb32c096..2fb69fe5e40e2 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -7,9 +7,12 @@ import torch from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + _POSSIBLE_INT8_KERNELS, choose_scaled_mm_linear_kernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + Int8ScaledMMLinearLayerConfig, +) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -56,7 +59,9 @@ class QuarkW8A8Int8(QuarkScheme): input_symmetric=(self.input_symmetric is True), ) - kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, possible_kernels=_POSSIBLE_INT8_KERNELS + ) if kernel_type.__name__ not in self._kernel_backends_being_used: logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__) @@ -102,8 +107,8 @@ class QuarkW8A8Int8(QuarkScheme): layer.register_parameter("weight_zero_point", weight_zero_point) # INPUT SCALE - input_zero_point=None - input_scale=None + input_zero_point = None + input_scale = None if self.is_static_input_scheme: input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader @@ -117,12 +122,17 @@ class QuarkW8A8Int8(QuarkScheme): layer.register_parameter("input_zero_point", input_zero_point) if not hasattr(layer, "azp_adj"): layer.register_parameter("azp_adj", None) - - layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] + + layer_param_names = [ + "weight", + "weight_scale", + "input_scale", + "input_zero_point", + "azp_adj", + ] self.kernel = kernel_type( - c=scaled_mm_linear_kernel_config, - layer_param_names = layer_param_names + c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names ) # Checkpoints are serialized in quark format, which is