From f9a4087182ffcd9404779fcda876f820b3b26d5f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 11 Nov 2025 09:46:04 -0700 Subject: [PATCH] Remove weight_scale.T special case for SM90 Block FP8 CUTLASS kernel (#28431) Signed-off-by: mgoin --- benchmarks/kernels/bench_block_fp8_gemm.py | 43 +++++++++++++------ .../scaled_mm_blockwise_sm90_fp8_dispatch.cuh | 3 +- .../schemes/compressed_tensors_w8a8_fp8.py | 2 +- .../model_executor/layers/quantization/fp8.py | 2 +- .../layers/quantization/utils/fp8_utils.py | 22 ++-------- 5 files changed, 36 insertions(+), 36 deletions(-) diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py index f1e504499eaf6..11e3ac7f0c1fa 100644 --- a/benchmarks/kernels/bench_block_fp8_gemm.py +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -1,10 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +# Disable DeepGEMM for this benchmark to use CUTLASS +os.environ["VLLM_USE_DEEP_GEMM"] = "0" + import torch from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_w8a8_block_fp8_linear, + W8A8BlockFp8LinearOp, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED, @@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min - # Create random FP8 tensors + # Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp) A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max + # Create quantized weight tensor B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - # Create scales + # Create weight scales block_n, block_k = block_size[0], block_size[1] n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k @@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): * factor_for_scale ) - # SM90 CUTLASS requires row-major format for scales - if use_cutlass and current_platform.is_device_capability(90): - Bs = Bs.T.contiguous() + # Create W8A8BlockFp8LinearOp instance + weight_group_shape = GroupShape(block_n, block_k) + act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization + + linear_op = W8A8BlockFp8LinearOp( + weight_group_shape=weight_group_shape, + act_quant_group_shape=act_quant_group_shape, + cutlass_block_fp8_supported=use_cutlass, + use_aiter_and_is_supported=False, + ) def run(): - if use_cutlass: - return apply_w8a8_block_fp8_linear( - A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True - ) - else: - return apply_w8a8_block_fp8_linear( - A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False - ) + return linear_op.apply( + input=A_ref, + weight=B, + weight_scale=Bs, + input_scale=None, + bias=None, + ) return run diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index 147eb8efc0778..c40d499662714 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -48,7 +48,8 @@ struct cutlass_3x_gemm_fp8_blockwise { using ElementBlockScale = float; using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig< - ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>; + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::GMMA::Major::MN, cute::GMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 6da136cbc8f69..ee99572f5f499 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -173,7 +173,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): layer.input_scale = None if self.strategy == QuantizationStrategy.BLOCK: - maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + maybe_post_process_fp8_weight_block(layer) def apply_weights( self, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 83d136600b77c..cb065eb68b66b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase): return if self.block_quant: - maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + maybe_post_process_fp8_weight_block(layer) def apply( self, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index c63196b893574..0c54cf4def005 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -55,17 +55,13 @@ def cutlass_scaled_mm( Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, - is_hopper: bool | None = None, ) -> torch.Tensor: - if is_hopper is None: - is_hopper = current_platform.is_device_capability(90) return ops.cutlass_scaled_mm( A, B.T, out_dtype=output_dtype, scale_a=As, - # SM90 block FP8 requires row-major scale_b, which we do ahead of time - scale_b=Bs if block_size is not None and is_hopper else Bs.T, + scale_b=Bs.T, ) @@ -130,7 +126,7 @@ def _padded_cutlass( padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale) output = cutlass_scaled_mm( - padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True + padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype ) return output[0 : qx.shape[0], ...] @@ -303,7 +299,6 @@ class W8A8BlockFp8LinearOp: weight_scale, list(self.weight_group_shape), input_2d.dtype, - False, ) def _run_aiter( @@ -1125,9 +1120,7 @@ def process_fp8_weight_block_strategy( return weight, weight_scale -def maybe_post_process_fp8_weight_block( - layer: torch.nn.Module, cutlass_block_fp8_supported: bool -): +def maybe_post_process_fp8_weight_block(layer: torch.nn.Module): assert layer.weight_block_size is not None from vllm.utils.deep_gemm import ( @@ -1146,15 +1139,6 @@ def maybe_post_process_fp8_weight_block( requant_weight_ue8m0_inplace( layer.weight.data, layer.weight_scale.data, block_sz ) - # SM90 Block FP8 CUTLASS requires row-major weight scales - elif ( - current_platform.is_device_capability(90) - and cutlass_block_fp8_supported - and not should_use_deepgemm - ): - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.data.T.contiguous(), requires_grad=False - ) def expert_weight_is_col_major(x: torch.Tensor) -> bool: