diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py index 0b3903a0c0bc5..11e3ac7f0c1fa 100644 --- a/benchmarks/kernels/bench_block_fp8_gemm.py +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -3,10 +3,8 @@ import os -# Disable DeepGEMM for this benchmark +# Disable DeepGEMM for this benchmark to use CUTLASS os.environ["VLLM_USE_DEEP_GEMM"] = "0" -# Enable FlashInfer FP8 linear (will be used when provider="flashinfer-block-fp8") -os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1" import torch @@ -96,15 +94,6 @@ plot_title = "BF16 vs W8A8 Block FP8 GEMMs" if CUTLASS_BLOCK_FP8_SUPPORTED: available_providers.append("w8a8-block-fp8-cutlass") -# Check if FlashInfer block GEMM is available -try: - from vllm.utils.flashinfer import has_flashinfer_block_gemm - - if has_flashinfer_block_gemm(): - available_providers.append("flashinfer-block-fp8") -except ImportError: - pass - @vllm_triton.testing.perf_report( vllm_triton.testing.Benchmark( @@ -145,14 +134,6 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)): ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( lambda: run_w8a8_cutlass(), quantiles=quantiles ) - elif provider == "flashinfer-block-fp8": - # Use the same W8A8 setup as other providers for fair comparison - run_w8a8_flashinfer = build_w8a8_block_fp8_runner( - M, N, K, block_size, device, use_cutlass=False - ) - ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( - lambda: run_w8a8_flashinfer(), quantiles=quantiles - ) else: raise ValueError(f"Unknown provider: {provider}") diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 3005f9715f404..01789abaa44e0 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -24,6 +24,10 @@ from vllm.utils.deep_gemm import ( per_block_cast_to_fp8, should_use_deepgemm_for_fp8_linear, ) +from vllm.utils.flashinfer import ( + flashinfer_fp8_blockscale_gemm, + has_flashinfer_fp8_blockscale_gemm, +) from vllm.utils.import_utils import has_deep_gemm if current_platform.get_device_capability() < (9, 0): @@ -212,237 +216,46 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), ) @torch.inference_mode() -def test_flashinfer_block_gemm_matmul(M, N, K, block_size, out_dtype, seed): - """ - Test FlashInfer FP8 block-scale GEMM through W8A8BlockFp8LinearOp. - - This tests the FP8 + FP8 → BF16 path (W8A8 full quantization). - Matches TensorRT-LLM's test_fp8_block_scale_gemm behavior. - """ - import os - - from vllm.utils.flashinfer import has_flashinfer_block_gemm - - if not has_flashinfer_block_gemm(): +def test_w8a8_block_fp8_flashinfer_matmul(M, N, K, block_size, out_dtype, seed): + if not has_flashinfer_fp8_blockscale_gemm(): pytest.skip( "FlashInfer block GEMM not available (requires SM90+ and FlashInfer)" ) - - # Skip tests for dimensions that don't have pre-compiled kernels in FlashInfer - # These cause CUDA runtime errors - if K == 3884 or N == 7748: - pytest.skip(f"FlashInfer does not have pre-compiled kernels for K={K} or N={N}") - - # Enable FlashInfer backend (required for W8A8BlockFp8LinearOp to use FlashInfer) - os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1" - # Reload envs module to pick up the env var change - import importlib - - from vllm import envs - - importlib.reload(envs) - - from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, - ) - from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, - ) + # only aligned sizes + if K % 128 != 0 or N % 64 != 0: + pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") torch.manual_seed(seed) + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max = fp8_info.max - # Create BF16 inputs (normalized like TRT-LLM) - A_bf16 = torch.randn(M, K, dtype=torch.bfloat16) / K - B_bf16 = torch.randn(N, K, dtype=torch.bfloat16) / K + A_bf16 = (torch.rand(M, K, dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + B_bf16 = (torch.rand(N, K, dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - # Quantize weight with per-block scales - B_fp8, Bs = per_block_cast_to_fp8(B_bf16, block_size=block_size) + A_fp8, As_fp8 = per_token_group_quant_fp8(A_bf16, block_size[1], use_ue8m0=False) + B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_bf16, block_size, use_ue8m0=False) - # Create W8A8BlockFp8LinearOp to handle input quantization - block_n, block_k = block_size[0], block_size[1] - weight_group_shape = GroupShape(block_n, block_k) - act_quant_group_shape = GroupShape(1, block_k) # Per-token quantization + As = As_fp8.to(torch.float32) + Bs = Bs_fp8.to(torch.float32) - linear_op = W8A8BlockFp8LinearOp( - weight_group_shape=weight_group_shape, - act_quant_group_shape=act_quant_group_shape, - cutlass_block_fp8_supported=False, # Disable CUTLASS - use_aiter_and_is_supported=False, # Disable AITER + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + As_fp8 = get_col_major_tma_aligned_tensor(As_fp8) + + # Transpose earlier so that the testing will not trigger transposing kernels + + assert As_fp8.shape == (M, (K + 127) // 128), ( + f"{As_fp8.shape} != {(M, (K + 127) // 128)}" ) - # Verify FlashInfer backend is selected - assert linear_op.w8a8_blockscale_op == linear_op._run_flashinfer, ( - "FlashInfer backend not selected! " - "Make sure VLLM_USE_FLASHINFER_FP8_LINEAR=1 is set before running tests." - ) - - # Compute reference: BF16 × BF16 matmul (before quantization) - ref_out = torch.matmul(A_bf16, B_bf16.T) - - # Run W8A8 FlashInfer GEMM (input will be quantized internally) - out = linear_op.apply( + out = flashinfer_fp8_blockscale_gemm( input=A_bf16, weight=B_fp8, + input_scale=None, weight_scale=Bs, - input_scale=None, # Will quantize dynamically - bias=None, + out_dtype=out_dtype, ) - # Compare results using TensorRT-LLM's calc_diff metric - # This measures normalized similarity: sim = 2* / (||x||² + ||y||²) - out_fp64 = out.to(torch.float64) - ref_fp64 = ref_out.to(torch.float64) - denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum() - sim = 2 * (out_fp64 * ref_fp64).sum() / denominator - diff = 1 - sim - - # W8A8 threshold from TensorRT-LLM: diff < 0.001 (99.9% similarity) - assert diff < 0.001, ( - f"Similarity difference {diff:.6f} exceeds threshold (similarity: {sim:.6f})" - ) - - -@pytest.mark.parametrize( - "M,N,K,block_size,seed", - [ - (1, 1024, 4096, [128, 128], 0), - (32, 4096, 512, [128, 128], 0), - (128, 1024, 4096, [128, 128], 0), - ], -) -@pytest.mark.parametrize( - "input_dtype,weight_dtype", - [ - (torch.bfloat16, torch.bfloat16), # BF16 + BF16 (internal quantization) - (torch.bfloat16, torch.float8_e4m3fn), # BF16 + FP8 (weight-only) - (torch.float8_e4m3fn, torch.float8_e4m3fn), # FP8 + FP8 (W8A8) - ], -) -@torch.inference_mode() -def test_flashinfer_block_gemm_dtypes( - M, N, K, block_size, input_dtype, weight_dtype, seed -): - """ - Test all three supported dtype combinations for FlashInfer FP8 block-scale GEMM. - - Tests: - - BF16 + BF16 → BF16: Both inputs BF16, internal quantization - - BF16 + FP8 → BF16: Weight-only quantization - - FP8 + FP8 → BF16: W8A8 full quantization - - This mirrors FlashInfer's own test_fp8_blockscale_gemm_dtypes and TRT-LLM's tests. - """ - from vllm.utils.flashinfer import has_flashinfer_block_gemm - - if not has_flashinfer_block_gemm(): - pytest.skip( - "FlashInfer block GEMM not available (requires SM90+ and FlashInfer)" - ) - - from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import ( - flashinfer_block_gemm, - ) - - # Add debug output to verify test execution - print(f"\n{'=' * 80}") - print(f"TEST: M={M}, N={N}, K={K} | Input: {input_dtype}, Weight: {weight_dtype}") - print(f"{'=' * 80}") - - torch.manual_seed(seed) - - # Create BF16 data for reference (same as FlashInfer tests) - input_bf16 = torch.randn(M, K, dtype=torch.bfloat16) - weight_bf16 = torch.randn(N, K, dtype=torch.bfloat16) - - # Quantize input based on dtype - if input_dtype == torch.float8_e4m3fn: - input_tensor, input_scale = per_token_group_quant_fp8(input_bf16, block_size[1]) - else: - input_tensor, input_scale = input_bf16, None - - # Quantize weight based on dtype - if weight_dtype == torch.float8_e4m3fn: - weight_tensor, weight_scale = per_block_cast_to_fp8( - weight_bf16, block_size=block_size - ) - else: - weight_tensor, weight_scale = weight_bf16, None - - # Run FlashInfer FP8 block-scale GEMM - output = flashinfer_block_gemm( - input=input_tensor, - weight=weight_tensor, - scales_a=input_scale, - scales_b=weight_scale, - out_dtype=torch.bfloat16, - ) - - # Verify output properties - assert output.shape == (M, N), f"Expected shape {(M, N)}, got {output.shape}" - assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}" - - # Compute reference based on dtype combination - if input_dtype == torch.float8_e4m3fn and weight_dtype == torch.float8_e4m3fn: - # W8A8: Compare against dequantized FP8 reference (tests kernel correctness) - block_n, block_k = block_size[0], block_size[1] - k_tiles = (K + block_k - 1) // block_k - n_tiles = (N + block_n - 1) // block_n - - input_dequant = torch.zeros_like(input_bf16) - for i in range(M): - for k_tile in range(k_tiles): - start, end = k_tile * block_k, min((k_tile + 1) * block_k, K) - input_dequant[i, start:end] = ( - input_tensor[i, start:end].to(torch.bfloat16) - * input_scale[i, k_tile] - ) - - weight_dequant = torch.zeros_like(weight_bf16) - for j in range(N): - for k_tile in range(k_tiles): - start, end = k_tile * block_k, min((k_tile + 1) * block_k, K) - weight_dequant[j, start:end] = ( - weight_tensor[j, start:end].to(torch.bfloat16) - * weight_scale[j // block_n, k_tile] - ) - - reference = torch.matmul(input_dequant, weight_dequant.T) - - # W8A8: Use TRT-LLM's calc_diff metric with strict threshold - out_fp64 = output.to(torch.float64) - ref_fp64 = reference.to(torch.float64) - denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum() - sim = 2 * (out_fp64 * ref_fp64).sum() / denominator - diff = 1 - sim - - # W8A8 achieves very high accuracy: diff < 0.001 (99.9% similarity) - assert diff < 0.001, ( - f"W8A8 similarity difference {diff:.6f} too high (expected < 0.001, similarity: {sim:.6f})" - ) - else: - # BF16+BF16 or BF16+FP8: Compare against original BF16 reference - reference = torch.matmul(input_bf16, weight_bf16.T) - - out_fp64 = output.to(torch.float64) - ref_fp64 = reference.to(torch.float64) - denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum() - sim = 2 * (out_fp64 * ref_fp64).sum() / denominator - diff = 1 - sim - - if input_dtype == torch.bfloat16 and weight_dtype == torch.bfloat16: - # BF16+BF16: Highest accuracy (internal quantization) - threshold = 0.001 - threshold_desc = "0.1%" - elif input_dtype == torch.bfloat16 and weight_dtype == torch.float8_e4m3fn: - # BF16+FP8: Weight-only quantization, higher error - threshold = 0.01 - threshold_desc = "1%" - else: - # Other combinations - threshold = 0.01 - threshold_desc = "1%" - - assert diff < threshold, ( - f"Similarity difference {diff:.6f} too high for " - f"{input_dtype} + {weight_dtype} (expected < {threshold_desc}, similarity: {sim:.6f})" - ) + rel_diff = torch.mean( + torch.abs(out.to(torch.bfloat16) - ref_out.to(torch.bfloat16)) + ) / torch.mean(torch.abs(ref_out.to(torch.bfloat16))) + assert rel_diff < 0.001 diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py b/vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py deleted file mode 100644 index 04fb00ef5c8b8..0000000000000 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_block_gemm.py +++ /dev/null @@ -1,57 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -""" -FlashInfer FP8 Block-Scale GEMM wrapper for vLLM. - -This module provides a thin wrapper around FlashInfer's FP8 block-scale GEMM -implementation, which uses TensorRT-LLM's optimized kernels for NVIDIA Hopper (SM90+). -""" - -import torch - - -def flashinfer_block_gemm( - input: torch.Tensor, - weight: torch.Tensor, - scales_a: torch.Tensor | None, - scales_b: torch.Tensor, - out_dtype: torch.dtype, -) -> torch.Tensor: - """ - Wrapper for FlashInfer's FP8 block-scale GEMM. - - Computes: output = (input @ weight.T) with per-block scaling for quantization. - - Supports three modes: - 1. BF16 + BF16 → BF16: Both inputs BF16, internal quantization (scales_a=None, scales_b used internally) - 2. BF16 + FP8 → BF16: Weight-only quantization (scales_a=None, scales_b for weight) - 3. FP8 + FP8 → BF16: W8A8 full quantization (scales_a for input, scales_b for weight) - - Args: - input: Input tensor (M, K) - BF16 or FP8 e4m3 - weight: Weight tensor (N, K) - BF16 or FP8 e4m3 - scales_a: Input scales (M, K//block_k) or None - FP32 - None: input is BF16 (will be quantized internally for BF16+BF16 or left as-is for BF16+FP8) - Provided: input is pre-quantized FP8 (W8A8 mode) - scales_b: Weight scales (N//block_n, K//block_k) - FP32 - out_dtype: Output dtype (typically torch.bfloat16) - - Returns: - output: Result tensor (M, N) in out_dtype - - Note: - - Requires SM90+ GPU (NVIDIA Hopper) - - Uses TensorRT-LLM's optimized CUTLASS kernels via FlashInfer - - For M < 32, automatically uses SwapAB kernel optimization - """ - from flashinfer.gemm import fp8_blockscale_gemm_swapab - - return fp8_blockscale_gemm_swapab( - input=input, - weight=weight, - input_scale=scales_a, - weight_scale=scales_b, - out_dtype=out_dtype, - ) - diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index f893d7e53cd16..5ebb7395a3cc7 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -35,7 +35,11 @@ from vllm.utils.deep_gemm import ( should_use_deepgemm_for_fp8_linear, transform_sf_into_required_layout, ) -from vllm.utils.flashinfer import has_flashinfer_block_gemm +from vllm.utils.flashinfer import ( + flashinfer_fp8_blockscale_gemm, + is_flashinfer_fp8_blockscale_gemm_supported, + should_use_flashinfer_for_block_scale_fp8_linear, +) from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -227,6 +231,76 @@ direct_register_custom_op( ) +def _flashinfer_fp8_blockscale_gemm_impl( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + group_size: int, + use_deep_gemm_e8m0: bool, +) -> torch.Tensor: + def use_flashinfer( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + return flashinfer_fp8_blockscale_gemm( + input=input, + weight=weight, + weight_scale=weight_scale, + out_dtype=torch.bfloat16, + ) + + def use_deepgemm( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + q_input, input_scale = per_token_group_quant_fp8( + input, + group_size=group_size, + column_major_scales=True, + use_ue8m0=use_deep_gemm_e8m0, + ) + output = torch.empty( + (q_input.shape[0], weight.shape[0]), + dtype=torch.bfloat16, + device=q_input.device, + ) + fp8_gemm_nt( + (q_input, input_scale), + (weight, weight_scale), + output, + is_deep_gemm_e8m0_used=use_deep_gemm_e8m0, + ) + return output + + condition = input.shape[0] < 32 + + # Pass all required variables through operands + return torch.cond( + condition, use_flashinfer, use_deepgemm, (input, weight, weight_scale) + ) + + +def _flashinfer_fp8_blockscale_gemm_fake( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + group_size: int, + use_deep_gemm_e8m0: bool, +) -> torch.Tensor: + return torch.empty( + input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device + ) + + +direct_register_custom_op( + "flashinfer_fp8_blockscale_gemm", + _flashinfer_fp8_blockscale_gemm_impl, + fake_impl=_flashinfer_fp8_blockscale_gemm_fake, +) + + # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 class W8A8BlockFp8LinearOp: @@ -247,6 +321,7 @@ class W8A8BlockFp8LinearOp: self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_hopper = current_platform.is_device_capability(90) self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() + self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported() # Get the correct blockscale mul and input quant operations. # We can't use _dispatch_w8a8_blockscale_op to figure out if we want @@ -282,7 +357,12 @@ class W8A8BlockFp8LinearOp: output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if should_use_deepgemm_for_fp8_linear( + if should_use_flashinfer_for_block_scale_fp8_linear( + self.is_flashinfer_supported, output_dtype, input_2d, weight + ): + output = self._run_flashinfer(input_2d, weight, weight_scale) + + elif should_use_deepgemm_for_fp8_linear( output_dtype, weight, self.is_deep_gemm_supported ): output = self._run_deepgemm(input_2d, weight, weight_scale) @@ -415,7 +495,6 @@ class W8A8BlockFp8LinearOp: input_2d: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, - input_scale: torch.Tensor | None = None, ) -> torch.Tensor: """ Run FlashInfer FP8 block-scale GEMM. @@ -423,23 +502,16 @@ class W8A8BlockFp8LinearOp: This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper). """ - from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import ( - flashinfer_block_gemm, - ) - - # Quantize input dynamically if not pre-quantized (same as CUTLASS) - assert input_scale is None - assert self.input_quant_op is not None - q_input, input_scale = self.input_quant_op(input_2d) - - # Now call FlashInfer with FP8 input + FP8 weight (W8A8) - return flashinfer_block_gemm( - input=q_input, # FP8 quantized input + # Now call FlashInfer with BF16 input + FP8 weight, input will be + # quantized with FlashInfer kernel (W8A8) + output = torch.ops.vllm.flashinfer_fp8_blockscale_gemm( + input=input_2d, # BF16 input weight=weight, # FP8 weight - scales_a=input_scale, # Input scales (computed dynamically) - scales_b=weight_scale, # Weight scales - out_dtype=input_2d.dtype, + weight_scale=weight_scale, # Weight scales + group_size=self.act_quant_group_shape.col, + use_deep_gemm_e8m0=self.use_deep_gemm_e8m0, ) + return output def _dispatch_w8a8_blockscale_op( self, @@ -457,22 +529,6 @@ class W8A8BlockFp8LinearOp: ], QuantFP8 | None, ]: - # Prefer FlashInfer on SM90+ if available (Hopper optimized) - if ( - has_flashinfer_block_gemm() - and envs.VLLM_USE_FLASHINFER_FP8_LINEAR - and not use_aiter_and_is_supported - ): - logger.info_once("Using FlashInfer FP8 block-scale GEMM for linear layers") - return self._run_flashinfer, ( - QuantFP8( - False, - self.act_quant_group_shape, - column_major_scales=False, - use_ue8m0=False, - ) - ) - if use_cutlass: return self._run_cutlass, ( QuantFP8( diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index df8cc2468c605..1b01c39cc68a5 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -540,27 +540,52 @@ def flashinfer_scaled_fp8_mm( return output +flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper( + "flashinfer.gemm", "fp8_blockscale_gemm_sm90" +) + + @functools.cache -def has_flashinfer_block_gemm() -> bool: - """Return `True` if FlashInfer FP8 block-scale GEMM is available.""" - if not has_flashinfer(): - return False - if not current_platform.is_cuda(): - return False - # Only SM90+ (Hopper) supports this kernel - if not current_platform.is_device_capability(90): +def has_flashinfer_fp8_blockscale_gemm() -> bool: + """Return `True` if FlashInfer block-scale FP8 GEMM is available.""" + return has_flashinfer() and hasattr( + _get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90" + ) + + +@functools.cache +def is_flashinfer_fp8_blockscale_gemm_supported() -> bool: + """Return `True` if FlashInfer block-scale FP8 GEMM is supported.""" + return envs.VLLM_USE_FLASHINFER_FP8_LINEAR and has_flashinfer_fp8_blockscale_gemm() + + +def should_use_flashinfer_for_block_scale_fp8_linear( + is_flashinfer_supported: bool, + output_dtype: torch.dtype, + input: torch.Tensor, + weight: torch.Tensor, +): + if not is_flashinfer_supported: return False - try: - import flashinfer + # Verify DeepGEMM N/K dims requirements + # NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul + # test inside kernels/quatization/test_block_fp8.py + N_MULTIPLE = 64 + K_MULTIPLE = 128 - # Check if the module has the required binding - return hasattr(flashinfer, "Fp8BlockScaleGemmRunner") - except (ImportError, AttributeError): - logger.debug_once( - "FlashInfer block-scale GEMM not available: module or binding not found" - ) - return False + weight_dtype = weight.dtype + input_dtype = input.dtype + + should_use_flashinfer = ( + output_dtype == torch.bfloat16 + and input_dtype == torch.bfloat16 + and weight_dtype == torch.float8_e4m3fn + and weight.shape[0] % N_MULTIPLE == 0 + and weight.shape[1] % K_MULTIPLE == 0 + ) + + return should_use_flashinfer __all__ = [ @@ -579,11 +604,14 @@ __all__ = [ "has_flashinfer_all2all", "has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutedsl_grouped_gemm_nt_masked", + "has_flashinfer_fp8_blockscale_gemm", "has_nvidia_artifactory", "supports_trtllm_attention", "can_use_trtllm_attention", "use_trtllm_attention", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", - "has_flashinfer_block_gemm", + "flashinfer_fp8_blockscale_gemm", + "should_use_flashinfer_for_block_scale_fp8_linear", + "is_flashinfer_fp8_blockscale_gemm_supported", ]