diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py index 11e3ac7f0c1fa..0b3903a0c0bc5 100644 --- a/benchmarks/kernels/bench_block_fp8_gemm.py +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -3,8 +3,10 @@ import os -# Disable DeepGEMM for this benchmark to use CUTLASS +# Disable DeepGEMM for this benchmark os.environ["VLLM_USE_DEEP_GEMM"] = "0" +# Enable FlashInfer FP8 linear (will be used when provider="flashinfer-block-fp8") +os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1" import torch @@ -94,6 +96,15 @@ plot_title = "BF16 vs W8A8 Block FP8 GEMMs" if CUTLASS_BLOCK_FP8_SUPPORTED: available_providers.append("w8a8-block-fp8-cutlass") +# Check if FlashInfer block GEMM is available +try: + from vllm.utils.flashinfer import has_flashinfer_block_gemm + + if has_flashinfer_block_gemm(): + available_providers.append("flashinfer-block-fp8") +except ImportError: + pass + @vllm_triton.testing.perf_report( vllm_triton.testing.Benchmark( @@ -134,6 +145,14 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)): ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( lambda: run_w8a8_cutlass(), quantiles=quantiles ) + elif provider == "flashinfer-block-fp8": + # Use the same W8A8 setup as other providers for fair comparison + run_w8a8_flashinfer = build_w8a8_block_fp8_runner( + M, N, K, block_size, device, use_cutlass=False + ) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8_flashinfer(), quantiles=quantiles + ) else: raise ValueError(f"Unknown provider: {provider}") diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 32c77b9a01ece..3005f9715f404 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -205,3 +205,244 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 + + +@pytest.mark.parametrize( + "M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), +) +@torch.inference_mode() +def test_flashinfer_block_gemm_matmul(M, N, K, block_size, out_dtype, seed): + """ + Test FlashInfer FP8 block-scale GEMM through W8A8BlockFp8LinearOp. + + This tests the FP8 + FP8 → BF16 path (W8A8 full quantization). + Matches TensorRT-LLM's test_fp8_block_scale_gemm behavior. + """ + import os + + from vllm.utils.flashinfer import has_flashinfer_block_gemm + + if not has_flashinfer_block_gemm(): + pytest.skip( + "FlashInfer block GEMM not available (requires SM90+ and FlashInfer)" + ) + + # Skip tests for dimensions that don't have pre-compiled kernels in FlashInfer + # These cause CUDA runtime errors + if K == 3884 or N == 7748: + pytest.skip(f"FlashInfer does not have pre-compiled kernels for K={K} or N={N}") + + # Enable FlashInfer backend (required for W8A8BlockFp8LinearOp to use FlashInfer) + os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1" + # Reload envs module to pick up the env var change + import importlib + + from vllm import envs + + importlib.reload(envs) + + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp, + ) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + ) + + torch.manual_seed(seed) + + # Create BF16 inputs (normalized like TRT-LLM) + A_bf16 = torch.randn(M, K, dtype=torch.bfloat16) / K + B_bf16 = torch.randn(N, K, dtype=torch.bfloat16) / K + + # Quantize weight with per-block scales + B_fp8, Bs = per_block_cast_to_fp8(B_bf16, block_size=block_size) + + # Create W8A8BlockFp8LinearOp to handle input quantization + block_n, block_k = block_size[0], block_size[1] + weight_group_shape = GroupShape(block_n, block_k) + act_quant_group_shape = GroupShape(1, block_k) # Per-token quantization + + linear_op = W8A8BlockFp8LinearOp( + weight_group_shape=weight_group_shape, + act_quant_group_shape=act_quant_group_shape, + cutlass_block_fp8_supported=False, # Disable CUTLASS + use_aiter_and_is_supported=False, # Disable AITER + ) + + # Verify FlashInfer backend is selected + assert linear_op.w8a8_blockscale_op == linear_op._run_flashinfer, ( + "FlashInfer backend not selected! " + "Make sure VLLM_USE_FLASHINFER_FP8_LINEAR=1 is set before running tests." + ) + + # Compute reference: BF16 × BF16 matmul (before quantization) + ref_out = torch.matmul(A_bf16, B_bf16.T) + + # Run W8A8 FlashInfer GEMM (input will be quantized internally) + out = linear_op.apply( + input=A_bf16, + weight=B_fp8, + weight_scale=Bs, + input_scale=None, # Will quantize dynamically + bias=None, + ) + + # Compare results using TensorRT-LLM's calc_diff metric + # This measures normalized similarity: sim = 2* / (||x||² + ||y||²) + out_fp64 = out.to(torch.float64) + ref_fp64 = ref_out.to(torch.float64) + denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum() + sim = 2 * (out_fp64 * ref_fp64).sum() / denominator + diff = 1 - sim + + # W8A8 threshold from TensorRT-LLM: diff < 0.001 (99.9% similarity) + assert diff < 0.001, ( + f"Similarity difference {diff:.6f} exceeds threshold (similarity: {sim:.6f})" + ) + + +@pytest.mark.parametrize( + "M,N,K,block_size,seed", + [ + (1, 1024, 4096, [128, 128], 0), + (32, 4096, 512, [128, 128], 0), + (128, 1024, 4096, [128, 128], 0), + ], +) +@pytest.mark.parametrize( + "input_dtype,weight_dtype", + [ + (torch.bfloat16, torch.bfloat16), # BF16 + BF16 (internal quantization) + (torch.bfloat16, torch.float8_e4m3fn), # BF16 + FP8 (weight-only) + (torch.float8_e4m3fn, torch.float8_e4m3fn), # FP8 + FP8 (W8A8) + ], +) +@torch.inference_mode() +def test_flashinfer_block_gemm_dtypes( + M, N, K, block_size, input_dtype, weight_dtype, seed +): + """ + Test all three supported dtype combinations for FlashInfer FP8 block-scale GEMM. + + Tests: + - BF16 + BF16 → BF16: Both inputs BF16, internal quantization + - BF16 + FP8 → BF16: Weight-only quantization + - FP8 + FP8 → BF16: W8A8 full quantization + + This mirrors FlashInfer's own test_fp8_blockscale_gemm_dtypes and TRT-LLM's tests. + """ + from vllm.utils.flashinfer import has_flashinfer_block_gemm + + if not has_flashinfer_block_gemm(): + pytest.skip( + "FlashInfer block GEMM not available (requires SM90+ and FlashInfer)" + ) + + from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import ( + flashinfer_block_gemm, + ) + + # Add debug output to verify test execution + print(f"\n{'=' * 80}") + print(f"TEST: M={M}, N={N}, K={K} | Input: {input_dtype}, Weight: {weight_dtype}") + print(f"{'=' * 80}") + + torch.manual_seed(seed) + + # Create BF16 data for reference (same as FlashInfer tests) + input_bf16 = torch.randn(M, K, dtype=torch.bfloat16) + weight_bf16 = torch.randn(N, K, dtype=torch.bfloat16) + + # Quantize input based on dtype + if input_dtype == torch.float8_e4m3fn: + input_tensor, input_scale = per_token_group_quant_fp8(input_bf16, block_size[1]) + else: + input_tensor, input_scale = input_bf16, None + + # Quantize weight based on dtype + if weight_dtype == torch.float8_e4m3fn: + weight_tensor, weight_scale = per_block_cast_to_fp8( + weight_bf16, block_size=block_size + ) + else: + weight_tensor, weight_scale = weight_bf16, None + + # Run FlashInfer FP8 block-scale GEMM + output = flashinfer_block_gemm( + input=input_tensor, + weight=weight_tensor, + scales_a=input_scale, + scales_b=weight_scale, + out_dtype=torch.bfloat16, + ) + + # Verify output properties + assert output.shape == (M, N), f"Expected shape {(M, N)}, got {output.shape}" + assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}" + + # Compute reference based on dtype combination + if input_dtype == torch.float8_e4m3fn and weight_dtype == torch.float8_e4m3fn: + # W8A8: Compare against dequantized FP8 reference (tests kernel correctness) + block_n, block_k = block_size[0], block_size[1] + k_tiles = (K + block_k - 1) // block_k + n_tiles = (N + block_n - 1) // block_n + + input_dequant = torch.zeros_like(input_bf16) + for i in range(M): + for k_tile in range(k_tiles): + start, end = k_tile * block_k, min((k_tile + 1) * block_k, K) + input_dequant[i, start:end] = ( + input_tensor[i, start:end].to(torch.bfloat16) + * input_scale[i, k_tile] + ) + + weight_dequant = torch.zeros_like(weight_bf16) + for j in range(N): + for k_tile in range(k_tiles): + start, end = k_tile * block_k, min((k_tile + 1) * block_k, K) + weight_dequant[j, start:end] = ( + weight_tensor[j, start:end].to(torch.bfloat16) + * weight_scale[j // block_n, k_tile] + ) + + reference = torch.matmul(input_dequant, weight_dequant.T) + + # W8A8: Use TRT-LLM's calc_diff metric with strict threshold + out_fp64 = output.to(torch.float64) + ref_fp64 = reference.to(torch.float64) + denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum() + sim = 2 * (out_fp64 * ref_fp64).sum() / denominator + diff = 1 - sim + + # W8A8 achieves very high accuracy: diff < 0.001 (99.9% similarity) + assert diff < 0.001, ( + f"W8A8 similarity difference {diff:.6f} too high (expected < 0.001, similarity: {sim:.6f})" + ) + else: + # BF16+BF16 or BF16+FP8: Compare against original BF16 reference + reference = torch.matmul(input_bf16, weight_bf16.T) + + out_fp64 = output.to(torch.float64) + ref_fp64 = reference.to(torch.float64) + denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum() + sim = 2 * (out_fp64 * ref_fp64).sum() / denominator + diff = 1 - sim + + if input_dtype == torch.bfloat16 and weight_dtype == torch.bfloat16: + # BF16+BF16: Highest accuracy (internal quantization) + threshold = 0.001 + threshold_desc = "0.1%" + elif input_dtype == torch.bfloat16 and weight_dtype == torch.float8_e4m3fn: + # BF16+FP8: Weight-only quantization, higher error + threshold = 0.01 + threshold_desc = "1%" + else: + # Other combinations + threshold = 0.01 + threshold_desc = "1%" + + assert diff < threshold, ( + f"Similarity difference {diff:.6f} too high for " + f"{input_dtype} + {weight_dtype} (expected < {threshold_desc}, similarity: {sim:.6f})" + ) diff --git a/vllm/envs.py b/vllm/envs.py index 1d4128d74b95c..9a9140aa0270d 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -168,6 +168,7 @@ if TYPE_CHECKING: "relax", ] = "relax" VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True + VLLM_USE_FLASHINFER_FP8_LINEAR: bool = False VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False @@ -1208,6 +1209,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool( int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1")) ), + # Allow use of FlashInfer FP8 block-scale GEMM for linear layers. + # This uses TensorRT-LLM kernels and requires SM90+ (Hopper). + "VLLM_USE_FLASHINFER_FP8_LINEAR": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_FP8_LINEAR", "0")) + ), # Allow use of FlashInfer MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool( int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0")) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py b/vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py new file mode 100644 index 0000000000000..04fb00ef5c8b8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +FlashInfer FP8 Block-Scale GEMM wrapper for vLLM. + +This module provides a thin wrapper around FlashInfer's FP8 block-scale GEMM +implementation, which uses TensorRT-LLM's optimized kernels for NVIDIA Hopper (SM90+). +""" + +import torch + + +def flashinfer_block_gemm( + input: torch.Tensor, + weight: torch.Tensor, + scales_a: torch.Tensor | None, + scales_b: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: + """ + Wrapper for FlashInfer's FP8 block-scale GEMM. + + Computes: output = (input @ weight.T) with per-block scaling for quantization. + + Supports three modes: + 1. BF16 + BF16 → BF16: Both inputs BF16, internal quantization (scales_a=None, scales_b used internally) + 2. BF16 + FP8 → BF16: Weight-only quantization (scales_a=None, scales_b for weight) + 3. FP8 + FP8 → BF16: W8A8 full quantization (scales_a for input, scales_b for weight) + + Args: + input: Input tensor (M, K) - BF16 or FP8 e4m3 + weight: Weight tensor (N, K) - BF16 or FP8 e4m3 + scales_a: Input scales (M, K//block_k) or None - FP32 + None: input is BF16 (will be quantized internally for BF16+BF16 or left as-is for BF16+FP8) + Provided: input is pre-quantized FP8 (W8A8 mode) + scales_b: Weight scales (N//block_n, K//block_k) - FP32 + out_dtype: Output dtype (typically torch.bfloat16) + + Returns: + output: Result tensor (M, N) in out_dtype + + Note: + - Requires SM90+ GPU (NVIDIA Hopper) + - Uses TensorRT-LLM's optimized CUTLASS kernels via FlashInfer + - For M < 32, automatically uses SwapAB kernel optimization + """ + from flashinfer.gemm import fp8_blockscale_gemm_swapab + + return fp8_blockscale_gemm_swapab( + input=input, + weight=weight, + input_scale=scales_a, + weight_scale=scales_b, + out_dtype=out_dtype, + ) + diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index de6a1e8c1aa7d..f893d7e53cd16 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -35,6 +35,7 @@ from vllm.utils.deep_gemm import ( should_use_deepgemm_for_fp8_linear, transform_sf_into_required_layout, ) +from vllm.utils.flashinfer import has_flashinfer_block_gemm from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -409,6 +410,37 @@ class W8A8BlockFp8LinearOp: input_2d.dtype, ) + def _run_flashinfer( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Run FlashInfer FP8 block-scale GEMM. + + This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels + and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper). + """ + from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import ( + flashinfer_block_gemm, + ) + + # Quantize input dynamically if not pre-quantized (same as CUTLASS) + assert input_scale is None + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + + # Now call FlashInfer with FP8 input + FP8 weight (W8A8) + return flashinfer_block_gemm( + input=q_input, # FP8 quantized input + weight=weight, # FP8 weight + scales_a=input_scale, # Input scales (computed dynamically) + scales_b=weight_scale, # Weight scales + out_dtype=input_2d.dtype, + ) + def _dispatch_w8a8_blockscale_op( self, use_cutlass: bool, @@ -425,6 +457,22 @@ class W8A8BlockFp8LinearOp: ], QuantFP8 | None, ]: + # Prefer FlashInfer on SM90+ if available (Hopper optimized) + if ( + has_flashinfer_block_gemm() + and envs.VLLM_USE_FLASHINFER_FP8_LINEAR + and not use_aiter_and_is_supported + ): + logger.info_once("Using FlashInfer FP8 block-scale GEMM for linear layers") + return self._run_flashinfer, ( + QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=False, + use_ue8m0=False, + ) + ) + if use_cutlass: return self._run_cutlass, ( QuantFP8( diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 6bbe02348eaf1..df8cc2468c605 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -540,6 +540,29 @@ def flashinfer_scaled_fp8_mm( return output +@functools.cache +def has_flashinfer_block_gemm() -> bool: + """Return `True` if FlashInfer FP8 block-scale GEMM is available.""" + if not has_flashinfer(): + return False + if not current_platform.is_cuda(): + return False + # Only SM90+ (Hopper) supports this kernel + if not current_platform.is_device_capability(90): + return False + + try: + import flashinfer + + # Check if the module has the required binding + return hasattr(flashinfer, "Fp8BlockScaleGemmRunner") + except (ImportError, AttributeError): + logger.debug_once( + "FlashInfer block-scale GEMM not available: module or binding not found" + ) + return False + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -562,4 +585,5 @@ __all__ = [ "use_trtllm_attention", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", + "has_flashinfer_block_gemm", ]