From 3d429d63a658f29fe26e3448250bc8d4d52908b0 Mon Sep 17 00:00:00 2001 From: Kate Cheng Date: Fri, 21 Nov 2025 15:59:27 -0800 Subject: [PATCH 1/3] Enable linear deepgemm_swapAB Signed-off-by: Kate Cheng --- benchmarks/kernels/bench_block_fp8_gemm.py | 21 +- tests/kernels/quantization/test_block_fp8.py | 241 ++++++++++++++++++ vllm/envs.py | 6 + .../utils/flashinfer_block_gemm.py | 57 +++++ .../layers/quantization/utils/fp8_utils.py | 48 ++++ vllm/utils/flashinfer.py | 24 ++ 6 files changed, 396 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py index 11e3ac7f0c1fa..0b3903a0c0bc5 100644 --- a/benchmarks/kernels/bench_block_fp8_gemm.py +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -3,8 +3,10 @@ import os -# Disable DeepGEMM for this benchmark to use CUTLASS +# Disable DeepGEMM for this benchmark os.environ["VLLM_USE_DEEP_GEMM"] = "0" +# Enable FlashInfer FP8 linear (will be used when provider="flashinfer-block-fp8") +os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1" import torch @@ -94,6 +96,15 @@ plot_title = "BF16 vs W8A8 Block FP8 GEMMs" if CUTLASS_BLOCK_FP8_SUPPORTED: available_providers.append("w8a8-block-fp8-cutlass") +# Check if FlashInfer block GEMM is available +try: + from vllm.utils.flashinfer import has_flashinfer_block_gemm + + if has_flashinfer_block_gemm(): + available_providers.append("flashinfer-block-fp8") +except ImportError: + pass + @vllm_triton.testing.perf_report( vllm_triton.testing.Benchmark( @@ -134,6 +145,14 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)): ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( lambda: run_w8a8_cutlass(), quantiles=quantiles ) + elif provider == "flashinfer-block-fp8": + # Use the same W8A8 setup as other providers for fair comparison + run_w8a8_flashinfer = build_w8a8_block_fp8_runner( + M, N, K, block_size, device, use_cutlass=False + ) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8_flashinfer(), quantiles=quantiles + ) else: raise ValueError(f"Unknown provider: {provider}") diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 32c77b9a01ece..3005f9715f404 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -205,3 +205,244 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 + + +@pytest.mark.parametrize( + "M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), +) +@torch.inference_mode() +def test_flashinfer_block_gemm_matmul(M, N, K, block_size, out_dtype, seed): + """ + Test FlashInfer FP8 block-scale GEMM through W8A8BlockFp8LinearOp. + + This tests the FP8 + FP8 → BF16 path (W8A8 full quantization). + Matches TensorRT-LLM's test_fp8_block_scale_gemm behavior. + """ + import os + + from vllm.utils.flashinfer import has_flashinfer_block_gemm + + if not has_flashinfer_block_gemm(): + pytest.skip( + "FlashInfer block GEMM not available (requires SM90+ and FlashInfer)" + ) + + # Skip tests for dimensions that don't have pre-compiled kernels in FlashInfer + # These cause CUDA runtime errors + if K == 3884 or N == 7748: + pytest.skip(f"FlashInfer does not have pre-compiled kernels for K={K} or N={N}") + + # Enable FlashInfer backend (required for W8A8BlockFp8LinearOp to use FlashInfer) + os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1" + # Reload envs module to pick up the env var change + import importlib + + from vllm import envs + + importlib.reload(envs) + + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp, + ) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + ) + + torch.manual_seed(seed) + + # Create BF16 inputs (normalized like TRT-LLM) + A_bf16 = torch.randn(M, K, dtype=torch.bfloat16) / K + B_bf16 = torch.randn(N, K, dtype=torch.bfloat16) / K + + # Quantize weight with per-block scales + B_fp8, Bs = per_block_cast_to_fp8(B_bf16, block_size=block_size) + + # Create W8A8BlockFp8LinearOp to handle input quantization + block_n, block_k = block_size[0], block_size[1] + weight_group_shape = GroupShape(block_n, block_k) + act_quant_group_shape = GroupShape(1, block_k) # Per-token quantization + + linear_op = W8A8BlockFp8LinearOp( + weight_group_shape=weight_group_shape, + act_quant_group_shape=act_quant_group_shape, + cutlass_block_fp8_supported=False, # Disable CUTLASS + use_aiter_and_is_supported=False, # Disable AITER + ) + + # Verify FlashInfer backend is selected + assert linear_op.w8a8_blockscale_op == linear_op._run_flashinfer, ( + "FlashInfer backend not selected! " + "Make sure VLLM_USE_FLASHINFER_FP8_LINEAR=1 is set before running tests." + ) + + # Compute reference: BF16 × BF16 matmul (before quantization) + ref_out = torch.matmul(A_bf16, B_bf16.T) + + # Run W8A8 FlashInfer GEMM (input will be quantized internally) + out = linear_op.apply( + input=A_bf16, + weight=B_fp8, + weight_scale=Bs, + input_scale=None, # Will quantize dynamically + bias=None, + ) + + # Compare results using TensorRT-LLM's calc_diff metric + # This measures normalized similarity: sim = 2* / (||x||² + ||y||²) + out_fp64 = out.to(torch.float64) + ref_fp64 = ref_out.to(torch.float64) + denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum() + sim = 2 * (out_fp64 * ref_fp64).sum() / denominator + diff = 1 - sim + + # W8A8 threshold from TensorRT-LLM: diff < 0.001 (99.9% similarity) + assert diff < 0.001, ( + f"Similarity difference {diff:.6f} exceeds threshold (similarity: {sim:.6f})" + ) + + +@pytest.mark.parametrize( + "M,N,K,block_size,seed", + [ + (1, 1024, 4096, [128, 128], 0), + (32, 4096, 512, [128, 128], 0), + (128, 1024, 4096, [128, 128], 0), + ], +) +@pytest.mark.parametrize( + "input_dtype,weight_dtype", + [ + (torch.bfloat16, torch.bfloat16), # BF16 + BF16 (internal quantization) + (torch.bfloat16, torch.float8_e4m3fn), # BF16 + FP8 (weight-only) + (torch.float8_e4m3fn, torch.float8_e4m3fn), # FP8 + FP8 (W8A8) + ], +) +@torch.inference_mode() +def test_flashinfer_block_gemm_dtypes( + M, N, K, block_size, input_dtype, weight_dtype, seed +): + """ + Test all three supported dtype combinations for FlashInfer FP8 block-scale GEMM. + + Tests: + - BF16 + BF16 → BF16: Both inputs BF16, internal quantization + - BF16 + FP8 → BF16: Weight-only quantization + - FP8 + FP8 → BF16: W8A8 full quantization + + This mirrors FlashInfer's own test_fp8_blockscale_gemm_dtypes and TRT-LLM's tests. + """ + from vllm.utils.flashinfer import has_flashinfer_block_gemm + + if not has_flashinfer_block_gemm(): + pytest.skip( + "FlashInfer block GEMM not available (requires SM90+ and FlashInfer)" + ) + + from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import ( + flashinfer_block_gemm, + ) + + # Add debug output to verify test execution + print(f"\n{'=' * 80}") + print(f"TEST: M={M}, N={N}, K={K} | Input: {input_dtype}, Weight: {weight_dtype}") + print(f"{'=' * 80}") + + torch.manual_seed(seed) + + # Create BF16 data for reference (same as FlashInfer tests) + input_bf16 = torch.randn(M, K, dtype=torch.bfloat16) + weight_bf16 = torch.randn(N, K, dtype=torch.bfloat16) + + # Quantize input based on dtype + if input_dtype == torch.float8_e4m3fn: + input_tensor, input_scale = per_token_group_quant_fp8(input_bf16, block_size[1]) + else: + input_tensor, input_scale = input_bf16, None + + # Quantize weight based on dtype + if weight_dtype == torch.float8_e4m3fn: + weight_tensor, weight_scale = per_block_cast_to_fp8( + weight_bf16, block_size=block_size + ) + else: + weight_tensor, weight_scale = weight_bf16, None + + # Run FlashInfer FP8 block-scale GEMM + output = flashinfer_block_gemm( + input=input_tensor, + weight=weight_tensor, + scales_a=input_scale, + scales_b=weight_scale, + out_dtype=torch.bfloat16, + ) + + # Verify output properties + assert output.shape == (M, N), f"Expected shape {(M, N)}, got {output.shape}" + assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}" + + # Compute reference based on dtype combination + if input_dtype == torch.float8_e4m3fn and weight_dtype == torch.float8_e4m3fn: + # W8A8: Compare against dequantized FP8 reference (tests kernel correctness) + block_n, block_k = block_size[0], block_size[1] + k_tiles = (K + block_k - 1) // block_k + n_tiles = (N + block_n - 1) // block_n + + input_dequant = torch.zeros_like(input_bf16) + for i in range(M): + for k_tile in range(k_tiles): + start, end = k_tile * block_k, min((k_tile + 1) * block_k, K) + input_dequant[i, start:end] = ( + input_tensor[i, start:end].to(torch.bfloat16) + * input_scale[i, k_tile] + ) + + weight_dequant = torch.zeros_like(weight_bf16) + for j in range(N): + for k_tile in range(k_tiles): + start, end = k_tile * block_k, min((k_tile + 1) * block_k, K) + weight_dequant[j, start:end] = ( + weight_tensor[j, start:end].to(torch.bfloat16) + * weight_scale[j // block_n, k_tile] + ) + + reference = torch.matmul(input_dequant, weight_dequant.T) + + # W8A8: Use TRT-LLM's calc_diff metric with strict threshold + out_fp64 = output.to(torch.float64) + ref_fp64 = reference.to(torch.float64) + denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum() + sim = 2 * (out_fp64 * ref_fp64).sum() / denominator + diff = 1 - sim + + # W8A8 achieves very high accuracy: diff < 0.001 (99.9% similarity) + assert diff < 0.001, ( + f"W8A8 similarity difference {diff:.6f} too high (expected < 0.001, similarity: {sim:.6f})" + ) + else: + # BF16+BF16 or BF16+FP8: Compare against original BF16 reference + reference = torch.matmul(input_bf16, weight_bf16.T) + + out_fp64 = output.to(torch.float64) + ref_fp64 = reference.to(torch.float64) + denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum() + sim = 2 * (out_fp64 * ref_fp64).sum() / denominator + diff = 1 - sim + + if input_dtype == torch.bfloat16 and weight_dtype == torch.bfloat16: + # BF16+BF16: Highest accuracy (internal quantization) + threshold = 0.001 + threshold_desc = "0.1%" + elif input_dtype == torch.bfloat16 and weight_dtype == torch.float8_e4m3fn: + # BF16+FP8: Weight-only quantization, higher error + threshold = 0.01 + threshold_desc = "1%" + else: + # Other combinations + threshold = 0.01 + threshold_desc = "1%" + + assert diff < threshold, ( + f"Similarity difference {diff:.6f} too high for " + f"{input_dtype} + {weight_dtype} (expected < {threshold_desc}, similarity: {sim:.6f})" + ) diff --git a/vllm/envs.py b/vllm/envs.py index 1d4128d74b95c..9a9140aa0270d 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -168,6 +168,7 @@ if TYPE_CHECKING: "relax", ] = "relax" VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True + VLLM_USE_FLASHINFER_FP8_LINEAR: bool = False VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False @@ -1208,6 +1209,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool( int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1")) ), + # Allow use of FlashInfer FP8 block-scale GEMM for linear layers. + # This uses TensorRT-LLM kernels and requires SM90+ (Hopper). + "VLLM_USE_FLASHINFER_FP8_LINEAR": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_FP8_LINEAR", "0")) + ), # Allow use of FlashInfer MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool( int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0")) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py b/vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py new file mode 100644 index 0000000000000..04fb00ef5c8b8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +FlashInfer FP8 Block-Scale GEMM wrapper for vLLM. + +This module provides a thin wrapper around FlashInfer's FP8 block-scale GEMM +implementation, which uses TensorRT-LLM's optimized kernels for NVIDIA Hopper (SM90+). +""" + +import torch + + +def flashinfer_block_gemm( + input: torch.Tensor, + weight: torch.Tensor, + scales_a: torch.Tensor | None, + scales_b: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: + """ + Wrapper for FlashInfer's FP8 block-scale GEMM. + + Computes: output = (input @ weight.T) with per-block scaling for quantization. + + Supports three modes: + 1. BF16 + BF16 → BF16: Both inputs BF16, internal quantization (scales_a=None, scales_b used internally) + 2. BF16 + FP8 → BF16: Weight-only quantization (scales_a=None, scales_b for weight) + 3. FP8 + FP8 → BF16: W8A8 full quantization (scales_a for input, scales_b for weight) + + Args: + input: Input tensor (M, K) - BF16 or FP8 e4m3 + weight: Weight tensor (N, K) - BF16 or FP8 e4m3 + scales_a: Input scales (M, K//block_k) or None - FP32 + None: input is BF16 (will be quantized internally for BF16+BF16 or left as-is for BF16+FP8) + Provided: input is pre-quantized FP8 (W8A8 mode) + scales_b: Weight scales (N//block_n, K//block_k) - FP32 + out_dtype: Output dtype (typically torch.bfloat16) + + Returns: + output: Result tensor (M, N) in out_dtype + + Note: + - Requires SM90+ GPU (NVIDIA Hopper) + - Uses TensorRT-LLM's optimized CUTLASS kernels via FlashInfer + - For M < 32, automatically uses SwapAB kernel optimization + """ + from flashinfer.gemm import fp8_blockscale_gemm_swapab + + return fp8_blockscale_gemm_swapab( + input=input, + weight=weight, + input_scale=scales_a, + weight_scale=scales_b, + out_dtype=out_dtype, + ) + diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index de6a1e8c1aa7d..f893d7e53cd16 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -35,6 +35,7 @@ from vllm.utils.deep_gemm import ( should_use_deepgemm_for_fp8_linear, transform_sf_into_required_layout, ) +from vllm.utils.flashinfer import has_flashinfer_block_gemm from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -409,6 +410,37 @@ class W8A8BlockFp8LinearOp: input_2d.dtype, ) + def _run_flashinfer( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Run FlashInfer FP8 block-scale GEMM. + + This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels + and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper). + """ + from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import ( + flashinfer_block_gemm, + ) + + # Quantize input dynamically if not pre-quantized (same as CUTLASS) + assert input_scale is None + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + + # Now call FlashInfer with FP8 input + FP8 weight (W8A8) + return flashinfer_block_gemm( + input=q_input, # FP8 quantized input + weight=weight, # FP8 weight + scales_a=input_scale, # Input scales (computed dynamically) + scales_b=weight_scale, # Weight scales + out_dtype=input_2d.dtype, + ) + def _dispatch_w8a8_blockscale_op( self, use_cutlass: bool, @@ -425,6 +457,22 @@ class W8A8BlockFp8LinearOp: ], QuantFP8 | None, ]: + # Prefer FlashInfer on SM90+ if available (Hopper optimized) + if ( + has_flashinfer_block_gemm() + and envs.VLLM_USE_FLASHINFER_FP8_LINEAR + and not use_aiter_and_is_supported + ): + logger.info_once("Using FlashInfer FP8 block-scale GEMM for linear layers") + return self._run_flashinfer, ( + QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=False, + use_ue8m0=False, + ) + ) + if use_cutlass: return self._run_cutlass, ( QuantFP8( diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 6bbe02348eaf1..df8cc2468c605 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -540,6 +540,29 @@ def flashinfer_scaled_fp8_mm( return output +@functools.cache +def has_flashinfer_block_gemm() -> bool: + """Return `True` if FlashInfer FP8 block-scale GEMM is available.""" + if not has_flashinfer(): + return False + if not current_platform.is_cuda(): + return False + # Only SM90+ (Hopper) supports this kernel + if not current_platform.is_device_capability(90): + return False + + try: + import flashinfer + + # Check if the module has the required binding + return hasattr(flashinfer, "Fp8BlockScaleGemmRunner") + except (ImportError, AttributeError): + logger.debug_once( + "FlashInfer block-scale GEMM not available: module or binding not found" + ) + return False + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -562,4 +585,5 @@ __all__ = [ "use_trtllm_attention", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", + "has_flashinfer_block_gemm", ] From 5a5506c6617136bceee3ce3d81277d08fd0bdb71 Mon Sep 17 00:00:00 2001 From: Jhao-Ting Chen Date: Mon, 22 Dec 2025 09:16:49 -0800 Subject: [PATCH 2/3] enable DeepGEMM swapAB from FlashInfer for M<32 linear gemms Signed-off-by: Jhao-Ting Chen --- benchmarks/kernels/bench_block_fp8_gemm.py | 21 +- tests/kernels/quantization/test_block_fp8.py | 249 +++--------------- .../utils/flashinfer_block_gemm.py | 57 ---- .../layers/quantization/utils/fp8_utils.py | 124 ++++++--- vllm/utils/flashinfer.py | 64 +++-- 5 files changed, 168 insertions(+), 347 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py index 0b3903a0c0bc5..11e3ac7f0c1fa 100644 --- a/benchmarks/kernels/bench_block_fp8_gemm.py +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -3,10 +3,8 @@ import os -# Disable DeepGEMM for this benchmark +# Disable DeepGEMM for this benchmark to use CUTLASS os.environ["VLLM_USE_DEEP_GEMM"] = "0" -# Enable FlashInfer FP8 linear (will be used when provider="flashinfer-block-fp8") -os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1" import torch @@ -96,15 +94,6 @@ plot_title = "BF16 vs W8A8 Block FP8 GEMMs" if CUTLASS_BLOCK_FP8_SUPPORTED: available_providers.append("w8a8-block-fp8-cutlass") -# Check if FlashInfer block GEMM is available -try: - from vllm.utils.flashinfer import has_flashinfer_block_gemm - - if has_flashinfer_block_gemm(): - available_providers.append("flashinfer-block-fp8") -except ImportError: - pass - @vllm_triton.testing.perf_report( vllm_triton.testing.Benchmark( @@ -145,14 +134,6 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)): ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( lambda: run_w8a8_cutlass(), quantiles=quantiles ) - elif provider == "flashinfer-block-fp8": - # Use the same W8A8 setup as other providers for fair comparison - run_w8a8_flashinfer = build_w8a8_block_fp8_runner( - M, N, K, block_size, device, use_cutlass=False - ) - ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( - lambda: run_w8a8_flashinfer(), quantiles=quantiles - ) else: raise ValueError(f"Unknown provider: {provider}") diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 3005f9715f404..01789abaa44e0 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -24,6 +24,10 @@ from vllm.utils.deep_gemm import ( per_block_cast_to_fp8, should_use_deepgemm_for_fp8_linear, ) +from vllm.utils.flashinfer import ( + flashinfer_fp8_blockscale_gemm, + has_flashinfer_fp8_blockscale_gemm, +) from vllm.utils.import_utils import has_deep_gemm if current_platform.get_device_capability() < (9, 0): @@ -212,237 +216,46 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), ) @torch.inference_mode() -def test_flashinfer_block_gemm_matmul(M, N, K, block_size, out_dtype, seed): - """ - Test FlashInfer FP8 block-scale GEMM through W8A8BlockFp8LinearOp. - - This tests the FP8 + FP8 → BF16 path (W8A8 full quantization). - Matches TensorRT-LLM's test_fp8_block_scale_gemm behavior. - """ - import os - - from vllm.utils.flashinfer import has_flashinfer_block_gemm - - if not has_flashinfer_block_gemm(): +def test_w8a8_block_fp8_flashinfer_matmul(M, N, K, block_size, out_dtype, seed): + if not has_flashinfer_fp8_blockscale_gemm(): pytest.skip( "FlashInfer block GEMM not available (requires SM90+ and FlashInfer)" ) - - # Skip tests for dimensions that don't have pre-compiled kernels in FlashInfer - # These cause CUDA runtime errors - if K == 3884 or N == 7748: - pytest.skip(f"FlashInfer does not have pre-compiled kernels for K={K} or N={N}") - - # Enable FlashInfer backend (required for W8A8BlockFp8LinearOp to use FlashInfer) - os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1" - # Reload envs module to pick up the env var change - import importlib - - from vllm import envs - - importlib.reload(envs) - - from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, - ) - from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, - ) + # only aligned sizes + if K % 128 != 0 or N % 64 != 0: + pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") torch.manual_seed(seed) + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max = fp8_info.max - # Create BF16 inputs (normalized like TRT-LLM) - A_bf16 = torch.randn(M, K, dtype=torch.bfloat16) / K - B_bf16 = torch.randn(N, K, dtype=torch.bfloat16) / K + A_bf16 = (torch.rand(M, K, dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + B_bf16 = (torch.rand(N, K, dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - # Quantize weight with per-block scales - B_fp8, Bs = per_block_cast_to_fp8(B_bf16, block_size=block_size) + A_fp8, As_fp8 = per_token_group_quant_fp8(A_bf16, block_size[1], use_ue8m0=False) + B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_bf16, block_size, use_ue8m0=False) - # Create W8A8BlockFp8LinearOp to handle input quantization - block_n, block_k = block_size[0], block_size[1] - weight_group_shape = GroupShape(block_n, block_k) - act_quant_group_shape = GroupShape(1, block_k) # Per-token quantization + As = As_fp8.to(torch.float32) + Bs = Bs_fp8.to(torch.float32) - linear_op = W8A8BlockFp8LinearOp( - weight_group_shape=weight_group_shape, - act_quant_group_shape=act_quant_group_shape, - cutlass_block_fp8_supported=False, # Disable CUTLASS - use_aiter_and_is_supported=False, # Disable AITER + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + As_fp8 = get_col_major_tma_aligned_tensor(As_fp8) + + # Transpose earlier so that the testing will not trigger transposing kernels + + assert As_fp8.shape == (M, (K + 127) // 128), ( + f"{As_fp8.shape} != {(M, (K + 127) // 128)}" ) - # Verify FlashInfer backend is selected - assert linear_op.w8a8_blockscale_op == linear_op._run_flashinfer, ( - "FlashInfer backend not selected! " - "Make sure VLLM_USE_FLASHINFER_FP8_LINEAR=1 is set before running tests." - ) - - # Compute reference: BF16 × BF16 matmul (before quantization) - ref_out = torch.matmul(A_bf16, B_bf16.T) - - # Run W8A8 FlashInfer GEMM (input will be quantized internally) - out = linear_op.apply( + out = flashinfer_fp8_blockscale_gemm( input=A_bf16, weight=B_fp8, + input_scale=None, weight_scale=Bs, - input_scale=None, # Will quantize dynamically - bias=None, + out_dtype=out_dtype, ) - # Compare results using TensorRT-LLM's calc_diff metric - # This measures normalized similarity: sim = 2* / (||x||² + ||y||²) - out_fp64 = out.to(torch.float64) - ref_fp64 = ref_out.to(torch.float64) - denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum() - sim = 2 * (out_fp64 * ref_fp64).sum() / denominator - diff = 1 - sim - - # W8A8 threshold from TensorRT-LLM: diff < 0.001 (99.9% similarity) - assert diff < 0.001, ( - f"Similarity difference {diff:.6f} exceeds threshold (similarity: {sim:.6f})" - ) - - -@pytest.mark.parametrize( - "M,N,K,block_size,seed", - [ - (1, 1024, 4096, [128, 128], 0), - (32, 4096, 512, [128, 128], 0), - (128, 1024, 4096, [128, 128], 0), - ], -) -@pytest.mark.parametrize( - "input_dtype,weight_dtype", - [ - (torch.bfloat16, torch.bfloat16), # BF16 + BF16 (internal quantization) - (torch.bfloat16, torch.float8_e4m3fn), # BF16 + FP8 (weight-only) - (torch.float8_e4m3fn, torch.float8_e4m3fn), # FP8 + FP8 (W8A8) - ], -) -@torch.inference_mode() -def test_flashinfer_block_gemm_dtypes( - M, N, K, block_size, input_dtype, weight_dtype, seed -): - """ - Test all three supported dtype combinations for FlashInfer FP8 block-scale GEMM. - - Tests: - - BF16 + BF16 → BF16: Both inputs BF16, internal quantization - - BF16 + FP8 → BF16: Weight-only quantization - - FP8 + FP8 → BF16: W8A8 full quantization - - This mirrors FlashInfer's own test_fp8_blockscale_gemm_dtypes and TRT-LLM's tests. - """ - from vllm.utils.flashinfer import has_flashinfer_block_gemm - - if not has_flashinfer_block_gemm(): - pytest.skip( - "FlashInfer block GEMM not available (requires SM90+ and FlashInfer)" - ) - - from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import ( - flashinfer_block_gemm, - ) - - # Add debug output to verify test execution - print(f"\n{'=' * 80}") - print(f"TEST: M={M}, N={N}, K={K} | Input: {input_dtype}, Weight: {weight_dtype}") - print(f"{'=' * 80}") - - torch.manual_seed(seed) - - # Create BF16 data for reference (same as FlashInfer tests) - input_bf16 = torch.randn(M, K, dtype=torch.bfloat16) - weight_bf16 = torch.randn(N, K, dtype=torch.bfloat16) - - # Quantize input based on dtype - if input_dtype == torch.float8_e4m3fn: - input_tensor, input_scale = per_token_group_quant_fp8(input_bf16, block_size[1]) - else: - input_tensor, input_scale = input_bf16, None - - # Quantize weight based on dtype - if weight_dtype == torch.float8_e4m3fn: - weight_tensor, weight_scale = per_block_cast_to_fp8( - weight_bf16, block_size=block_size - ) - else: - weight_tensor, weight_scale = weight_bf16, None - - # Run FlashInfer FP8 block-scale GEMM - output = flashinfer_block_gemm( - input=input_tensor, - weight=weight_tensor, - scales_a=input_scale, - scales_b=weight_scale, - out_dtype=torch.bfloat16, - ) - - # Verify output properties - assert output.shape == (M, N), f"Expected shape {(M, N)}, got {output.shape}" - assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}" - - # Compute reference based on dtype combination - if input_dtype == torch.float8_e4m3fn and weight_dtype == torch.float8_e4m3fn: - # W8A8: Compare against dequantized FP8 reference (tests kernel correctness) - block_n, block_k = block_size[0], block_size[1] - k_tiles = (K + block_k - 1) // block_k - n_tiles = (N + block_n - 1) // block_n - - input_dequant = torch.zeros_like(input_bf16) - for i in range(M): - for k_tile in range(k_tiles): - start, end = k_tile * block_k, min((k_tile + 1) * block_k, K) - input_dequant[i, start:end] = ( - input_tensor[i, start:end].to(torch.bfloat16) - * input_scale[i, k_tile] - ) - - weight_dequant = torch.zeros_like(weight_bf16) - for j in range(N): - for k_tile in range(k_tiles): - start, end = k_tile * block_k, min((k_tile + 1) * block_k, K) - weight_dequant[j, start:end] = ( - weight_tensor[j, start:end].to(torch.bfloat16) - * weight_scale[j // block_n, k_tile] - ) - - reference = torch.matmul(input_dequant, weight_dequant.T) - - # W8A8: Use TRT-LLM's calc_diff metric with strict threshold - out_fp64 = output.to(torch.float64) - ref_fp64 = reference.to(torch.float64) - denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum() - sim = 2 * (out_fp64 * ref_fp64).sum() / denominator - diff = 1 - sim - - # W8A8 achieves very high accuracy: diff < 0.001 (99.9% similarity) - assert diff < 0.001, ( - f"W8A8 similarity difference {diff:.6f} too high (expected < 0.001, similarity: {sim:.6f})" - ) - else: - # BF16+BF16 or BF16+FP8: Compare against original BF16 reference - reference = torch.matmul(input_bf16, weight_bf16.T) - - out_fp64 = output.to(torch.float64) - ref_fp64 = reference.to(torch.float64) - denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum() - sim = 2 * (out_fp64 * ref_fp64).sum() / denominator - diff = 1 - sim - - if input_dtype == torch.bfloat16 and weight_dtype == torch.bfloat16: - # BF16+BF16: Highest accuracy (internal quantization) - threshold = 0.001 - threshold_desc = "0.1%" - elif input_dtype == torch.bfloat16 and weight_dtype == torch.float8_e4m3fn: - # BF16+FP8: Weight-only quantization, higher error - threshold = 0.01 - threshold_desc = "1%" - else: - # Other combinations - threshold = 0.01 - threshold_desc = "1%" - - assert diff < threshold, ( - f"Similarity difference {diff:.6f} too high for " - f"{input_dtype} + {weight_dtype} (expected < {threshold_desc}, similarity: {sim:.6f})" - ) + rel_diff = torch.mean( + torch.abs(out.to(torch.bfloat16) - ref_out.to(torch.bfloat16)) + ) / torch.mean(torch.abs(ref_out.to(torch.bfloat16))) + assert rel_diff < 0.001 diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py b/vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py deleted file mode 100644 index 04fb00ef5c8b8..0000000000000 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py +++ /dev/null @@ -1,57 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -""" -FlashInfer FP8 Block-Scale GEMM wrapper for vLLM. - -This module provides a thin wrapper around FlashInfer's FP8 block-scale GEMM -implementation, which uses TensorRT-LLM's optimized kernels for NVIDIA Hopper (SM90+). -""" - -import torch - - -def flashinfer_block_gemm( - input: torch.Tensor, - weight: torch.Tensor, - scales_a: torch.Tensor | None, - scales_b: torch.Tensor, - out_dtype: torch.dtype, -) -> torch.Tensor: - """ - Wrapper for FlashInfer's FP8 block-scale GEMM. - - Computes: output = (input @ weight.T) with per-block scaling for quantization. - - Supports three modes: - 1. BF16 + BF16 → BF16: Both inputs BF16, internal quantization (scales_a=None, scales_b used internally) - 2. BF16 + FP8 → BF16: Weight-only quantization (scales_a=None, scales_b for weight) - 3. FP8 + FP8 → BF16: W8A8 full quantization (scales_a for input, scales_b for weight) - - Args: - input: Input tensor (M, K) - BF16 or FP8 e4m3 - weight: Weight tensor (N, K) - BF16 or FP8 e4m3 - scales_a: Input scales (M, K//block_k) or None - FP32 - None: input is BF16 (will be quantized internally for BF16+BF16 or left as-is for BF16+FP8) - Provided: input is pre-quantized FP8 (W8A8 mode) - scales_b: Weight scales (N//block_n, K//block_k) - FP32 - out_dtype: Output dtype (typically torch.bfloat16) - - Returns: - output: Result tensor (M, N) in out_dtype - - Note: - - Requires SM90+ GPU (NVIDIA Hopper) - - Uses TensorRT-LLM's optimized CUTLASS kernels via FlashInfer - - For M < 32, automatically uses SwapAB kernel optimization - """ - from flashinfer.gemm import fp8_blockscale_gemm_swapab - - return fp8_blockscale_gemm_swapab( - input=input, - weight=weight, - input_scale=scales_a, - weight_scale=scales_b, - out_dtype=out_dtype, - ) - diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index f893d7e53cd16..5ebb7395a3cc7 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -35,7 +35,11 @@ from vllm.utils.deep_gemm import ( should_use_deepgemm_for_fp8_linear, transform_sf_into_required_layout, ) -from vllm.utils.flashinfer import has_flashinfer_block_gemm +from vllm.utils.flashinfer import ( + flashinfer_fp8_blockscale_gemm, + is_flashinfer_fp8_blockscale_gemm_supported, + should_use_flashinfer_for_block_scale_fp8_linear, +) from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -227,6 +231,76 @@ direct_register_custom_op( ) +def _flashinfer_fp8_blockscale_gemm_impl( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + group_size: int, + use_deep_gemm_e8m0: bool, +) -> torch.Tensor: + def use_flashinfer( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + return flashinfer_fp8_blockscale_gemm( + input=input, + weight=weight, + weight_scale=weight_scale, + out_dtype=torch.bfloat16, + ) + + def use_deepgemm( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + q_input, input_scale = per_token_group_quant_fp8( + input, + group_size=group_size, + column_major_scales=True, + use_ue8m0=use_deep_gemm_e8m0, + ) + output = torch.empty( + (q_input.shape[0], weight.shape[0]), + dtype=torch.bfloat16, + device=q_input.device, + ) + fp8_gemm_nt( + (q_input, input_scale), + (weight, weight_scale), + output, + is_deep_gemm_e8m0_used=use_deep_gemm_e8m0, + ) + return output + + condition = input.shape[0] < 32 + + # Pass all required variables through operands + return torch.cond( + condition, use_flashinfer, use_deepgemm, (input, weight, weight_scale) + ) + + +def _flashinfer_fp8_blockscale_gemm_fake( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + group_size: int, + use_deep_gemm_e8m0: bool, +) -> torch.Tensor: + return torch.empty( + input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device + ) + + +direct_register_custom_op( + "flashinfer_fp8_blockscale_gemm", + _flashinfer_fp8_blockscale_gemm_impl, + fake_impl=_flashinfer_fp8_blockscale_gemm_fake, +) + + # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 class W8A8BlockFp8LinearOp: @@ -247,6 +321,7 @@ class W8A8BlockFp8LinearOp: self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_hopper = current_platform.is_device_capability(90) self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() + self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported() # Get the correct blockscale mul and input quant operations. # We can't use _dispatch_w8a8_blockscale_op to figure out if we want @@ -282,7 +357,12 @@ class W8A8BlockFp8LinearOp: output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if should_use_deepgemm_for_fp8_linear( + if should_use_flashinfer_for_block_scale_fp8_linear( + self.is_flashinfer_supported, output_dtype, input_2d, weight + ): + output = self._run_flashinfer(input_2d, weight, weight_scale) + + elif should_use_deepgemm_for_fp8_linear( output_dtype, weight, self.is_deep_gemm_supported ): output = self._run_deepgemm(input_2d, weight, weight_scale) @@ -415,7 +495,6 @@ class W8A8BlockFp8LinearOp: input_2d: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, - input_scale: torch.Tensor | None = None, ) -> torch.Tensor: """ Run FlashInfer FP8 block-scale GEMM. @@ -423,23 +502,16 @@ class W8A8BlockFp8LinearOp: This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper). """ - from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import ( - flashinfer_block_gemm, - ) - - # Quantize input dynamically if not pre-quantized (same as CUTLASS) - assert input_scale is None - assert self.input_quant_op is not None - q_input, input_scale = self.input_quant_op(input_2d) - - # Now call FlashInfer with FP8 input + FP8 weight (W8A8) - return flashinfer_block_gemm( - input=q_input, # FP8 quantized input + # Now call FlashInfer with BF16 input + FP8 weight, input will be + # quantized with FlashInfer kernel (W8A8) + output = torch.ops.vllm.flashinfer_fp8_blockscale_gemm( + input=input_2d, # BF16 input weight=weight, # FP8 weight - scales_a=input_scale, # Input scales (computed dynamically) - scales_b=weight_scale, # Weight scales - out_dtype=input_2d.dtype, + weight_scale=weight_scale, # Weight scales + group_size=self.act_quant_group_shape.col, + use_deep_gemm_e8m0=self.use_deep_gemm_e8m0, ) + return output def _dispatch_w8a8_blockscale_op( self, @@ -457,22 +529,6 @@ class W8A8BlockFp8LinearOp: ], QuantFP8 | None, ]: - # Prefer FlashInfer on SM90+ if available (Hopper optimized) - if ( - has_flashinfer_block_gemm() - and envs.VLLM_USE_FLASHINFER_FP8_LINEAR - and not use_aiter_and_is_supported - ): - logger.info_once("Using FlashInfer FP8 block-scale GEMM for linear layers") - return self._run_flashinfer, ( - QuantFP8( - False, - self.act_quant_group_shape, - column_major_scales=False, - use_ue8m0=False, - ) - ) - if use_cutlass: return self._run_cutlass, ( QuantFP8( diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index df8cc2468c605..1b01c39cc68a5 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -540,27 +540,52 @@ def flashinfer_scaled_fp8_mm( return output +flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper( + "flashinfer.gemm", "fp8_blockscale_gemm_sm90" +) + + @functools.cache -def has_flashinfer_block_gemm() -> bool: - """Return `True` if FlashInfer FP8 block-scale GEMM is available.""" - if not has_flashinfer(): - return False - if not current_platform.is_cuda(): - return False - # Only SM90+ (Hopper) supports this kernel - if not current_platform.is_device_capability(90): +def has_flashinfer_fp8_blockscale_gemm() -> bool: + """Return `True` if FlashInfer block-scale FP8 GEMM is available.""" + return has_flashinfer() and hasattr( + _get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90" + ) + + +@functools.cache +def is_flashinfer_fp8_blockscale_gemm_supported() -> bool: + """Return `True` if FlashInfer block-scale FP8 GEMM is supported.""" + return envs.VLLM_USE_FLASHINFER_FP8_LINEAR and has_flashinfer_fp8_blockscale_gemm() + + +def should_use_flashinfer_for_block_scale_fp8_linear( + is_flashinfer_supported: bool, + output_dtype: torch.dtype, + input: torch.Tensor, + weight: torch.Tensor, +): + if not is_flashinfer_supported: return False - try: - import flashinfer + # Verify DeepGEMM N/K dims requirements + # NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul + # test inside kernels/quatization/test_block_fp8.py + N_MULTIPLE = 64 + K_MULTIPLE = 128 - # Check if the module has the required binding - return hasattr(flashinfer, "Fp8BlockScaleGemmRunner") - except (ImportError, AttributeError): - logger.debug_once( - "FlashInfer block-scale GEMM not available: module or binding not found" - ) - return False + weight_dtype = weight.dtype + input_dtype = input.dtype + + should_use_flashinfer = ( + output_dtype == torch.bfloat16 + and input_dtype == torch.bfloat16 + and weight_dtype == torch.float8_e4m3fn + and weight.shape[0] % N_MULTIPLE == 0 + and weight.shape[1] % K_MULTIPLE == 0 + ) + + return should_use_flashinfer __all__ = [ @@ -579,11 +604,14 @@ __all__ = [ "has_flashinfer_all2all", "has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutedsl_grouped_gemm_nt_masked", + "has_flashinfer_fp8_blockscale_gemm", "has_nvidia_artifactory", "supports_trtllm_attention", "can_use_trtllm_attention", "use_trtllm_attention", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", - "has_flashinfer_block_gemm", + "flashinfer_fp8_blockscale_gemm", + "should_use_flashinfer_for_block_scale_fp8_linear", + "is_flashinfer_fp8_blockscale_gemm_supported", ] From e019391cd85bc474d38b4a590a3bbe297cdcddeb Mon Sep 17 00:00:00 2001 From: Jhao-Ting Chen Date: Wed, 24 Dec 2025 10:05:35 -0800 Subject: [PATCH 3/3] refine commit, polish PR Signed-off-by: Jhao-Ting Chen --- tests/kernels/quantization/test_block_fp8.py | 11 ++++------- vllm/envs.py | 6 +++--- .../layers/quantization/utils/fp8_utils.py | 17 ++++++++++++----- vllm/utils/flashinfer.py | 15 ++++++++++----- 4 files changed, 29 insertions(+), 20 deletions(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 01789abaa44e0..bd4a737ca6300 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -211,6 +211,10 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 +@pytest.mark.skipif( + current_platform.is_fp8_fnuz(), + reason="This platform supports e4m3fnuz, not e4m3fn.", +) @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), @@ -239,13 +243,6 @@ def test_w8a8_block_fp8_flashinfer_matmul(M, N, K, block_size, out_dtype, seed): Bs = Bs_fp8.to(torch.float32) ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - As_fp8 = get_col_major_tma_aligned_tensor(As_fp8) - - # Transpose earlier so that the testing will not trigger transposing kernels - - assert As_fp8.shape == (M, (K + 127) // 128), ( - f"{As_fp8.shape} != {(M, (K + 127) // 128)}" - ) out = flashinfer_fp8_blockscale_gemm( input=A_bf16, diff --git a/vllm/envs.py b/vllm/envs.py index 9a9140aa0270d..9f595df66b2c5 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -168,7 +168,7 @@ if TYPE_CHECKING: "relax", ] = "relax" VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True - VLLM_USE_FLASHINFER_FP8_LINEAR: bool = False + VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER: bool = False VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False @@ -1211,8 +1211,8 @@ environment_variables: dict[str, Callable[[], Any]] = { ), # Allow use of FlashInfer FP8 block-scale GEMM for linear layers. # This uses TensorRT-LLM kernels and requires SM90+ (Hopper). - "VLLM_USE_FLASHINFER_FP8_LINEAR": lambda: bool( - int(os.getenv("VLLM_USE_FLASHINFER_FP8_LINEAR", "0")) + "VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER": lambda: bool( + int(os.getenv("VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER", "0")) ), # Allow use of FlashInfer MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool( diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 5ebb7395a3cc7..73f4a793503d7 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -38,7 +38,7 @@ from vllm.utils.deep_gemm import ( from vllm.utils.flashinfer import ( flashinfer_fp8_blockscale_gemm, is_flashinfer_fp8_blockscale_gemm_supported, - should_use_flashinfer_for_block_scale_fp8_linear, + should_use_flashinfer_for_blockscale_fp8_gemm, ) from vllm.utils.torch_utils import direct_register_custom_op @@ -238,7 +238,7 @@ def _flashinfer_fp8_blockscale_gemm_impl( group_size: int, use_deep_gemm_e8m0: bool, ) -> torch.Tensor: - def use_flashinfer( + def use_flashinfer_deepgemm_swapAB( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, @@ -274,11 +274,18 @@ def _flashinfer_fp8_blockscale_gemm_impl( ) return output + # there is only no benefit of using FlashInfer DeepGEMM for higher batch sizes since + # the swapAB optimization is only effective for small batch sizes. + # there is slight accuracy loss when using FlashInfer blockscale gemm for all batch + # sizes for DeepSeek-V3. condition = input.shape[0] < 32 - # Pass all required variables through operands + # torch.cond for torch compile compatibility return torch.cond( - condition, use_flashinfer, use_deepgemm, (input, weight, weight_scale) + condition, + use_flashinfer_deepgemm_swapAB, + use_deepgemm, + (input, weight, weight_scale), ) @@ -357,7 +364,7 @@ class W8A8BlockFp8LinearOp: output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if should_use_flashinfer_for_block_scale_fp8_linear( + if should_use_flashinfer_for_blockscale_fp8_gemm( self.is_flashinfer_supported, output_dtype, input_2d, weight ): output = self._run_flashinfer(input_2d, weight, weight_scale) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 1b01c39cc68a5..0804add2343de 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -548,18 +548,23 @@ flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper( @functools.cache def has_flashinfer_fp8_blockscale_gemm() -> bool: """Return `True` if FlashInfer block-scale FP8 GEMM is available.""" - return has_flashinfer() and hasattr( - _get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90" + return ( + has_flashinfer() + and current_platform.is_device_capability(90) + and hasattr(_get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90") ) @functools.cache def is_flashinfer_fp8_blockscale_gemm_supported() -> bool: """Return `True` if FlashInfer block-scale FP8 GEMM is supported.""" - return envs.VLLM_USE_FLASHINFER_FP8_LINEAR and has_flashinfer_fp8_blockscale_gemm() + return ( + envs.VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER + and has_flashinfer_fp8_blockscale_gemm() + ) -def should_use_flashinfer_for_block_scale_fp8_linear( +def should_use_flashinfer_for_blockscale_fp8_gemm( is_flashinfer_supported: bool, output_dtype: torch.dtype, input: torch.Tensor, @@ -612,6 +617,6 @@ __all__ = [ "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", "flashinfer_fp8_blockscale_gemm", - "should_use_flashinfer_for_block_scale_fp8_linear", + "should_use_flashinfer_for_blockscale_fp8_gemm", "is_flashinfer_fp8_blockscale_gemm_supported", ]