mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 12:49:08 +08:00
Remove weight_scale.T special case for SM90 Block FP8 CUTLASS kernel (#28431)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
287bbbeb06
commit
f9a4087182
@ -1,10 +1,18 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Disable DeepGEMM for this benchmark to use CUTLASS
|
||||||
|
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
apply_w8a8_block_fp8_linear,
|
W8A8BlockFp8LinearOp,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||||
@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
|||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
# Create random FP8 tensors
|
# Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp)
|
||||||
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||||
|
|
||||||
|
# Create quantized weight tensor
|
||||||
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||||
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
# Create scales
|
# Create weight scales
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
n_tiles = (N + block_n - 1) // block_n
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
k_tiles = (K + block_k - 1) // block_k
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
|||||||
* factor_for_scale
|
* factor_for_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
# SM90 CUTLASS requires row-major format for scales
|
# Create W8A8BlockFp8LinearOp instance
|
||||||
if use_cutlass and current_platform.is_device_capability(90):
|
weight_group_shape = GroupShape(block_n, block_k)
|
||||||
Bs = Bs.T.contiguous()
|
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
|
||||||
|
|
||||||
|
linear_op = W8A8BlockFp8LinearOp(
|
||||||
|
weight_group_shape=weight_group_shape,
|
||||||
|
act_quant_group_shape=act_quant_group_shape,
|
||||||
|
cutlass_block_fp8_supported=use_cutlass,
|
||||||
|
use_aiter_and_is_supported=False,
|
||||||
|
)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
if use_cutlass:
|
return linear_op.apply(
|
||||||
return apply_w8a8_block_fp8_linear(
|
input=A_ref,
|
||||||
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True
|
weight=B,
|
||||||
)
|
weight_scale=Bs,
|
||||||
else:
|
input_scale=None,
|
||||||
return apply_w8a8_block_fp8_linear(
|
bias=None,
|
||||||
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return run
|
return run
|
||||||
|
|
||||||
|
|||||||
@ -48,7 +48,8 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
|||||||
using ElementBlockScale = float;
|
using ElementBlockScale = float;
|
||||||
|
|
||||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
|
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
|
||||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||||
|
cute::GMMA::Major::MN, cute::GMMA::Major::K>;
|
||||||
|
|
||||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||||
|
|||||||
@ -173,7 +173,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
layer.input_scale = None
|
layer.input_scale = None
|
||||||
|
|
||||||
if self.strategy == QuantizationStrategy.BLOCK:
|
if self.strategy == QuantizationStrategy.BLOCK:
|
||||||
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
|
maybe_post_process_fp8_weight_block(layer)
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
|
maybe_post_process_fp8_weight_block(layer)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -55,17 +55,13 @@ def cutlass_scaled_mm(
|
|||||||
Bs: torch.Tensor,
|
Bs: torch.Tensor,
|
||||||
block_size: list[int],
|
block_size: list[int],
|
||||||
output_dtype: torch.dtype = torch.float16,
|
output_dtype: torch.dtype = torch.float16,
|
||||||
is_hopper: bool | None = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if is_hopper is None:
|
|
||||||
is_hopper = current_platform.is_device_capability(90)
|
|
||||||
return ops.cutlass_scaled_mm(
|
return ops.cutlass_scaled_mm(
|
||||||
A,
|
A,
|
||||||
B.T,
|
B.T,
|
||||||
out_dtype=output_dtype,
|
out_dtype=output_dtype,
|
||||||
scale_a=As,
|
scale_a=As,
|
||||||
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
|
scale_b=Bs.T,
|
||||||
scale_b=Bs if block_size is not None and is_hopper else Bs.T,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -130,7 +126,7 @@ def _padded_cutlass(
|
|||||||
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
|
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
|
||||||
|
|
||||||
output = cutlass_scaled_mm(
|
output = cutlass_scaled_mm(
|
||||||
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True
|
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
|
||||||
)
|
)
|
||||||
return output[0 : qx.shape[0], ...]
|
return output[0 : qx.shape[0], ...]
|
||||||
|
|
||||||
@ -303,7 +299,6 @@ class W8A8BlockFp8LinearOp:
|
|||||||
weight_scale,
|
weight_scale,
|
||||||
list(self.weight_group_shape),
|
list(self.weight_group_shape),
|
||||||
input_2d.dtype,
|
input_2d.dtype,
|
||||||
False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_aiter(
|
def _run_aiter(
|
||||||
@ -1125,9 +1120,7 @@ def process_fp8_weight_block_strategy(
|
|||||||
return weight, weight_scale
|
return weight, weight_scale
|
||||||
|
|
||||||
|
|
||||||
def maybe_post_process_fp8_weight_block(
|
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
|
||||||
layer: torch.nn.Module, cutlass_block_fp8_supported: bool
|
|
||||||
):
|
|
||||||
assert layer.weight_block_size is not None
|
assert layer.weight_block_size is not None
|
||||||
|
|
||||||
from vllm.utils.deep_gemm import (
|
from vllm.utils.deep_gemm import (
|
||||||
@ -1146,15 +1139,6 @@ def maybe_post_process_fp8_weight_block(
|
|||||||
requant_weight_ue8m0_inplace(
|
requant_weight_ue8m0_inplace(
|
||||||
layer.weight.data, layer.weight_scale.data, block_sz
|
layer.weight.data, layer.weight_scale.data, block_sz
|
||||||
)
|
)
|
||||||
# SM90 Block FP8 CUTLASS requires row-major weight scales
|
|
||||||
elif (
|
|
||||||
current_platform.is_device_capability(90)
|
|
||||||
and cutlass_block_fp8_supported
|
|
||||||
and not should_use_deepgemm
|
|
||||||
):
|
|
||||||
layer.weight_scale = torch.nn.Parameter(
|
|
||||||
layer.weight_scale.data.T.contiguous(), requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user