From 502640c3f944f8b9702989bd45f8add970b67ea0 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 2 Oct 2025 21:35:13 +0200 Subject: [PATCH] [Perf] Fix and reapply move apply w8a8 block fp8 linear to class (#25696) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: ElizaWszola Signed-off-by: ElizaWszola Signed-off-by: Luka Govedič Signed-off-by: Luka Govedič Co-authored-by: Luka Govedič Co-authored-by: Michael Goin Co-authored-by: Luka Govedič --- .../cutlass_benchmarks/w8a8_benchmarks.py | 4 +- .../benchmark_fp8_block_dense_gemm.py | 4 +- tests/kernels/quantization/test_block_fp8.py | 5 +- .../quantization/test_fp8_quant_group.py | 30 +- .../model_executor/test_enabled_custom_ops.py | 30 -- tests/quantization/test_compressed_tensors.py | 35 ++ vllm/config/vllm.py | 17 + .../compressed_tensors/compressed_tensors.py | 8 + .../schemes/compressed_tensors_w8a8_fp8.py | 37 +- .../model_executor/layers/quantization/fp8.py | 40 +- .../layers/quantization/input_quant_fp8.py | 24 +- .../layers/quantization/utils/fp8_utils.py | 357 ++++++++++++------ vllm/utils/deep_gemm.py | 21 +- 13 files changed, 412 insertions(+), 200 deletions(-) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index a5a5b52f6039..02f8c593392c 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -17,7 +17,7 @@ from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - w8a8_block_fp8_matmul, + w8a8_triton_block_scaled_mm, ) from vllm.utils import FlexibleArgumentParser, cdiv @@ -158,7 +158,7 @@ def bench_fp8( "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16) ), - "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul( + "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm( a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128) ), "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm( diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index db2398fc40a4..2010b8038563 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -9,7 +9,7 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, - w8a8_block_fp8_matmul, + w8a8_triton_block_scaled_mm, ) from vllm.triton_utils import triton from vllm.utils.deep_gemm import ( @@ -63,7 +63,7 @@ def benchmark_shape(m: int, # === vLLM Triton Implementation === def vllm_triton_gemm(): - return w8a8_block_fp8_matmul(A_vllm, + return w8a8_triton_block_scaled_mm(A_vllm, B_vllm, A_scale_vllm, B_scale_vllm, diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 211d1ecfe6e4..e02df540ce9d 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,7 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, native_w8a8_block_matmul) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_block_fp8_matmul) + cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import (fp8_gemm_nt, @@ -91,7 +91,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 720eee62760d..8f2bc6e3cee5 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -20,9 +20,11 @@ from vllm.platforms import current_platform (8, 513, 64), # Non-divisible (native only) ]) @pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("use_ue8m0", [True, False]) @torch.inference_mode() def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, - group_size: int, seed: int) -> None: + group_size: int, seed: int, + use_ue8m0: bool) -> None: """Test QuantFP8 group quantization with various configurations. Tests both CUDA and native implementations, column-major scales, @@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, group_shape = GroupShape(1, group_size) quant_op = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=False) + column_major_scales=False, + use_ue8m0=use_ue8m0) # 1. Test native implementation (always available) x_quant_native, scales_native = quant_op.forward_native(x.clone()) @@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, # 2. Test column-major scales configuration quant_op_col = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=True) + column_major_scales=True, + use_ue8m0=use_ue8m0) _, scales_col = quant_op_col.forward_native(x.clone()) - assert scales_col.shape == (expected_num_groups, batch_size) + assert scales_col.shape == (batch_size, expected_num_groups) + assert scales_col.stride(0) == 1 + assert scales_col.stride(1) == batch_size + + # Test column-major scales consistency + assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8) # 3. Test CUDA implementation (only for divisible dimensions) if is_divisible: @@ -68,21 +77,23 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, @pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("use_ue8m0", [True, False]) @torch.inference_mode() -def test_quantfp8_group_multidimensional(seed: int) -> None: +def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None: current_platform.seed_everything(seed) group_size = 64 # Test with 3D input - batch1, batch2, hidden_dim = 4, 8, 512 + batch1, batch2, hidden_dim = 4, 8, 1024 x_3d = torch.randn( (batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 group_shape = GroupShape(1, group_size) quant_op = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=False) + column_major_scales=False, + use_ue8m0=use_ue8m0) x_quant, scales = quant_op.forward_native(x_3d.clone()) assert x_quant.shape == x_3d.shape @@ -91,9 +102,10 @@ def test_quantfp8_group_multidimensional(seed: int) -> None: # Test column_major_scales with multi-dim quant_op_col = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=True) + column_major_scales=True, + use_ue8m0=use_ue8m0) _, scales_col = quant_op_col.forward_native(x_3d.clone()) - assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) + assert scales_col.shape == (batch1, batch2, hidden_dim // group_size) # Test with 4D input batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256 diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 92ce10a9efc0..200b6ecd5852 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -17,8 +17,6 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.layernorm import (RMSNorm, dispatch_rocm_rmsnorm_func, fused_add_rms_norm, rms_norm) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] @@ -111,34 +109,6 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() -@pytest.mark.skipif( - not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(), - reason="AITER is a feature exclusive for ROCm and FP8_FNUZ") -@pytest.mark.parametrize("use_cutlass", [True, False]) -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) -def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, - use_rocm_aiter_gemm_w8a8_blockscale: str, - monkeypatch): - - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", - use_rocm_aiter_gemm_w8a8_blockscale) - - use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool( - int(use_rocm_aiter_gemm_w8a8_blockscale))) - block_scale_func = dispatch_w8a8_blockscale_func( - use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported) - if use_cutlass: - assert block_scale_func == cutlass_scaled_mm - elif current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_gemm_w8a8_blockscale): - assert block_scale_func == ( - torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) - else: - assert block_scale_func == w8a8_block_fp8_matmul - - @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index c0ab3fbb1062..af8c7ec3b482 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -18,6 +18,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp) from vllm.model_executor.layers.quantization.utils.quant_utils import ( cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -742,3 +745,35 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt, perplexity = llm.generate_prompt_perplexity([prompt])[0] print(perplexity) assert perplexity <= exp_perplexity + + +def test_compressed_tensors_fp8_block_enabled(vllm_runner): + model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK" + with vllm_runner(model_path) as llm: + + fp8_dtype = current_platform.fp8_dtype() + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8) + assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear, + W8A8BlockFp8LinearOp) + + assert qkv_proj.weight.dtype is fp8_dtype + assert qkv_proj.weight_scale.dtype is torch.float32 + assert len(qkv_proj.weight.shape) == 2 + assert len(qkv_proj.weight_scale.shape) == 2 + + input_quant_op = \ + qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op + assert isinstance(input_quant_op, QuantFP8) + assert input_quant_op._forward_method == input_quant_op.forward_cuda + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 107df502e08e..ac40b0fd4783 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -516,6 +516,23 @@ class VllmConfig: " by VLLM_DEBUG_DUMP_PATH to %s", env_path) self.compilation_config.debug_dump_path = env_path + def has_blocked_weights(): + if self.quant_config is not None: + if hasattr(self.quant_config, "weight_block_size"): + return self.quant_config.weight_block_size is not None + elif hasattr(self.quant_config, "has_blocked_weights"): + return self.quant_config.has_blocked_weights() + return False + + # Enable quant_fp8 CUDA ops (TODO disable in follow up) + # On H100 the CUDA kernel is faster than + # native implementation + # https://github.com/vllm-project/vllm/issues/25094 + if has_blocked_weights(): + custom_ops = self.compilation_config.custom_ops + if "none" not in custom_ops and "-quant_fp8" not in custom_ops: + custom_ops.append("+quant_fp8") + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d6550dd16892..3f771ea2abd1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -644,6 +644,14 @@ class CompressedTensorsConfig(QuantizationConfig): # If no matches, return None return None + def has_blocked_weights(self) -> bool: + for scheme in self.target_scheme_map.values(): + weight_quant = scheme.get("weights") + if (weight_quant is not None + and weight_quant.strategy == QuantizationStrategy.BLOCK): + return True + return False + @staticmethod def supports_cutlass_24( weight_quant: Optional[QuantizationArgs], diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 5ad1b15b7ed5..4755c17c5967 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -11,7 +11,7 @@ from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_fp8_block_linear, check_aiter_fp8_linear_support, + W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy, @@ -41,16 +41,30 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme - self.act_q_group_shape = GroupShape.PER_TENSOR \ - if is_static_input_scheme else GroupShape.PER_TOKEN - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_q_group_shape) self.weight_block_size = self.weight_quant.block_structure + if self.weight_block_size is not None: + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) + else: + self.act_q_group_shape = GroupShape.PER_TENSOR \ + if is_static_input_scheme else GroupShape.PER_TOKEN + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + if self.weight_block_size is not None: + assert not self.is_static_input_scheme + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=self.act_q_group_shape, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + else: + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_q_group_shape) + @classmethod def get_min_capability(cls) -> int: # lovelace and up @@ -142,13 +156,14 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if layer.weight_block_size is not None: - return apply_fp8_block_linear( - layer, + if self.weight_block_size is not None: + return self.w8a8_block_fp8_linear.apply( input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, bias=bias, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported) + ) return self.fp8_linear.apply(input=x, weight=layer.weight, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 3ebb20de9996..9b7b3f18baa7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_fp8_block_linear, check_aiter_fp8_linear_support, + W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, expert_weight_is_col_major, maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, @@ -242,15 +242,28 @@ class Fp8LinearMethod(LinearMethodBase): self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" - # Use per-token quantization for better perf if dynamic and cutlass - if not self.act_q_static and cutlass_fp8_supported(): - self.act_q_group_shape = GroupShape.PER_TOKEN + if self.weight_block_size: + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) else: - self.act_q_group_shape = GroupShape.PER_TENSOR + # Use per-token quantization for better perf if dynamic and cutlass + if not self.act_q_static and cutlass_fp8_supported(): + self.act_q_group_shape = GroupShape.PER_TOKEN + else: + self.act_q_group_shape = GroupShape.PER_TENSOR - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape) + if self.block_quant: + assert not self.act_q_static + assert self.weight_block_size is not None + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=self.act_q_group_shape, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + else: + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.act_q_static, + act_quant_group_shape=self.act_q_group_shape) def create_weights( self, @@ -399,12 +412,15 @@ class Fp8LinearMethod(LinearMethodBase): bias=bias) if self.block_quant: - return apply_fp8_block_linear( - layer, + assert self.weight_block_size is not None + + return self.w8a8_block_fp8_linear.apply( input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, bias=bias, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported) + ) return self.fp8_linear.apply(input=x, weight=layer.weight, diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 31182f40b48f..ece3e5817116 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -27,11 +27,14 @@ class QuantFP8(CustomOp): This CustomOp supports both static and dynamic quantization. """ - def __init__(self, - static: bool, - group_shape: GroupShape, - num_token_padding: Optional[int] = None, - column_major_scales: bool = False): + def __init__( + self, + static: bool, + group_shape: GroupShape, + num_token_padding: Optional[int] = None, + column_major_scales: bool = False, + use_ue8m0: Optional[bool] = None, # for Torch compile + ): """ :param static: static or dynamic quantization :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, @@ -46,6 +49,7 @@ class QuantFP8(CustomOp): self.group_shape = group_shape self.num_token_padding = num_token_padding self.column_major_scales = column_major_scales + self.use_ue8m0 = use_ue8m0 self.is_group_quant = group_shape.is_per_group() if self.is_group_quant: @@ -70,7 +74,8 @@ class QuantFP8(CustomOp): x, group_size=self.group_size, column_major_scales=self.column_major_scales, - dtype=_FP8_DTYPE) + dtype=_FP8_DTYPE, + use_ue8m0=self.use_ue8m0) assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape @@ -137,7 +142,10 @@ class QuantFP8(CustomOp): x_grouped = x.view(-1, num_groups, self.group_size) absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() - scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + scales_raw = absmax / _FP8_MAX + if self.use_ue8m0: + scales_raw = torch.exp2(torch.ceil(torch.log2(scales_raw))) + scales = (scales_raw).clamp(min=_FP8_MIN_SCALING_FACTOR) x_scaled = x_grouped / scales x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) @@ -151,6 +159,6 @@ class QuantFP8(CustomOp): scales = scales.reshape(orig_shape[:-1] + (num_groups, )) if self.column_major_scales: - scales = scales.transpose(-2, -1).contiguous() + scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2) return x_quant, scales diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 828111dc299e..13bb69190eae 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -13,8 +13,9 @@ import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( - group_broadcast) + GroupShape, group_broadcast) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.model_executor.parameter import (BlockQuantScaleParameter, @@ -24,6 +25,7 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op from vllm.utils.deep_gemm import (fp8_gemm_nt, is_deep_gemm_e8m0_used, + is_deep_gemm_supported, should_use_deepgemm_for_fp8_linear) logger = init_logger(__name__) @@ -35,6 +37,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz +# We need to pass in the is_hopper flag as argument because the function +# current_platform.is_device_capability() is not supported by Torch compiler. def cutlass_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -42,15 +46,17 @@ def cutlass_scaled_mm( Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, + is_hopper: Optional[bool] = None, ) -> torch.Tensor: + if is_hopper is None: + is_hopper = current_platform.is_device_capability(90) return ops.cutlass_scaled_mm( A, B.T, out_dtype=output_dtype, scale_a=As, # SM90 block FP8 requires row-major scale_b, which we do ahead of time - scale_b=Bs if block_size is not None - and current_platform.is_device_capability(90) else Bs.T) + scale_b=Bs if block_size is not None and is_hopper else Bs.T) def rocm_aiter_gemm_w8a8_blockscale_impl( @@ -96,115 +102,251 @@ if current_platform.is_rocm(): aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) -def dispatch_w8a8_blockscale_func( - use_cutlass: bool, use_aiter_and_is_supported: bool -) -> Callable[[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - list[int], - torch.dtype, -], torch.Tensor]: - if use_cutlass: - return cutlass_scaled_mm - if (use_aiter_and_is_supported): - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale - return w8a8_block_fp8_matmul +# TODO we should be able to change the type of block_size to GroupShape +# after we resolve GroupShape compilation issue +# https://github.com/vllm-project/vllm/issues/25270 +def _w8a8_triton_block_scaled_mm_func( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale, + block_size, output_dtype) + + +def _w8a8_triton_block_scaled_mm_fake( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty((qx.size(0), weight.size(0)), + dtype=output_dtype, + device=qx.device) + + +direct_register_custom_op( + "w8a8_triton_block_scaled_mm_func", + _w8a8_triton_block_scaled_mm_func, + fake_impl=_w8a8_triton_block_scaled_mm_fake, +) + + +def _padded_cutlass( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + pad_multiple = 4 + dim = qx.shape[0] + padded = dim if dim % pad_multiple == 0 else dim + pad_multiple - ( + dim % pad_multiple) + + padded_shape = [padded, *qx.shape[1:]] + padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype) + padded_qx[0:qx.shape[0], ...].copy_(qx) + + padded_x_scale_shape = [*x_scale.shape[1:], padded] + padded_x_scale = torch.ones(padded_x_scale_shape, + device=x_scale.device, + dtype=x_scale.dtype).permute(-1, -2) + padded_x_scale[0:x_scale.shape[0], ...].copy_(x_scale) + + output = cutlass_scaled_mm(padded_qx, weight, padded_x_scale, weight_scale, + block_size, output_dtype, True) + return output[0:qx.shape[0], ...] + + +def _padded_cutlass_fake( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty((qx.size(0), weight.size(0)), + dtype=output_dtype, + device=qx.device) + + +direct_register_custom_op( + "padded_cutlass", + _padded_cutlass, + fake_impl=_padded_cutlass_fake, +) + + +def _fp8_gemm_nt_op(q_input: torch.Tensor, input_scale: torch.Tensor, + weight: torch.Tensor, weight_scale: torch.Tensor, + output: torch.Tensor, use_deep_gemm_e8m0: bool) -> None: + fp8_gemm_nt((q_input, input_scale), (weight, weight_scale), + output, + is_deep_gemm_e8m0_used=use_deep_gemm_e8m0) + + +def _fp8_gemm_nt_op_fake(q_input: torch.Tensor, input_scale: torch.Tensor, + weight: torch.Tensor, weight_scale: torch.Tensor, + output: torch.Tensor, + use_deep_gemm_e8m0: bool) -> None: + return None + + +direct_register_custom_op( + "fp8_gemm_nt_op", + _fp8_gemm_nt_op, + mutates_args=["output"], + fake_impl=_fp8_gemm_nt_op_fake, +) # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 -def apply_w8a8_block_fp8_linear( - input: torch.Tensor, - weight: torch.Tensor, - block_size: list[int], - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, -) -> torch.Tensor: - assert input_scale is None - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[0]] - output_dtype = input.dtype +class W8A8BlockFp8LinearOp: + """ + This class executes a Blocked FP8 linear layer using cutlass if supported + and torch.scaled_mm otherwise. + """ - if should_use_deepgemm_for_fp8_linear(output_dtype, weight): + def __init__( + self, + weight_group_shape: GroupShape, + act_quant_group_shape: GroupShape, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, + ): + self.weight_group_shape = weight_group_shape + self.act_quant_group_shape = act_quant_group_shape + self.is_deep_gemm_supported = is_deep_gemm_supported() + self.is_hopper = current_platform.is_device_capability(90) + self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() + # Get the correct blockscale mul and input quant operations. + # We can't use _dispatch_w8a8_blockscale_op to figure out if we want + # to use deepgemm because we don't know the shape of weights (and + # whether deepgemm supports it) at the init time. + self.w8a8_blockscale_op, self.input_quant_op = \ + self._dispatch_w8a8_blockscale_op( + cutlass_block_fp8_supported, use_aiter_and_is_supported) + self.deepgemm_input_quant_op = (QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=True, + use_ue8m0=self.use_deep_gemm_e8m0) if self.is_deep_gemm_supported + else None) + + def apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] + output_dtype = input.dtype - q_input, x_scale = per_token_group_quant_fp8( - input_2d, - block_size[1], - column_major_scales=True, - ) + if should_use_deepgemm_for_fp8_linear(output_dtype, weight, + self.is_deep_gemm_supported): + output = self._run_deepgemm(input_2d, weight, weight_scale) + else: + output = self.w8a8_blockscale_op(input_2d, weight, weight_scale) + + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + def _run_deepgemm( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.deepgemm_input_quant_op is not None + q_input, input_scale = self.deepgemm_input_quant_op(input_2d) output = torch.empty((q_input.shape[0], weight.shape[0]), dtype=torch.bfloat16, device=q_input.device) - fp8_gemm_nt((q_input, x_scale), (weight, weight_scale), output) - if bias is not None: - output += bias - return output.to(dtype=output_dtype).view(*output_shape) + torch.ops.vllm.fp8_gemm_nt_op(q_input, input_scale, weight, + weight_scale, output, + self.use_deep_gemm_e8m0) + return output - w8a8_blockscale_func = dispatch_w8a8_blockscale_func( - cutlass_block_fp8_supported, use_aiter_and_is_supported) - if cutlass_block_fp8_supported: - num_pad = 0 - if current_platform.is_device_capability(90): - # pad first dimension to be divisible by 4 due to - # cutlass blockwise gemm limitation for hopper - num_pad = 4 - (input_2d.shape[0] % 4) - if num_pad > 0: - input_2d = torch.nn.functional.pad(input_2d, - (0, 0, 0, num_pad), - "constant", 0) - q_input, x_scale = per_token_group_quant_fp8(input_2d, - block_size[1], - column_major_scales=True) - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, - block_size, input.dtype) - if num_pad > 0: - output = output[:-num_pad] - else: - if use_aiter_and_is_supported: - q_input, x_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + def _run_cutlass( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + if self.is_hopper: + return torch.ops.vllm.padded_cutlass(q_input, weight, input_scale, + weight_scale, + list(self.weight_group_shape), + input_2d.dtype) else: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=False) + return cutlass_scaled_mm(q_input, weight, + input_scale, weight_scale, + list(self.weight_group_shape), + input_2d.dtype, False) - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, - block_size, input.dtype) + def _run_aiter( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.act_quant_group_shape == GroupShape(1, 128) + q_input, input_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( + q_input, weight, input_scale, weight_scale, + self.weight_group_shape, input_2d.dtype) - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + def _run_triton( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( + q_input, weight, input_scale, weight_scale, + self.weight_group_shape, input_2d.dtype) - -def apply_w8a8_block_fp8_linear_fake( - input: torch.Tensor, - weight: torch.Tensor, - block_size: list[int], - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, -) -> torch.Tensor: - output_shape = [*input.shape[:-1], weight.shape[0]] - return torch.empty(output_shape, dtype=input.dtype, device=input.device) - - -if not current_platform.is_cpu(): - direct_register_custom_op( - op_name="apply_w8a8_block_fp8_linear", - op_func=apply_w8a8_block_fp8_linear, - mutates_args=[], - fake_impl=apply_w8a8_block_fp8_linear_fake, - ) + def _dispatch_w8a8_blockscale_op( + self, + use_cutlass: bool, + use_aiter_and_is_supported: bool, + ) -> tuple[Callable[[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], torch.Tensor], Optional[QuantFP8]]: + if use_cutlass: + return self._run_cutlass, (QuantFP8(False, + self.act_quant_group_shape, + column_major_scales=True, + use_ue8m0=False)) + if use_aiter_and_is_supported: + return self._run_aiter, None + return self._run_triton, (QuantFP8(False, + self.act_quant_group_shape, + column_major_scales=False, + use_ue8m0=False)) def input_to_float8( @@ -456,7 +598,7 @@ def per_token_group_quant_fp8( @triton.jit -def _w8a8_block_fp8_matmul( +def _w8a8_triton_block_scaled_mm( # Pointers to inputs and output A, B, @@ -581,7 +723,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, return None -def w8a8_block_fp8_matmul( +def w8a8_triton_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -641,7 +783,7 @@ def w8a8_block_fp8_matmul( return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - _w8a8_block_fp8_matmul[grid]( + _w8a8_triton_block_scaled_mm[grid]( A, B, C, @@ -924,25 +1066,6 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module, layer.weight_scale.data.T.contiguous(), requires_grad=False) -def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor, - bias: Optional[torch.Tensor], - cutlass_block_fp8_supported: bool, - use_aiter_and_is_supported: bool) -> torch.Tensor: - """Apply block-wise FP8 linear operation.""" - assert layer.weight_block_size is not None - - return torch.ops.vllm.apply_w8a8_block_fp8_linear( - input=input, - weight=layer.weight, - block_size=layer.weight_block_size, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias, - cutlass_block_fp8_supported=cutlass_block_fp8_supported, - use_aiter_and_is_supported=use_aiter_and_is_supported, - ) - - def expert_weight_is_col_major(x: torch.Tensor) -> bool: assert x.dim() == 3 b, m, n = x.shape diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 4f05f0bc35cc..7e28d5e452af 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -9,7 +9,7 @@ from __future__ import annotations import functools import importlib import os -from typing import Any, Callable, NoReturn +from typing import Any, Callable, NoReturn, Optional import torch @@ -136,9 +136,12 @@ def fp8_gemm_nt(*args, **kwargs): _lazy_init() if _fp8_gemm_nt_impl is None: return _missing(*args, **kwargs) - return _fp8_gemm_nt_impl(*args, - disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), - **kwargs) + if "is_deep_gemm_e8m0_used" in kwargs: + use_ue8m0 = kwargs["is_deep_gemm_e8m0_used"] + del kwargs["is_deep_gemm_e8m0_used"] + else: + use_ue8m0 = is_deep_gemm_e8m0_used() + return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs) def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): @@ -301,9 +304,13 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): return 1 - sim -def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, - weight: torch.Tensor): - return (is_deep_gemm_supported() and output_dtype == torch.bfloat16 +def should_use_deepgemm_for_fp8_linear( + output_dtype: torch.dtype, + weight: torch.Tensor, + supports_deep_gemm: Optional[bool] = None): + if supports_deep_gemm is None: + supports_deep_gemm = is_deep_gemm_supported() + return (supports_deep_gemm and output_dtype == torch.bfloat16 and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)