mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 00:14:34 +08:00
Remove weight_scale.T special case for SM90 Block FP8 CUTLASS kernel (#28431)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
287bbbeb06
commit
f9a4087182
@ -1,10 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
|
||||
# Disable DeepGEMM for this benchmark to use CUTLASS
|
||||
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_w8a8_block_fp8_linear,
|
||||
W8A8BlockFp8LinearOp,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
# Create random FP8 tensors
|
||||
# Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp)
|
||||
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||
|
||||
# Create quantized weight tensor
|
||||
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
# Create scales
|
||||
# Create weight scales
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
# SM90 CUTLASS requires row-major format for scales
|
||||
if use_cutlass and current_platform.is_device_capability(90):
|
||||
Bs = Bs.T.contiguous()
|
||||
# Create W8A8BlockFp8LinearOp instance
|
||||
weight_group_shape = GroupShape(block_n, block_k)
|
||||
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
|
||||
|
||||
linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=weight_group_shape,
|
||||
act_quant_group_shape=act_quant_group_shape,
|
||||
cutlass_block_fp8_supported=use_cutlass,
|
||||
use_aiter_and_is_supported=False,
|
||||
)
|
||||
|
||||
def run():
|
||||
if use_cutlass:
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True
|
||||
)
|
||||
else:
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False
|
||||
)
|
||||
return linear_op.apply(
|
||||
input=A_ref,
|
||||
weight=B,
|
||||
weight_scale=Bs,
|
||||
input_scale=None,
|
||||
bias=None,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@ -48,7 +48,8 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
|
||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||
cute::GMMA::Major::MN, cute::GMMA::Major::K>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
@ -173,7 +173,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
layer.input_scale = None
|
||||
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
|
||||
@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return
|
||||
|
||||
if self.block_quant:
|
||||
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -55,17 +55,13 @@ def cutlass_scaled_mm(
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
is_hopper: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
if is_hopper is None:
|
||||
is_hopper = current_platform.is_device_capability(90)
|
||||
return ops.cutlass_scaled_mm(
|
||||
A,
|
||||
B.T,
|
||||
out_dtype=output_dtype,
|
||||
scale_a=As,
|
||||
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
|
||||
scale_b=Bs if block_size is not None and is_hopper else Bs.T,
|
||||
scale_b=Bs.T,
|
||||
)
|
||||
|
||||
|
||||
@ -130,7 +126,7 @@ def _padded_cutlass(
|
||||
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
|
||||
|
||||
output = cutlass_scaled_mm(
|
||||
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True
|
||||
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
|
||||
)
|
||||
return output[0 : qx.shape[0], ...]
|
||||
|
||||
@ -303,7 +299,6 @@ class W8A8BlockFp8LinearOp:
|
||||
weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
input_2d.dtype,
|
||||
False,
|
||||
)
|
||||
|
||||
def _run_aiter(
|
||||
@ -1125,9 +1120,7 @@ def process_fp8_weight_block_strategy(
|
||||
return weight, weight_scale
|
||||
|
||||
|
||||
def maybe_post_process_fp8_weight_block(
|
||||
layer: torch.nn.Module, cutlass_block_fp8_supported: bool
|
||||
):
|
||||
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
|
||||
assert layer.weight_block_size is not None
|
||||
|
||||
from vllm.utils.deep_gemm import (
|
||||
@ -1146,15 +1139,6 @@ def maybe_post_process_fp8_weight_block(
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.weight.data, layer.weight_scale.data, block_sz
|
||||
)
|
||||
# SM90 Block FP8 CUTLASS requires row-major weight scales
|
||||
elif (
|
||||
current_platform.is_device_capability(90)
|
||||
and cutlass_block_fp8_supported
|
||||
and not should_use_deepgemm
|
||||
):
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user