From 974e6820ceac90e1b70fa3f285ce4441f44c6049 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 28 Oct 2025 16:26:51 +0000 Subject: [PATCH] first try Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 47 ++- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 66 ++-- .../kernels/scaled_mm/__init__.py | 26 +- .../quantization/kernels/scaled_mm/aiter.py | 12 +- .../quantization/kernels/scaled_mm/cutlass.py | 256 +++++++++---- .../kernels/scaled_mm/flash_infer.py | 120 ++++++ .../quantization/kernels/scaled_mm/rocm.py | 179 +++++++++ .../quantization/kernels/scaled_mm/torch.py | 343 ++++++++++++++++++ 8 files changed, 924 insertions(+), 125 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index ee431c9148b86..c1108e96d2135 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -10,6 +10,14 @@ from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + _POSSIBLE_FP8_KERNELS, + choose_scaled_mm_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( + ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, @@ -24,7 +32,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_block_fp8_supported, maybe_create_device_identity, ) @@ -72,9 +79,32 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_q_group_shape, + param_name_list = ["weight", "weight_scale", "input_scale"] + layer_mapping_function = lambda layer: ( + tuple(getattr(layer, param_name) for param_name in param_name_list), + param_name_list, + ) + + # TODO: clean up + if self.strategy == QuantizationStrategy.TENSOR: + weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR + elif self.strategy == QuantizationStrategy.CHANNEL: + weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL + + scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), + is_static_input_scheme=self.is_static_input_scheme, + input_symmetric=True, + weight_quant_strategy=weight_quant_strategy, + activation_group_shape=self.act_q_group_shape, + out_dtype=self.out_dtype, + ) + kernel = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config, + _POSSIBLE_FP8_KERNELS, + ) + self.fp8_linear = kernel( + scaled_mm_linear_kernel_config, layer_mapping_function ) @classmethod @@ -190,11 +220,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): bias=bias, ) - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 2a885ec899458..0445223526c9e 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -2,16 +2,36 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass +from enum import Enum import torch +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape + + +class ScaledMMLinearQuantStrategy(Enum): + TENSOR = "tensor" + CHANNEL = "channel" + BLOCK = "block" + + def is_per_token(self) -> bool: + return self.row == 1 and self.col == -1 + + def is_per_group(self) -> bool: + return self.row == 1 and self.col >= 1 + @dataclass class ScaledMMLinearLayerConfig: + # TODO: remove is channelwise is_channelwise: bool is_static_input_scheme: bool input_symmetric: bool + out_dtype: torch.dtype | None + weight_quant_strategy: ScaledMMLinearQuantStrategy + activation_group_shape: GroupShape | None = GroupShape.PER_TENSOR class ScaledMMLinearKernel(ABC): @@ -26,21 +46,11 @@ class ScaledMMLinearKernel(ABC): raise NotImplementedError def __init__( - self, - c: ScaledMMLinearLayerConfig, - w_q_param_name: str, - w_s_param_name: str, - i_s_param_name: str, - i_zp_param_name: str, - azp_adj_param_name: str, + self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable ) -> None: assert self.can_implement(c) self.config = c - self.w_q_name = w_q_param_name - self.w_s_name = w_s_param_name - self.i_s_name = i_s_param_name - self.i_zp_name = i_zp_param_name - self.azp_adj_name = azp_adj_param_name + self.layer_mapping_function = layer_mapping_function @abstractmethod def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -55,19 +65,19 @@ class ScaledMMLinearKernel(ABC): ) -> torch.Tensor: raise NotImplementedError - def _get_weight_params( - self, layer: torch.nn.Module - ) -> tuple[ - torch.Tensor, # weight - torch.Tensor, # weight_scale - torch.Tensor | None, # input_scale, - torch.Tensor | None, # input_zp - torch.Tensor | None, # azp_adj - ]: - return ( - getattr(layer, self.w_q_name), - getattr(layer, self.w_s_name), - getattr(layer, self.i_s_name), - getattr(layer, self.i_zp_name), - getattr(layer, self.azp_adj_name), - ) + # def _get_weight_params( + # self, layer: torch.nn.Module + # ) -> tuple[ + # torch.Tensor, # weight + # torch.Tensor, # weight_scale + # torch.Tensor | None, # input_scale, + # torch.Tensor | None, # input_zp + # torch.Tensor | None, # azp_adj + # ]: + # return ( + # getattr(layer, self.w_q_name), + # getattr(layer, self.w_s_name), + # getattr(layer, self.i_s_name), + # getattr(layer, self.i_zp_name), + # getattr(layer, self.azp_adj_name), + # ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index dd59e5d935dcb..2ad21162995fe 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -12,10 +12,18 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassScaledMMLinearKernel, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( + ROCmScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 ScaledMMLinearKernel, ScaledMMLinearLayerConfig, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( + ChannelWiseTorchScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( TritonScaledMMLinearKernel, ) @@ -25,16 +33,28 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) -_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { +_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], } +_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { + PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], + PlatformEnum.ROCM: [ + ROCmScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, + ], +} + def choose_scaled_mm_linear_kernel( - config: ScaledMMLinearLayerConfig, compute_capability: int | None = None + config: ScaledMMLinearLayerConfig, + possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]], + compute_capability: int | None = None, ) -> type[ScaledMMLinearKernel]: """ Choose an ScaledMMLinearKernel that can implement the given config for the @@ -61,7 +81,7 @@ def choose_scaled_mm_linear_kernel( compute_capability = _cc[0] * 10 + _cc[1] failure_reasons = [] - for kernel in _POSSIBLE_KERNELS[current_platform._enum]: + for kernel in possible_kernels[current_platform._enum]: if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): failure_reasons.append( f" {kernel.__name__} disabled by environment variable" diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index a19396a162bcb..7dc1a57f1ecd3 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -9,8 +9,8 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op -from .cutlass import CutlassScaledMMLinearKernel -from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig +from .cutlass import process_weights_after_loading +from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig def rocm_aiter_gemm_w8a8_impl( @@ -52,7 +52,7 @@ if current_platform.is_rocm(): ) -class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): +class AiterScaledMMLinearKernel(ScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: return 90 @@ -92,7 +92,9 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - super().process_weights_after_loading(layer) + _, param_names = self.layer_mapping_function(layer) + + process_weights_after_loading(self.config, layer, *param_names) def apply_weights( self, @@ -110,7 +112,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support ATIER block scaled GEMM and mix-precision GEMM. """ - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + (w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index e8769916b4cef..6e88d65acd453 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -2,10 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, ) @@ -14,6 +18,111 @@ from vllm.platforms import current_platform from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +def cutlass_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm( + A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias + ) + return output.view(*output_shape) + + +def process_weights_after_loading( + config: ScaledMMLinearLayerConfig, + layer: torch.nn.Module, + w_q_name: str, + w_s_name: str, + i_s_name: str, + i_zp_name: str, + azp_adj_name: str, +): + # WEIGHT + # Cutlass kernels need transposed weight. + weight = getattr(layer, w_q_name) + replace_parameter( + layer, + w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False), + ) + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, w_s_name) + if is_fused_module and not config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) + replace_parameter( + layer, + w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False), + ) + + # INPUT SCALE + if config.is_static_input_scheme: + input_scale = getattr(layer, i_s_name) + + if config.input_symmetric: + replace_parameter( + layer, + i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False), + ) + setattr(layer, i_zp_name, None) + else: + input_zero_point = getattr(layer, i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) + replace_parameter( + layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False) + ) + + # AZP loaded as int8 but used as int32 + azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) + replace_parameter( + layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + ) + + else: + setattr(layer, i_s_name, None) + setattr(layer, i_zp_name, None) + + # azp_adj is the AZP adjustment term, used to account for weights. + # It does not depend on scales or azp, so it is the same for + # static and dynamic quantization. + # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md + # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md + if not config.input_symmetric: + weight = getattr(layer, w_q_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) + if config.is_static_input_scheme: + # cutlass_w8a8 requires azp to be folded into azp_adj + # in the per-tensor case + azp_adj = getattr(layer, i_zp_name) * azp_adj + setattr( + layer, + azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False), + ) + else: + setattr(layer, azp_adj_name, None) + + class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -27,83 +136,9 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # WEIGHT - # Cutlass kernels need transposed weight. - weight = getattr(layer, self.w_q_name) - replace_parameter( - layer, - self.w_q_name, - torch.nn.Parameter(weight.t().data, requires_grad=False), - ) + _, param_names = self.layer_mapping_function(layer) - # WEIGHT SCALE - # Cutlass kernels support only per-tensor and per-channel. - # If we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), convert to the per-channel case. - is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) - if is_fused_module and not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) - replace_parameter( - layer, - self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False), - ) - - # INPUT SCALE - if self.config.is_static_input_scheme: - input_scale = getattr(layer, self.i_s_name) - - if self.config.input_symmetric: - replace_parameter( - layer, - self.i_s_name, - torch.nn.Parameter(input_scale.max(), requires_grad=False), - ) - setattr(layer, self.i_zp_name, None) - else: - input_zero_point = getattr(layer, self.i_zp_name) - - # reconstruct the ranges - int8_traits = torch.iinfo(torch.int8) - azps = input_zero_point.to(dtype=torch.int32) - range_max = (input_scale * (int8_traits.max - azps)).max() - range_min = (input_scale * (int8_traits.min - azps)).min() - - scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) - replace_parameter( - layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) - ) - - # AZP loaded as int8 but used as int32 - azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) - replace_parameter( - layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) - ) - - else: - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - - # azp_adj is the AZP adjustment term, used to account for weights. - # It does not depend on scales or azp, so it is the same for - # static and dynamic quantization. - # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md - # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md - if not self.config.input_symmetric: - weight = getattr(layer, self.w_q_name) - azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) - if self.config.is_static_input_scheme: - # cutlass_w8a8 requires azp to be folded into azp_adj - # in the per-tensor case - azp_adj = getattr(layer, self.i_zp_name) * azp_adj - setattr( - layer, - self.azp_adj_name, - torch.nn.Parameter(azp_adj, requires_grad=False), - ) - else: - setattr(layer, self.azp_adj_name, None) + process_weights_after_loading(self.config, layer, *param_names) def apply_weights( self, @@ -111,7 +146,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + (w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. @@ -138,3 +173,70 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): return ops.cutlass_scaled_mm( x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias ) + + +class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel): + def __init__( + self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + ) -> None: + self.quant_fp8 = QuantFP8( + static=c.is_static_input_scheme, + group_shape=GroupShape.PER_TENSOR, + num_token_padding=None, + ) + super().__init__(c, layer_mapping_function) + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_cuda(): + return ( + False, + "CutlassFP8ScaledMMLinearKernel is supported " + + "on CUDA platforms Only.", + ) + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ): + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + (w, w_s, x_s), _ = self.layer_mapping_function(layer) + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + + out_dtype = self.config.out_dtype + out_dtype = x.dtype if out_dtype is None else out_dtype + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + ) + + output_shape = [*x_2d_q.shape[:-1], w.shape[1]] + + return cutlass_w8a8_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py new file mode 100644 index 0000000000000..9940ef49bb3e0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch + +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer + +from .ScaledMMLinearKernel import ( + ScaledMMLinearKernel, + ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, +) + + +def flashinfer_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return flashinfer_scaled_fp8_mm( + A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias + ) + + +class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel): + def __init__( + self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + ) -> None: + self.quant_fp8 = QuantFP8( + static=c.is_static_input_scheme, + group_shape=GroupShape.PER_TENSOR, + num_token_padding=None, + ) + super().__init__(c, layer_mapping_function) + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() + per_tensor_weight_scales = ( + c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + ) + + if not current_platform.is_cuda(): + return ( + False, + "FlashInferScaledMMLinearKernel is supported " + + "on CUDA platforms Only.", + ) + + if not has_flashinfer(): + return ( + False, + "FlashInferScaledMMLinearKernel requires " + + "FlashInfer to be installed.", + ) + if not has_flashinfer(): + return ( + False, + "FlashInferScaledMMLinearKernel requires " + + "FlashInfer to be installed.", + ) + + if not (per_tensor_activation_scales and per_tensor_weight_scales): + return ( + False, + "FlashInferScaledMMLinearKernel requires " + + "per tensor activation and weight scales.", + ) + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ): + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + (w, w_s, x_s), _ = self.layer_mapping_function(layer) + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + + out_dtype = self.config.out_dtype + out_dtype = x.dtype if out_dtype is None else out_dtype + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + ) + + output_shape = [*x_2d_q.shape[:-1], w.shape[1]] + + return flashinfer_w8a8_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py new file mode 100644 index 0000000000000..74454743fb0db --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op + +from .ScaledMMLinearKernel import ( + ScaledMMLinearKernel, + ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, +) + + +def rocm_per_tensor_float_w8a8_scaled_mm_impl( + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + if ( + A.shape[0] == 1 + and B.shape[1] % 16 == 0 + and ((bias is None) or (bias.dtype == out_dtype)) + ): + output = ops.wvSplitKQ( + B.t(), + A, + out_dtype, + As, + Bs, + current_platform.get_cu_count(), + bias, + ) + # Fallabck + else: + output = torch._scaled_mm( + A, + B, + out_dtype=out_dtype, + scale_a=As, + scale_b=Bs, + bias=bias, + ) + return output + + +def rocm_per_tensor_float_w8a8_scaled_mm_fake( + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype) + + +def rocm_per_tensor_float_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list[int], +) -> torch.Tensor: + output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( + A, B, out_dtype, As, Bs, bias + ) + return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl, + fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake, + ) + + +class ROCmScaledMMLinearKernel(ScaledMMLinearKernel): + def __init__( + self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + ) -> None: + self.quant_fp8 = QuantFP8( + static=c.is_static_input_scheme, + group_shape=GroupShape.PER_TENSOR, + num_token_padding=None, + ) + super().__init__(c, layer_mapping_function) + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + # TODO: check if this causes an issue on non-ROCM platforms + from vllm.platforms.rocm import on_mi3xx + + per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() + per_tensor_weight_scales = ( + c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + ) + + if not current_platform.is_rocm(): + return ( + False, + "ROCmScaledMMLinearFP8Kernel is supported " + "on ROCm platforms Only.", + ) + if not on_mi3xx(): + return ( + False, + "ROCmScaledMMLinearFP8Kernel is supported " + + "on MI3xx architures only.", + ) + if not envs.VLLM_ROCM_USE_SKINNY_GEMM: + return ( + False, + "VLLM_ROCM_USE_SKINNY_GEMM must be enabled " + + "to use ROCmScaledMMLinearKernel ", + ) + + if not (per_tensor_activation_scales and per_tensor_weight_scales): + return ( + False, + "ROCmScaledMMLinearKernel requires " + + "per tensor activation and weight scales.", + ) + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ): + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + (w, w_s, x_s), _ = self.layer_mapping_function(layer) + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + + out_dtype = self.config.out_dtype + out_dtype = x.dtype if out_dtype is None else out_dtype + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + ) + + output_shape = [*x_2d_q.shape[:-1], w.shape[1]] + + return rocm_per_tensor_float_w8a8_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py new file mode 100644 index 0000000000000..0b2c0a8b49fd1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/torch.py @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch +from packaging import version + +from vllm.config import CompilationMode, get_current_vllm_config +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import ( + ScaledMMLinearKernel, + ScaledMMLinearLayerConfig, + ScaledMMLinearQuantStrategy, +) + +# Input scaling factors are no longer optional in _scaled_mm starting +# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale +TORCH_DEVICE_IDENTITY = None + + +def maybe_create_device_identity(): + # Allocate dummy ones tensor for torch._scaled_mm + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY is None: + TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) + + +def torch_per_tensor_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + output = torch._scaled_mm( + A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape) + + +def torch_row_wise_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM + # when using it. + # For now it has only been validated on ROCm platform. + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using + # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. + # + # For CUDA platform please validate if the torch._scaled_mm supports + # rowwise scaled GEMM before using it + + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm( + A, + B, + out_dtype=out_dtype, + scale_a=As, + scale_b=Bs.t(), + bias=bias, + ) + + output = torch.narrow(output, 0, 0, A.shape[0]) + output = output.view(*output_shape) + return output + + +def torch_channelwise_w8a8_scaled_mm( + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + # Use unfused DQ due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm( + A, + B, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, A.shape[0]) + x_scale = torch.narrow(As, 0, 0, A.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * Bs.t() + if bias is not None: + output = output + bias + return output.to(out_dtype).view(*output_shape) + + +class TorchScaledMMLinearKernel(ScaledMMLinearKernel): + def __init__( + self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable + ) -> None: + vllm_config = get_current_vllm_config().compilation_config + pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE + + output_padding = 17 if pad_output else None + + self.quant_fp8 = QuantFP8( + static=c.is_static_input_scheme, + group_shape=GroupShape.PER_TENSOR, + num_token_padding=output_padding, + ) + super().__init__(c, layer_mapping_function) + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + return + + +class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + assert c.activation_group_shape is not None + per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() + per_tensor_weight_scales = ( + c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + ) + + if not (per_tensor_activation_scales and per_tensor_weight_scales): + return ( + False, + "PerTensorTorchScaledMMLinearKernel requires " + + "per tensor activation and weight scales.", + ) + return True, None + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ): + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + (w, w_s, x_s), _ = self.layer_mapping_function(layer) + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + + out_dtype = self.config.out_dtype + out_dtype = x.dtype if out_dtype is None else out_dtype + + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + ) + output_shape = [*x_2d_q.shape[:-1], w.shape[1]] + return torch_per_tensor_w8a8_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) + + +class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): + @classmethod + def get_min_capability(cls) -> int: + return 94 + + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + assert c.activation_group_shape is not None + + per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() + per_tensor_weight_scales = ( + c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + ) + + if per_tensor_activation_scales and per_tensor_weight_scales: + return ( + False, + "RowWiseTorchScaledMMLinearKernel cannot be used with " + + "per tensor activation and weight scales.", + ) + + if not current_platform.is_rocm(): + return ( + False, + "RowWiseTorchScaledMMLinearKernel is only supported " + + "in ROCm platforms.", + ) + + if not version.parse(torch.__version__) >= version.parse("2.7"): + return ( + False, + "RowWiseTorchScaledMMLinearKernel requires " + "pytorch version >=2.7.", + ) + + return True, None + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ): + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + (w, w_s, x_s), _ = self.layer_mapping_function(layer) + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + + out_dtype = self.config.out_dtype + out_dtype = x.dtype if out_dtype is None else out_dtype + + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + ) + output_shape = [*x_2d_q.shape[:-1], w.shape[1]] + return torch_row_wise_w8a8_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + ) + + +class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): + @classmethod + def get_min_capability(cls) -> int: + return 94 + + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + assert c.activation_group_shape is not None + + per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() + per_tensor_weight_scales = ( + c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR + ) + + if per_tensor_activation_scales and per_tensor_weight_scales: + return ( + False, + "ChannelWiseTorchScaledMMLinearKernel cannot be used with " + + "per tensor activation and weight scales.", + ) + + return True, None + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ): + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + (w, w_s, x_s), _ = self.layer_mapping_function(layer) + # View input as 2D matrix for fp8 methods + x_2d = x.view(-1, x.shape[-1]) + + out_dtype = self.config.out_dtype + out_dtype = x.dtype if out_dtype is None else out_dtype + + # If input not quantized + # TODO(luka) remove this path if not used anymore + x_2d_q = x_2d + if x.dtype != current_platform.fp8_dtype(): + x_2d_q, x_s = self.quant_fp8( + x_2d, + x_s, + ) + output_shape = [*x_2d_q.shape[:-1], w.shape[1]] + return torch_channelwise_w8a8_scaled_mm( + A=x_2d_q, + B=w, + out_dtype=out_dtype, + As=x_s, + Bs=w_s, + bias=bias, + output_shape=output_shape, + )