Remove weight_scale.T special case for SM90 Block FP8 CUTLASS kernel (#28431)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-11-11 09:46:04 -07:00 committed by GitHub
parent 287bbbeb06
commit f9a4087182
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 36 additions and 36 deletions

View File

@ -1,10 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
# Disable DeepGEMM for this benchmark to use CUTLASS
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
import torch
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear,
W8A8BlockFp8LinearOp,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
# Create random FP8 tensors
# Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp)
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
# Create quantized weight tensor
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
# Create scales
# Create weight scales
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
* factor_for_scale
)
# SM90 CUTLASS requires row-major format for scales
if use_cutlass and current_platform.is_device_capability(90):
Bs = Bs.T.contiguous()
# Create W8A8BlockFp8LinearOp instance
weight_group_shape = GroupShape(block_n, block_k)
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=weight_group_shape,
act_quant_group_shape=act_quant_group_shape,
cutlass_block_fp8_supported=use_cutlass,
use_aiter_and_is_supported=False,
)
def run():
if use_cutlass:
return apply_w8a8_block_fp8_linear(
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True
)
else:
return apply_w8a8_block_fp8_linear(
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False
)
return linear_op.apply(
input=A_ref,
weight=B,
weight_scale=Bs,
input_scale=None,
bias=None,
)
return run

View File

@ -48,7 +48,8 @@ struct cutlass_3x_gemm_fp8_blockwise {
using ElementBlockScale = float;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::GMMA::Major::MN, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());

View File

@ -173,7 +173,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer.input_scale = None
if self.strategy == QuantizationStrategy.BLOCK:
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
maybe_post_process_fp8_weight_block(layer)
def apply_weights(
self,

View File

@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase):
return
if self.block_quant:
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
maybe_post_process_fp8_weight_block(layer)
def apply(
self,

View File

@ -55,17 +55,13 @@ def cutlass_scaled_mm(
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
is_hopper: bool | None = None,
) -> torch.Tensor:
if is_hopper is None:
is_hopper = current_platform.is_device_capability(90)
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=output_dtype,
scale_a=As,
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
scale_b=Bs if block_size is not None and is_hopper else Bs.T,
scale_b=Bs.T,
)
@ -130,7 +126,7 @@ def _padded_cutlass(
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
output = cutlass_scaled_mm(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
)
return output[0 : qx.shape[0], ...]
@ -303,7 +299,6 @@ class W8A8BlockFp8LinearOp:
weight_scale,
list(self.weight_group_shape),
input_2d.dtype,
False,
)
def _run_aiter(
@ -1125,9 +1120,7 @@ def process_fp8_weight_block_strategy(
return weight, weight_scale
def maybe_post_process_fp8_weight_block(
layer: torch.nn.Module, cutlass_block_fp8_supported: bool
):
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
assert layer.weight_block_size is not None
from vllm.utils.deep_gemm import (
@ -1146,15 +1139,6 @@ def maybe_post_process_fp8_weight_block(
requant_weight_ue8m0_inplace(
layer.weight.data, layer.weight_scale.data, block_sz
)
# SM90 Block FP8 CUTLASS requires row-major weight scales
elif (
current_platform.is_device_capability(90)
and cutlass_block_fp8_supported
and not should_use_deepgemm
):
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(), requires_grad=False
)
def expert_weight_is_col_major(x: torch.Tensor) -> bool: