From 31d5c1797f320b2f407c893673330b3a8766ae47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 11 Jul 2025 00:56:28 -0400 Subject: [PATCH] [Perf][fp8] Use CustomOp abstraction for fp8 quant for better perf (#19830) Signed-off-by: Luka Govedic Co-authored-by: mgoin --- .../kernels/bench_per_token_quant_fp8.py | 98 +++++++++++++++++ tests/compile/test_fusion.py | 11 +- tests/compile/test_fusion_attn.py | 2 + tests/compile/test_silu_mul_quant_fusion.py | 37 +++++-- vllm/attention/backends/abstract.py | 6 +- vllm/attention/backends/rocm_flash_attn.py | 6 +- vllm/compilation/fusion.py | 25 +---- vllm/model_executor/layers/fused_moe/utils.py | 2 + .../schemes/compressed_tensors_24.py | 24 ++-- .../schemes/compressed_tensors_w8a8_fp8.py | 8 +- .../layers/quantization/fbgemm_fp8.py | 6 +- .../model_executor/layers/quantization/fp8.py | 14 ++- .../layers/quantization/input_quant_fp8.py | 103 ++++++++++++++++++ .../layers/quantization/modelopt.py | 5 +- .../layers/quantization/ptpc_fp8.py | 8 +- .../quark/schemes/quark_w8a8_fp8.py | 16 ++- .../layers/quantization/utils/quant_utils.py | 35 ++++-- .../layers/quantization/utils/w8a8_utils.py | 66 ++++++----- 18 files changed, 368 insertions(+), 104 deletions(-) create mode 100644 benchmarks/kernels/bench_per_token_quant_fp8.py create mode 100644 vllm/model_executor/layers/quantization/input_quant_fp8.py diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py new file mode 100644 index 0000000000000..923d678f1f2db --- /dev/null +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +from typing import Callable + +import torch + +from vllm import _custom_ops as ops +from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.triton_utils import triton + + +# TODO(luka): use standalone_compile utility +def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int): + def inner(*args): + torch._dynamo.mark_dynamic(args[arg_index], dim_index) + return fn(*args) + + return inner + + +torch._dynamo.config.recompile_limit = 8888 +compilation_config = CompilationConfig(custom_ops=["none"]) +with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)): + torch_per_token_quant_fp8 = torch.compile( + QuantFP8(False, GroupShape.PER_TOKEN), + fullgraph=True, + dynamic=False, # recompile for different shapes + ) + + # First dim is explicitly dynamic to simulate vLLM usage + torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0) + + +def cuda_per_token_quant_fp8( + input: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + return ops.scaled_fp8_quant(input) + + +def calculate_diff(batch_size: int, seq_len: int): + """Calculate difference between Triton and CUDA implementations.""" + device = torch.device("cuda") + x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device) + + torch_out, torch_scale = torch_per_token_quant_fp8(x) + cuda_out, cuda_scale = cuda_per_token_quant_fp8(x) + + if torch.allclose( + cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5 + ) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [1, 16, 32, 64, 128] +seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] + +configs = list(itertools.product(batch_size_range, seq_len_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len"], + x_vals=configs, + line_arg="provider", + line_vals=["torch", "cuda"], + line_names=["Torch", "CUDA"], + styles=[("blue", "-"), ("green", "-")], + ylabel="us", + plot_name="per-token-dynamic-quant-fp8-performance", + args={}, + ) +) +def benchmark_quantization(batch_size, seq_len, provider): + dtype = torch.float16 + device = torch.device("cuda") + + x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch": + fn = lambda: torch_per_token_quant_fp8(x.clone()) + elif provider == "cuda": + fn = lambda: cuda_per_token_quant_fp8(x.clone()) + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + calculate_diff(batch_size=4, seq_len=4096) + benchmark_quantization.run(print_data=True) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 040fd176fec12..4a3820e20fd89 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -44,7 +44,9 @@ class TestModel(torch.nn.Module): ] self.fp8_linear = Fp8LinearOp( cutlass_fp8_supported=cutlass_fp8_enabled, - use_per_token_if_dynamic=True) + act_quant_static=static, + act_quant_group_shape=group_shape, + ) def forward(self, x): resid = torch.sqrt(x) @@ -91,9 +93,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, maybe_create_device_identity() # needed for certain non-cutlass fp8 paths vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) - vllm_config.compilation_config.pass_config = \ - PassConfig(enable_fusion=True, enable_noop=True) + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + )) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 37ec753bbc9e4..70750eb9ac4ee 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -50,6 +50,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # DYNAMO_ONCE does not properly propagate shapes. level=CompilationLevel.DYNAMO_AS_IS, backend="tests.compile.test_fusion_attn.backend_unfused", + custom_ops=["+quant_fp8"], ) vllm_config = VllmConfig(compilation_config=compile_config) backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) @@ -73,6 +74,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # DYNAMO_ONCE does not properly propagate shapes. level=CompilationLevel.DYNAMO_AS_IS, backend="tests.compile.test_fusion_attn.backend", + custom_ops=["+quant_fp8"], ) vllm_config = VllmConfig(compilation_config=compile_config) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index df36b86abdbe4..5351a3cf35ba5 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -4,33 +4,56 @@ import pytest import torch import vllm.envs as envs -from vllm._custom_ops import scaled_fp8_quant from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe +from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_FP8_SUPPORTED, Fp8LinearOp) +from vllm.platforms import current_platform from .backend import TestBackend class TestModel(torch.nn.Module): - def __init__(self, *args, **kwargs): + def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args, + **kwargs): super().__init__(*args, **kwargs) self.silu_and_mul = SiluAndMul() + self.wscale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32) + self.w = (torch.rand( + hidden_size, + hidden_size).to(dtype=current_platform.fp8_dtype()).t()) + + self.fp8_linear = Fp8LinearOp( + cutlass_fp8_supported=cutlass_fp8_enabled, + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) + def forward(self, x): y = self.silu_and_mul(x) - x2 = scaled_fp8_quant(y, self.scale) + x2 = self.fp8_linear.apply(y, + self.w, + self.wscale, + input_scale=self.wscale) return x2 @pytest.mark.parametrize("num_tokens", [256]) @pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("cutlass_fp8_enabled", + [True, False] if CUTLASS_FP8_SUPPORTED else [False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") -def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, + cutlass_fp8_enabled): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) @@ -40,11 +63,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): pass_config=PassConfig(enable_fusion=True, enable_noop=True)) fusion_pass = ActivationQuantFusionPass(config) - backend = TestBackend(fusion_pass) - model = TestModel() + backend = TestBackend(NoOpEliminationPass(config), fusion_pass) + model = TestModel(hidden_size, cutlass_fp8_enabled) # First dimension dynamic - x = torch.rand(num_tokens, hidden_size) + x = torch.rand(num_tokens, hidden_size * 2) torch._dynamo.mark_dynamic(x, 0) result = model(x) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 990ea054f3380..05c098a58a0d2 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,6 +9,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, import torch +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: @@ -289,7 +291,7 @@ class AttentionImpl(ABC, Generic[T]): raise NotImplementedError def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: tuple[int, int]): + group_shape: GroupShape): """ Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization @@ -298,7 +300,7 @@ class AttentionImpl(ABC, Generic[T]): TODO(luka) merge parameters into QuantDescriptor :param dtype: quantized dtype :param static: static or dynamic quantization - :param group_shape: quant group shape. (-1, -1) for per-tensor. + :param group_shape: quant group shape. :return: is fusion supported for this type of quantization """ return False diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 1e2c21f4e69d6..0b7783758dda7 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -19,6 +19,8 @@ from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.config import get_current_vllm_config from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform from vllm.platforms.rocm import use_rocm_custom_paged_attention @@ -598,10 +600,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): head_dim)) def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: tuple[int, int]): + group_shape: GroupShape): if self.use_triton_flash_attn: return dtype == current_platform.fp8_dtype( - ) and static and group_shape == (-1, -1) # per-tensor + ) and static and group_shape == GroupShape.PER_TENSOR # Only supported in the Triton backend return False diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 951a2861e3a40..3dec939c28351 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, ClassVar, NamedTuple, Optional +from typing import Callable, NamedTuple, Optional import torch import torch._inductor.pattern_matcher as pm @@ -11,6 +11,8 @@ from torch._ops import OpOverload from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform from .fx_utils import find_getitem_maybe @@ -33,27 +35,6 @@ RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default -# Use proxy as NamedTuple direct subclasses cannot have static members -class _GroupShape(NamedTuple): - row: int - col: int - - -class GroupShape(_GroupShape): - """ - This class describes the quantization group shape. - It includes static members for common shapes (per-tensor, per-token). - """ - - # Aliases for common quantization group shapes - PER_TENSOR: ClassVar['GroupShape'] - PER_TOKEN: ClassVar['GroupShape'] - - -GroupShape.PER_TENSOR = GroupShape(-1, -1) -GroupShape.PER_TOKEN = GroupShape(1, -1) - - class QuantKey(NamedTuple): """ Named tuple for identifying the type of quantization. diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 75228d3faf3d9..6638f423a32ef 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -111,6 +111,8 @@ def _fp8_quantize( is provided, the output will be blocked. """ if block_shape is None: + # TODO(luka): use QuantFP8 custom op + # https://github.com/vllm-project/vllm/issues/20711 A, A_scale = ops.scaled_fp8_quant( A, A_scale, use_per_token_if_dynamic=per_act_token) else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 30ed55aee04f8..168b221a9cfe9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -15,6 +15,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, sparse_cutlass_supported) from vllm.model_executor.parameter import (BasevLLMParameter, @@ -24,6 +27,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter, __all__ = ["CompressedTensors24"] +from vllm.platforms import current_platform + class CompressedTensors24(CompressedTensorsScheme): @@ -45,6 +50,12 @@ class CompressedTensors24(CompressedTensorsScheme): and self.model_compressor.sparsity_config.format == CompressionFormat.sparse_24_bitmask.value) + if quantized and input_quant is not None and \ + self._get_quant_dtype() == current_platform.fp8_dtype(): + static = not input_quant.dynamic + g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN + self.quant_fp8 = QuantFP8(static, g_shape) + @classmethod def get_min_capability(cls) -> int: # Only cutlass 3.x kernels are implemented so far @@ -232,9 +243,7 @@ class CompressedTensors24(CompressedTensorsScheme): :return: The output tensor of the layer """ if self.quantized: - scale = None - if hasattr(layer, "input_scale"): - scale = layer.input_scale + scale = getattr(layer, 'input_scale', None) if self.weights_dtype == torch.int8: ops_output = ops.scaled_int8_quant(x, scale=scale) @@ -242,11 +251,7 @@ class CompressedTensors24(CompressedTensorsScheme): input_scale = ops_output[1] else: assert self.weights_dtype == torch.float8_e4m3fn - if scale is not None: - q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale) - else: - q_input, input_scale = ops.scaled_fp8_quant( - x, use_per_token_if_dynamic=True) + q_input, input_scale = self.quant_fp8(x, scale=scale) else: # Not quantized, nothing to do with the input_scales, use as is @@ -269,7 +274,10 @@ class CompressedTensors24(CompressedTensorsScheme): def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: if not self.quantized: return params_dtype + return self._get_quant_dtype() + def _get_quant_dtype(self) -> torch.dtype: + assert self.quantized assert self.weight_quant is not None assert self.input_quant is not None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 1e61e058cb84c..d984e89d9e02a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -9,6 +9,8 @@ from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) @@ -26,7 +28,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): self.strategy = strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + self.act_q_group_shape = GroupShape.PER_TENSOR \ + if is_static_input_scheme else GroupShape.PER_TOKEN + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_q_group_shape) @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 3e465ee2cdd21..b2cab7d4614ad 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + GroupShape, is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, @@ -37,7 +37,6 @@ class FBGEMMFp8Config(QuantizationConfig): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = not current_platform.has_device_capability(89) - self.fp8_linear = Fp8LinearOp() @classmethod def get_name(cls) -> QuantizationMethods: @@ -76,7 +75,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + self.fp8_linear = Fp8LinearOp( + act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN) self.out_dtype = torch.get_default_dtype() def create_weights( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1e98e6c713840..59db3e6c4449b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + GroupShape, is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, cutlass_fp8_supported, maybe_create_device_identity, @@ -202,9 +202,17 @@ class Fp8LinearMethod(LinearMethodBase): and current_platform.is_fp8_fnuz()) self.block_quant = self.quant_config.weight_block_size is not None + self.act_q_static = self.quant_config.activation_scheme == "static" + # Use per-token quantization for better perf if dynamic and cutlass + if not self.act_q_static and cutlass_fp8_supported(): + self.act_q_group_shape = GroupShape.PER_TOKEN + else: + self.act_q_group_shape = GroupShape.PER_TENSOR + self.fp8_linear = Fp8LinearOp( - # Default to using per_token quantization if cutlass is supported - use_per_token_if_dynamic=cutlass_fp8_supported()) + act_quant_static=self.act_q_static, + act_quant_group_shape=self.act_q_group_shape, + cutlass_fp8_supported=cutlass_fp8_supported()) def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py new file mode 100644 index 0000000000000..e1a9bdde9334d --- /dev/null +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.platforms import current_platform + +# Using the default value (240.0) from pytorch will cause accuracy +# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm. +_FP8_DTYPE = current_platform.fp8_dtype() +_FP8_FINFO = torch.finfo(_FP8_DTYPE) +_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max +_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min +_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) + + +@CustomOp.register("quant_fp8") +class QuantFP8(CustomOp): + """ + Quantize input tensor to per-tensor or per-token FP8. + This CustomOp supports both static and dynamic quantization. + """ + + def __init__(self, + static: bool, + group_shape: GroupShape, + num_token_padding: Optional[int] = None): + """ + + :param static: static or dynamic quantization + :param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR) + :param num_token_padding: Pad the token dimension of output to this size + """ + super().__init__() + self.num_token_padding = num_token_padding + assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} + assert not static or group_shape == GroupShape.PER_TENSOR, \ + "Only per-tensor scales supported for static quantization." + self.static = static + self.group_shape = group_shape + self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN + + def forward_cuda( + self, + x: torch.Tensor, + scale: Optional[torch.Tensor] = None, + scale_ub: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert (scale is not None) == self.static + assert scale_ub is None or (not self.static and self.group_shape + == GroupShape.PER_TOKEN + and scale_ub.numel() == 1) + + return ops.scaled_fp8_quant( + x, + scale, + num_token_padding=self.num_token_padding, + scale_ub=scale_ub, + use_per_token_if_dynamic=self.use_per_token_if_dynamic) + + def forward_native( + self, + x: torch.Tensor, + scale: Optional[torch.Tensor] = None, + scale_ub: Optional[torch.Tensor] = None, + ): + assert (scale is not None) == self.static + assert scale_ub is None or (not self.static and self.group_shape + == GroupShape.PER_TOKEN + and scale_ub.numel() == 1) + + if scale is None: + if self.group_shape == GroupShape.PER_TOKEN: + x_max, _ = x.abs().max(dim=-1) + x_max = x_max.unsqueeze(-1).to(torch.float32) + if scale_ub is not None: + x_max = x_max.clamp(max=scale_ub) + else: + x_max = x.abs().max().unsqueeze(-1).to(torch.float32) + + scale = x_max / _FP8_MAX + scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR) + + # Even for dynamic per-token scales, + # reciprocal performs slightly better than division + out = x.to(torch.float32) * scale.reciprocal() + out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + + # This currently generates an extra Triton kernel in compilation. + # Fortunately, we don't use padding if compiling. + # TODO(luka): benchmark torch._scaled_mm to hopefully remove padding + # in general. + if self.num_token_padding is not None: + padding = max(self.num_token_padding - out.size(0), 0) + out = F.pad(out, (0, 0, 0, padding), "constant", 0.0) + + return out, scale diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 2295c0e5fe9ff..0a4e36f19bf88 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + GroupShape, is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, requantize_with_max_scale) from vllm.model_executor.parameter import (ModelWeightParameter, @@ -102,7 +102,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config): self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp() + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR) def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 32ba1055f9c83..d11cba2caba88 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, Fp8KVCacheMethod, Fp8LinearMethod) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + GroupShape, is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp) from vllm.platforms import current_platform @@ -95,8 +95,10 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): super().__init__(quant_config=quant_config) # Force weight quantization self.quant_config.is_checkpoint_fp8_serialized = False - self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False, - use_per_token_if_dynamic=True) + self.fp8_linear = Fp8LinearOp( + act_quant_static=False, + cutlass_fp8_supported=False, + act_quant_group_shape=GroupShape.PER_TOKEN) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index c7bc98184d0eb..2cb35249f49ef 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -7,6 +7,8 @@ import torch from torch.nn import Parameter from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, @@ -28,10 +30,14 @@ class QuarkW8A8Fp8(QuarkScheme): self.is_static_input_scheme = not cast( bool, input_config.get("is_dynamic")) self.input_qscheme = cast(str, input_config.get("qscheme")) - self.use_per_token_if_dynamic = (not self.is_static_input_scheme \ - and self.input_qscheme == "per_channel") + + per_token = (not self.is_static_input_scheme + and self.input_qscheme == "per_channel") + self.act_quant_group_shape = GroupShape.PER_TOKEN \ + if per_token else GroupShape.PER_TENSOR self.fp8_linear = Fp8LinearOp( - use_per_token_if_dynamic=self.use_per_token_if_dynamic) + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_quant_group_shape) self.out_dtype = torch.get_default_dtype() @classmethod @@ -44,7 +50,7 @@ class QuarkW8A8Fp8(QuarkScheme): # tensor scales (thus N scales being passed to the kernel), # requantize so we can always run per tensor if self.weight_qscheme == "per_tensor": - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): input_scale = getattr(layer, 'input_scale', None) weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, @@ -82,7 +88,7 @@ class QuarkW8A8Fp8(QuarkScheme): requires_grad=False) else: weight_scale = layer.weight_scale.data - if self.use_per_token_if_dynamic: + if self.act_quant_group_shape == GroupShape.PER_TOKEN: weight_scale = weight_scale.view(-1, 1) layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d6b96774b4e8b..54361a2323c28 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -3,7 +3,7 @@ """This file is used for /tests and /benchmarks""" from collections.abc import Mapping from types import MappingProxyType -from typing import Optional +from typing import ClassVar, NamedTuple, Optional import numpy import torch @@ -12,13 +12,30 @@ from vllm.model_executor.layers.quantization.qqq import ( MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.scalar_type import ScalarType, scalar_types -SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# Use proxy as NamedTuple direct subclasses cannot have static members +class _GroupShape(NamedTuple): + row: int + col: int + + +class GroupShape(_GroupShape): + """ + This class describes the quantization group shape. + It includes static members for common shapes (per-tensor, per-token). + """ + + # Aliases for common quantization group shapes + PER_TENSOR: ClassVar['GroupShape'] + PER_TOKEN: ClassVar['GroupShape'] + + +GroupShape.PER_TENSOR = GroupShape(-1, -1) +GroupShape.PER_TOKEN = GroupShape(1, -1) # Normalize the group_shape to the full extent for any dims that are -1 -def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int, - int]): +def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # -1 means full extent return (group_shape[0] if group_shape[0] > 0 else x.shape[-2], group_shape[1] if group_shape[1] > 0 else x.shape[-1]) @@ -58,7 +75,7 @@ def group_broadcast(t, shape): # (i.e. per-token-per-group) def scaled_quantize( x: torch.Tensor, - group_shape: tuple[int, int], + group_shape: GroupShape, quant_dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: group_shape = _normalize_quant_group_shape(x, group_shape) @@ -99,7 +116,7 @@ def scaled_quantize( def scaled_dequantize( x_q: torch.Tensor, x_s: torch.Tensor, - group_shape: Optional[tuple[int, int]] = None, + group_shape: Optional[GroupShape] = None, out_dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: if group_shape is not None: @@ -332,6 +349,10 @@ def quantize_weights(w: torch.Tensor, ) +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, group_size: int, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index adc67aa64952d..47bb45793281c 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -8,6 +8,9 @@ import torch from vllm import _custom_ops as ops from vllm import envs from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform # Input scaling factors are no longer optional in _scaled_mm starting @@ -271,20 +274,21 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def dispatch_w8a8_scaled_mm( cutlass_fp8_supported: bool, per_tensor_weights: bool, - per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool] -) -> Callable[..., torch.Tensor]: + per_tensor_activations: bool) -> Callable[..., torch.Tensor]: + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: return cutlass_w8a8_scaled_mm if per_tensor_weights and per_tensor_activations: if current_platform.is_rocm(): return rocm_per_tensor_w8a8_scaled_mm return torch_per_tensor_w8a8_scaled_mm - # torch.scaled_mm supports per tensor weights + activations only - # so fallback to naive if per channel or per token - if (use_per_token_if_dynamic and not per_tensor_weights - and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM): + # If torch.scaled_mm supports per-channel (weights) per-token (inputs) + if not per_tensor_weights and not per_tensor_activations \ + and USE_ROWWISE_TORCH_SCALED_MM: return torch_per_token_w8a8_scaled_mm + # Normally, torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token return torch_channelwise_w8a8_scaled_mm @@ -299,11 +303,11 @@ class Fp8LinearOp: """ def __init__(self, + act_quant_static: bool, cutlass_fp8_supported: bool = cutlass_fp8_supported(), - use_per_token_if_dynamic: bool = False, + act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, pad_output: Optional[bool] = None): self.cutlass_fp8_supported = cutlass_fp8_supported - self.use_per_token_if_dynamic = use_per_token_if_dynamic # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. @@ -312,9 +316,16 @@ class Fp8LinearOp: # as it breaks with dynamic shapes. if pad_output is None: config = get_current_vllm_config().compilation_config - pad_output = config.level < CompilationLevel.PIECEWISE - self.output_padding = 17 if ( - pad_output and not current_platform.is_rocm()) else None + pad_output = config.level < CompilationLevel.PIECEWISE and \ + not cutlass_fp8_supported and \ + not current_platform.is_rocm() + + self.output_padding = 17 if pad_output else None + self.act_quant_static = act_quant_static + self.act_quant_group_shape = act_quant_group_shape + self.quant_fp8 = QuantFP8(static=act_quant_static, + group_shape=act_quant_group_shape, + num_token_padding=self.output_padding) def apply( self, @@ -325,8 +336,6 @@ class Fp8LinearOp: input_scale: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, - # TODO(luka) remove this parameter in favor of __init__ - use_per_token_if_dynamic: Optional[bool] = None ) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_scale computed from x. @@ -336,40 +345,27 @@ class Fp8LinearOp: input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[1]] - # TODO(luka) this is here because currently MLA only decides this - # during the forward method instead of in __init__. - if use_per_token_if_dynamic is None: - use_per_token_if_dynamic = self.use_per_token_if_dynamic - if out_dtype is None: out_dtype = input.dtype - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if self.cutlass_fp8_supported: - assert input.dtype != current_platform.fp8_dtype( - ), "FP8 input to cutlass is not currently implemented" - qinput, x_scale = ops.scaled_fp8_quant( + # If input not quantized + # TODO(luka) remove this path if not used anymore + if input.dtype != current_platform.fp8_dtype(): + qinput, x_scale = self.quant_fp8( input_2d, input_scale, - scale_ub=input_scale_ub, - use_per_token_if_dynamic=use_per_token_if_dynamic) + input_scale_ub, + ) else: - if input.dtype != current_platform.fp8_dtype(): - # Maybe apply padding to output, see comment in __init__ - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, - input_scale, - num_token_padding=self.output_padding, - use_per_token_if_dynamic=use_per_token_if_dynamic) - else: - qinput, x_scale = input_2d, input_scale + qinput, x_scale = input_2d, input_scale per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) + # TODO(luka) do this dispatch during init (after ScaledMM refactor) w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( self.cutlass_fp8_supported, per_tensor_weights, - per_tensor_activations, use_per_token_if_dynamic) + per_tensor_activations) return w8a8_scaled_mm_func(qinput=qinput, weight=weight,