vllm/tests/kernels/quantization/test_block_fp8.py
Kate Cheng 3d429d63a6 Enable linear deepgemm_swapAB
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
2025-12-24 11:19:39 -08:00

449 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/sgl-project/sglang/pull/2575
import itertools
import pytest
import torch
from tests.kernels.quant_utils import (
native_per_token_group_quant_fp8,
native_w8a8_block_matmul,
)
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm,
per_token_group_quant_fp8,
w8a8_triton_block_scaled_mm,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
per_block_cast_to_fp8,
should_use_deepgemm_for_fp8_linear,
)
from vllm.utils.import_utils import has_deep_gemm
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 2050]
D = [512, 4096, 5120, 13824]
GROUP_SIZE = [64, 128, 512]
M = [1, 7, 8, 83, 84, 4096]
N = [128, 512, 7168, 7748, 13824]
K = [256, 3884, 4096, 13824, 16384]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
BLOCK_SIZE = [[128, 128]]
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
SEEDS = [0]
# Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda")
@pytest.fixture(autouse=True)
def setup_cuda():
torch.set_default_device("cuda")
@pytest.mark.skipif(
current_platform.is_fp8_fnuz(),
reason="This platform supports e4m3fnuz, not e4m3fn.",
)
@pytest.mark.parametrize(
"num_tokens,d,dtype,group_size,seed",
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
)
@torch.inference_mode()
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
torch.manual_seed(seed)
x = torch.rand(num_tokens, d, dtype=dtype)
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
out, scale = per_token_group_quant_fp8(x, group_size)
assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
assert torch.allclose(scale, ref_scale)
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
)
@torch.inference_mode()
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed)
factor_for_scale = 1e-2
fp8_info = torch.finfo(current_platform.fp8_dtype())
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.001
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="CUTLASS only supported on CUDA platform."
)
@torch.inference_mode()
def test_w8a8_block_fp8_cutlass_matmul():
# Test simple case where weight.shape % 128 != 0,
# like in DSV3 kv_a_proj_with_mqa
M = 32
N = 576
K = 7168
block_size = [128, 128]
out_dtype = torch.bfloat16
seed = 0
torch.manual_seed(seed)
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
# Hopper requires row-major format for scales
Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(90) else Bs
A_fp8, As = per_token_group_quant_fp8(
A_fp32, block_size[1], column_major_scales=False
)
# CUTLASS uses column-major format for scales
A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8(
A_fp32, block_size[1], column_major_scales=True
)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
out = cutlass_scaled_mm(
A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, block_size, out_dtype
)
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.001
@pytest.mark.skipif(
current_platform.is_fp8_fnuz(),
reason="This platform supports e4m3fnuz, not e4m3fn.",
)
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
)
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max = fp8_info.max
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
# only aligned sizes are supported by deepgemm
if not should_use_deepgemm_for_fp8_linear(
output_dtype=out_dtype, weight=B_fp32, supports_deep_gemm=True
):
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
As = As_fp8.to(torch.float32)
Bs = Bs_fp8.to(torch.float32)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
out = torch.zeros((M, N), device="cuda", dtype=out_dtype)
assert As_fp8.shape == (M, (K + 127) // 128), (
f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
)
fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.001
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
)
@torch.inference_mode()
def test_flashinfer_block_gemm_matmul(M, N, K, block_size, out_dtype, seed):
"""
Test FlashInfer FP8 block-scale GEMM through W8A8BlockFp8LinearOp.
This tests the FP8 + FP8 → BF16 path (W8A8 full quantization).
Matches TensorRT-LLM's test_fp8_block_scale_gemm behavior.
"""
import os
from vllm.utils.flashinfer import has_flashinfer_block_gemm
if not has_flashinfer_block_gemm():
pytest.skip(
"FlashInfer block GEMM not available (requires SM90+ and FlashInfer)"
)
# Skip tests for dimensions that don't have pre-compiled kernels in FlashInfer
# These cause CUDA runtime errors
if K == 3884 or N == 7748:
pytest.skip(f"FlashInfer does not have pre-compiled kernels for K={K} or N={N}")
# Enable FlashInfer backend (required for W8A8BlockFp8LinearOp to use FlashInfer)
os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1"
# Reload envs module to pick up the env var change
import importlib
from vllm import envs
importlib.reload(envs)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
torch.manual_seed(seed)
# Create BF16 inputs (normalized like TRT-LLM)
A_bf16 = torch.randn(M, K, dtype=torch.bfloat16) / K
B_bf16 = torch.randn(N, K, dtype=torch.bfloat16) / K
# Quantize weight with per-block scales
B_fp8, Bs = per_block_cast_to_fp8(B_bf16, block_size=block_size)
# Create W8A8BlockFp8LinearOp to handle input quantization
block_n, block_k = block_size[0], block_size[1]
weight_group_shape = GroupShape(block_n, block_k)
act_quant_group_shape = GroupShape(1, block_k) # Per-token quantization
linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=weight_group_shape,
act_quant_group_shape=act_quant_group_shape,
cutlass_block_fp8_supported=False, # Disable CUTLASS
use_aiter_and_is_supported=False, # Disable AITER
)
# Verify FlashInfer backend is selected
assert linear_op.w8a8_blockscale_op == linear_op._run_flashinfer, (
"FlashInfer backend not selected! "
"Make sure VLLM_USE_FLASHINFER_FP8_LINEAR=1 is set before running tests."
)
# Compute reference: BF16 × BF16 matmul (before quantization)
ref_out = torch.matmul(A_bf16, B_bf16.T)
# Run W8A8 FlashInfer GEMM (input will be quantized internally)
out = linear_op.apply(
input=A_bf16,
weight=B_fp8,
weight_scale=Bs,
input_scale=None, # Will quantize dynamically
bias=None,
)
# Compare results using TensorRT-LLM's calc_diff metric
# This measures normalized similarity: sim = 2*<x,y> / (||x||² + ||y||²)
out_fp64 = out.to(torch.float64)
ref_fp64 = ref_out.to(torch.float64)
denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum()
sim = 2 * (out_fp64 * ref_fp64).sum() / denominator
diff = 1 - sim
# W8A8 threshold from TensorRT-LLM: diff < 0.001 (99.9% similarity)
assert diff < 0.001, (
f"Similarity difference {diff:.6f} exceeds threshold (similarity: {sim:.6f})"
)
@pytest.mark.parametrize(
"M,N,K,block_size,seed",
[
(1, 1024, 4096, [128, 128], 0),
(32, 4096, 512, [128, 128], 0),
(128, 1024, 4096, [128, 128], 0),
],
)
@pytest.mark.parametrize(
"input_dtype,weight_dtype",
[
(torch.bfloat16, torch.bfloat16), # BF16 + BF16 (internal quantization)
(torch.bfloat16, torch.float8_e4m3fn), # BF16 + FP8 (weight-only)
(torch.float8_e4m3fn, torch.float8_e4m3fn), # FP8 + FP8 (W8A8)
],
)
@torch.inference_mode()
def test_flashinfer_block_gemm_dtypes(
M, N, K, block_size, input_dtype, weight_dtype, seed
):
"""
Test all three supported dtype combinations for FlashInfer FP8 block-scale GEMM.
Tests:
- BF16 + BF16 → BF16: Both inputs BF16, internal quantization
- BF16 + FP8 → BF16: Weight-only quantization
- FP8 + FP8 → BF16: W8A8 full quantization
This mirrors FlashInfer's own test_fp8_blockscale_gemm_dtypes and TRT-LLM's tests.
"""
from vllm.utils.flashinfer import has_flashinfer_block_gemm
if not has_flashinfer_block_gemm():
pytest.skip(
"FlashInfer block GEMM not available (requires SM90+ and FlashInfer)"
)
from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import (
flashinfer_block_gemm,
)
# Add debug output to verify test execution
print(f"\n{'=' * 80}")
print(f"TEST: M={M}, N={N}, K={K} | Input: {input_dtype}, Weight: {weight_dtype}")
print(f"{'=' * 80}")
torch.manual_seed(seed)
# Create BF16 data for reference (same as FlashInfer tests)
input_bf16 = torch.randn(M, K, dtype=torch.bfloat16)
weight_bf16 = torch.randn(N, K, dtype=torch.bfloat16)
# Quantize input based on dtype
if input_dtype == torch.float8_e4m3fn:
input_tensor, input_scale = per_token_group_quant_fp8(input_bf16, block_size[1])
else:
input_tensor, input_scale = input_bf16, None
# Quantize weight based on dtype
if weight_dtype == torch.float8_e4m3fn:
weight_tensor, weight_scale = per_block_cast_to_fp8(
weight_bf16, block_size=block_size
)
else:
weight_tensor, weight_scale = weight_bf16, None
# Run FlashInfer FP8 block-scale GEMM
output = flashinfer_block_gemm(
input=input_tensor,
weight=weight_tensor,
scales_a=input_scale,
scales_b=weight_scale,
out_dtype=torch.bfloat16,
)
# Verify output properties
assert output.shape == (M, N), f"Expected shape {(M, N)}, got {output.shape}"
assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}"
# Compute reference based on dtype combination
if input_dtype == torch.float8_e4m3fn and weight_dtype == torch.float8_e4m3fn:
# W8A8: Compare against dequantized FP8 reference (tests kernel correctness)
block_n, block_k = block_size[0], block_size[1]
k_tiles = (K + block_k - 1) // block_k
n_tiles = (N + block_n - 1) // block_n
input_dequant = torch.zeros_like(input_bf16)
for i in range(M):
for k_tile in range(k_tiles):
start, end = k_tile * block_k, min((k_tile + 1) * block_k, K)
input_dequant[i, start:end] = (
input_tensor[i, start:end].to(torch.bfloat16)
* input_scale[i, k_tile]
)
weight_dequant = torch.zeros_like(weight_bf16)
for j in range(N):
for k_tile in range(k_tiles):
start, end = k_tile * block_k, min((k_tile + 1) * block_k, K)
weight_dequant[j, start:end] = (
weight_tensor[j, start:end].to(torch.bfloat16)
* weight_scale[j // block_n, k_tile]
)
reference = torch.matmul(input_dequant, weight_dequant.T)
# W8A8: Use TRT-LLM's calc_diff metric with strict threshold
out_fp64 = output.to(torch.float64)
ref_fp64 = reference.to(torch.float64)
denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum()
sim = 2 * (out_fp64 * ref_fp64).sum() / denominator
diff = 1 - sim
# W8A8 achieves very high accuracy: diff < 0.001 (99.9% similarity)
assert diff < 0.001, (
f"W8A8 similarity difference {diff:.6f} too high (expected < 0.001, similarity: {sim:.6f})"
)
else:
# BF16+BF16 or BF16+FP8: Compare against original BF16 reference
reference = torch.matmul(input_bf16, weight_bf16.T)
out_fp64 = output.to(torch.float64)
ref_fp64 = reference.to(torch.float64)
denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum()
sim = 2 * (out_fp64 * ref_fp64).sum() / denominator
diff = 1 - sim
if input_dtype == torch.bfloat16 and weight_dtype == torch.bfloat16:
# BF16+BF16: Highest accuracy (internal quantization)
threshold = 0.001
threshold_desc = "0.1%"
elif input_dtype == torch.bfloat16 and weight_dtype == torch.float8_e4m3fn:
# BF16+FP8: Weight-only quantization, higher error
threshold = 0.01
threshold_desc = "1%"
else:
# Other combinations
threshold = 0.01
threshold_desc = "1%"
assert diff < threshold, (
f"Similarity difference {diff:.6f} too high for "
f"{input_dtype} + {weight_dtype} (expected < {threshold_desc}, similarity: {sim:.6f})"
)