Enable linear deepgemm_swapAB

Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
This commit is contained in:
Kate Cheng 2025-11-21 15:59:27 -08:00 committed by Jhao-Ting Chen
parent 09dc7c690c
commit 3d429d63a6
6 changed files with 396 additions and 1 deletions

View File

@ -3,8 +3,10 @@
import os
# Disable DeepGEMM for this benchmark to use CUTLASS
# Disable DeepGEMM for this benchmark
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
# Enable FlashInfer FP8 linear (will be used when provider="flashinfer-block-fp8")
os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1"
import torch
@ -94,6 +96,15 @@ plot_title = "BF16 vs W8A8 Block FP8 GEMMs"
if CUTLASS_BLOCK_FP8_SUPPORTED:
available_providers.append("w8a8-block-fp8-cutlass")
# Check if FlashInfer block GEMM is available
try:
from vllm.utils.flashinfer import has_flashinfer_block_gemm
if has_flashinfer_block_gemm():
available_providers.append("flashinfer-block-fp8")
except ImportError:
pass
@vllm_triton.testing.perf_report(
vllm_triton.testing.Benchmark(
@ -134,6 +145,14 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)):
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
lambda: run_w8a8_cutlass(), quantiles=quantiles
)
elif provider == "flashinfer-block-fp8":
# Use the same W8A8 setup as other providers for fair comparison
run_w8a8_flashinfer = build_w8a8_block_fp8_runner(
M, N, K, block_size, device, use_cutlass=False
)
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
lambda: run_w8a8_flashinfer(), quantiles=quantiles
)
else:
raise ValueError(f"Unknown provider: {provider}")

View File

@ -205,3 +205,244 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.001
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
)
@torch.inference_mode()
def test_flashinfer_block_gemm_matmul(M, N, K, block_size, out_dtype, seed):
"""
Test FlashInfer FP8 block-scale GEMM through W8A8BlockFp8LinearOp.
This tests the FP8 + FP8 BF16 path (W8A8 full quantization).
Matches TensorRT-LLM's test_fp8_block_scale_gemm behavior.
"""
import os
from vllm.utils.flashinfer import has_flashinfer_block_gemm
if not has_flashinfer_block_gemm():
pytest.skip(
"FlashInfer block GEMM not available (requires SM90+ and FlashInfer)"
)
# Skip tests for dimensions that don't have pre-compiled kernels in FlashInfer
# These cause CUDA runtime errors
if K == 3884 or N == 7748:
pytest.skip(f"FlashInfer does not have pre-compiled kernels for K={K} or N={N}")
# Enable FlashInfer backend (required for W8A8BlockFp8LinearOp to use FlashInfer)
os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1"
# Reload envs module to pick up the env var change
import importlib
from vllm import envs
importlib.reload(envs)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
torch.manual_seed(seed)
# Create BF16 inputs (normalized like TRT-LLM)
A_bf16 = torch.randn(M, K, dtype=torch.bfloat16) / K
B_bf16 = torch.randn(N, K, dtype=torch.bfloat16) / K
# Quantize weight with per-block scales
B_fp8, Bs = per_block_cast_to_fp8(B_bf16, block_size=block_size)
# Create W8A8BlockFp8LinearOp to handle input quantization
block_n, block_k = block_size[0], block_size[1]
weight_group_shape = GroupShape(block_n, block_k)
act_quant_group_shape = GroupShape(1, block_k) # Per-token quantization
linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=weight_group_shape,
act_quant_group_shape=act_quant_group_shape,
cutlass_block_fp8_supported=False, # Disable CUTLASS
use_aiter_and_is_supported=False, # Disable AITER
)
# Verify FlashInfer backend is selected
assert linear_op.w8a8_blockscale_op == linear_op._run_flashinfer, (
"FlashInfer backend not selected! "
"Make sure VLLM_USE_FLASHINFER_FP8_LINEAR=1 is set before running tests."
)
# Compute reference: BF16 × BF16 matmul (before quantization)
ref_out = torch.matmul(A_bf16, B_bf16.T)
# Run W8A8 FlashInfer GEMM (input will be quantized internally)
out = linear_op.apply(
input=A_bf16,
weight=B_fp8,
weight_scale=Bs,
input_scale=None, # Will quantize dynamically
bias=None,
)
# Compare results using TensorRT-LLM's calc_diff metric
# This measures normalized similarity: sim = 2*<x,y> / (||x||² + ||y||²)
out_fp64 = out.to(torch.float64)
ref_fp64 = ref_out.to(torch.float64)
denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum()
sim = 2 * (out_fp64 * ref_fp64).sum() / denominator
diff = 1 - sim
# W8A8 threshold from TensorRT-LLM: diff < 0.001 (99.9% similarity)
assert diff < 0.001, (
f"Similarity difference {diff:.6f} exceeds threshold (similarity: {sim:.6f})"
)
@pytest.mark.parametrize(
"M,N,K,block_size,seed",
[
(1, 1024, 4096, [128, 128], 0),
(32, 4096, 512, [128, 128], 0),
(128, 1024, 4096, [128, 128], 0),
],
)
@pytest.mark.parametrize(
"input_dtype,weight_dtype",
[
(torch.bfloat16, torch.bfloat16), # BF16 + BF16 (internal quantization)
(torch.bfloat16, torch.float8_e4m3fn), # BF16 + FP8 (weight-only)
(torch.float8_e4m3fn, torch.float8_e4m3fn), # FP8 + FP8 (W8A8)
],
)
@torch.inference_mode()
def test_flashinfer_block_gemm_dtypes(
M, N, K, block_size, input_dtype, weight_dtype, seed
):
"""
Test all three supported dtype combinations for FlashInfer FP8 block-scale GEMM.
Tests:
- BF16 + BF16 BF16: Both inputs BF16, internal quantization
- BF16 + FP8 BF16: Weight-only quantization
- FP8 + FP8 BF16: W8A8 full quantization
This mirrors FlashInfer's own test_fp8_blockscale_gemm_dtypes and TRT-LLM's tests.
"""
from vllm.utils.flashinfer import has_flashinfer_block_gemm
if not has_flashinfer_block_gemm():
pytest.skip(
"FlashInfer block GEMM not available (requires SM90+ and FlashInfer)"
)
from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import (
flashinfer_block_gemm,
)
# Add debug output to verify test execution
print(f"\n{'=' * 80}")
print(f"TEST: M={M}, N={N}, K={K} | Input: {input_dtype}, Weight: {weight_dtype}")
print(f"{'=' * 80}")
torch.manual_seed(seed)
# Create BF16 data for reference (same as FlashInfer tests)
input_bf16 = torch.randn(M, K, dtype=torch.bfloat16)
weight_bf16 = torch.randn(N, K, dtype=torch.bfloat16)
# Quantize input based on dtype
if input_dtype == torch.float8_e4m3fn:
input_tensor, input_scale = per_token_group_quant_fp8(input_bf16, block_size[1])
else:
input_tensor, input_scale = input_bf16, None
# Quantize weight based on dtype
if weight_dtype == torch.float8_e4m3fn:
weight_tensor, weight_scale = per_block_cast_to_fp8(
weight_bf16, block_size=block_size
)
else:
weight_tensor, weight_scale = weight_bf16, None
# Run FlashInfer FP8 block-scale GEMM
output = flashinfer_block_gemm(
input=input_tensor,
weight=weight_tensor,
scales_a=input_scale,
scales_b=weight_scale,
out_dtype=torch.bfloat16,
)
# Verify output properties
assert output.shape == (M, N), f"Expected shape {(M, N)}, got {output.shape}"
assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}"
# Compute reference based on dtype combination
if input_dtype == torch.float8_e4m3fn and weight_dtype == torch.float8_e4m3fn:
# W8A8: Compare against dequantized FP8 reference (tests kernel correctness)
block_n, block_k = block_size[0], block_size[1]
k_tiles = (K + block_k - 1) // block_k
n_tiles = (N + block_n - 1) // block_n
input_dequant = torch.zeros_like(input_bf16)
for i in range(M):
for k_tile in range(k_tiles):
start, end = k_tile * block_k, min((k_tile + 1) * block_k, K)
input_dequant[i, start:end] = (
input_tensor[i, start:end].to(torch.bfloat16)
* input_scale[i, k_tile]
)
weight_dequant = torch.zeros_like(weight_bf16)
for j in range(N):
for k_tile in range(k_tiles):
start, end = k_tile * block_k, min((k_tile + 1) * block_k, K)
weight_dequant[j, start:end] = (
weight_tensor[j, start:end].to(torch.bfloat16)
* weight_scale[j // block_n, k_tile]
)
reference = torch.matmul(input_dequant, weight_dequant.T)
# W8A8: Use TRT-LLM's calc_diff metric with strict threshold
out_fp64 = output.to(torch.float64)
ref_fp64 = reference.to(torch.float64)
denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum()
sim = 2 * (out_fp64 * ref_fp64).sum() / denominator
diff = 1 - sim
# W8A8 achieves very high accuracy: diff < 0.001 (99.9% similarity)
assert diff < 0.001, (
f"W8A8 similarity difference {diff:.6f} too high (expected < 0.001, similarity: {sim:.6f})"
)
else:
# BF16+BF16 or BF16+FP8: Compare against original BF16 reference
reference = torch.matmul(input_bf16, weight_bf16.T)
out_fp64 = output.to(torch.float64)
ref_fp64 = reference.to(torch.float64)
denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum()
sim = 2 * (out_fp64 * ref_fp64).sum() / denominator
diff = 1 - sim
if input_dtype == torch.bfloat16 and weight_dtype == torch.bfloat16:
# BF16+BF16: Highest accuracy (internal quantization)
threshold = 0.001
threshold_desc = "0.1%"
elif input_dtype == torch.bfloat16 and weight_dtype == torch.float8_e4m3fn:
# BF16+FP8: Weight-only quantization, higher error
threshold = 0.01
threshold_desc = "1%"
else:
# Other combinations
threshold = 0.01
threshold_desc = "1%"
assert diff < threshold, (
f"Similarity difference {diff:.6f} too high for "
f"{input_dtype} + {weight_dtype} (expected < {threshold_desc}, similarity: {sim:.6f})"
)

View File

@ -168,6 +168,7 @@ if TYPE_CHECKING:
"relax",
] = "relax"
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_USE_FLASHINFER_FP8_LINEAR: bool = False
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
@ -1208,6 +1209,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool(
int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))
),
# Allow use of FlashInfer FP8 block-scale GEMM for linear layers.
# This uses TensorRT-LLM kernels and requires SM90+ (Hopper).
"VLLM_USE_FLASHINFER_FP8_LINEAR": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER_FP8_LINEAR", "0"))
),
# Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0"))

View File

@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
FlashInfer FP8 Block-Scale GEMM wrapper for vLLM.
This module provides a thin wrapper around FlashInfer's FP8 block-scale GEMM
implementation, which uses TensorRT-LLM's optimized kernels for NVIDIA Hopper (SM90+).
"""
import torch
def flashinfer_block_gemm(
input: torch.Tensor,
weight: torch.Tensor,
scales_a: torch.Tensor | None,
scales_b: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
"""
Wrapper for FlashInfer's FP8 block-scale GEMM.
Computes: output = (input @ weight.T) with per-block scaling for quantization.
Supports three modes:
1. BF16 + BF16 BF16: Both inputs BF16, internal quantization (scales_a=None, scales_b used internally)
2. BF16 + FP8 BF16: Weight-only quantization (scales_a=None, scales_b for weight)
3. FP8 + FP8 BF16: W8A8 full quantization (scales_a for input, scales_b for weight)
Args:
input: Input tensor (M, K) - BF16 or FP8 e4m3
weight: Weight tensor (N, K) - BF16 or FP8 e4m3
scales_a: Input scales (M, K//block_k) or None - FP32
None: input is BF16 (will be quantized internally for BF16+BF16 or left as-is for BF16+FP8)
Provided: input is pre-quantized FP8 (W8A8 mode)
scales_b: Weight scales (N//block_n, K//block_k) - FP32
out_dtype: Output dtype (typically torch.bfloat16)
Returns:
output: Result tensor (M, N) in out_dtype
Note:
- Requires SM90+ GPU (NVIDIA Hopper)
- Uses TensorRT-LLM's optimized CUTLASS kernels via FlashInfer
- For M < 32, automatically uses SwapAB kernel optimization
"""
from flashinfer.gemm import fp8_blockscale_gemm_swapab
return fp8_blockscale_gemm_swapab(
input=input,
weight=weight,
input_scale=scales_a,
weight_scale=scales_b,
out_dtype=out_dtype,
)

View File

@ -35,6 +35,7 @@ from vllm.utils.deep_gemm import (
should_use_deepgemm_for_fp8_linear,
transform_sf_into_required_layout,
)
from vllm.utils.flashinfer import has_flashinfer_block_gemm
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
@ -409,6 +410,37 @@ class W8A8BlockFp8LinearOp:
input_2d.dtype,
)
def _run_flashinfer(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Run FlashInfer FP8 block-scale GEMM.
This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels
and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper).
"""
from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import (
flashinfer_block_gemm,
)
# Quantize input dynamically if not pre-quantized (same as CUTLASS)
assert input_scale is None
assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d)
# Now call FlashInfer with FP8 input + FP8 weight (W8A8)
return flashinfer_block_gemm(
input=q_input, # FP8 quantized input
weight=weight, # FP8 weight
scales_a=input_scale, # Input scales (computed dynamically)
scales_b=weight_scale, # Weight scales
out_dtype=input_2d.dtype,
)
def _dispatch_w8a8_blockscale_op(
self,
use_cutlass: bool,
@ -425,6 +457,22 @@ class W8A8BlockFp8LinearOp:
],
QuantFP8 | None,
]:
# Prefer FlashInfer on SM90+ if available (Hopper optimized)
if (
has_flashinfer_block_gemm()
and envs.VLLM_USE_FLASHINFER_FP8_LINEAR
and not use_aiter_and_is_supported
):
logger.info_once("Using FlashInfer FP8 block-scale GEMM for linear layers")
return self._run_flashinfer, (
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False,
)
)
if use_cutlass:
return self._run_cutlass, (
QuantFP8(

View File

@ -540,6 +540,29 @@ def flashinfer_scaled_fp8_mm(
return output
@functools.cache
def has_flashinfer_block_gemm() -> bool:
"""Return `True` if FlashInfer FP8 block-scale GEMM is available."""
if not has_flashinfer():
return False
if not current_platform.is_cuda():
return False
# Only SM90+ (Hopper) supports this kernel
if not current_platform.is_device_capability(90):
return False
try:
import flashinfer
# Check if the module has the required binding
return hasattr(flashinfer, "Fp8BlockScaleGemmRunner")
except (ImportError, AttributeError):
logger.debug_once(
"FlashInfer block-scale GEMM not available: module or binding not found"
)
return False
__all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
@ -562,4 +585,5 @@ __all__ = [
"use_trtllm_attention",
"flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm",
"has_flashinfer_block_gemm",
]