diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 01789abaa44e0..bd4a737ca6300 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -211,6 +211,10 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 +@pytest.mark.skipif( + current_platform.is_fp8_fnuz(), + reason="This platform supports e4m3fnuz, not e4m3fn.", +) @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), @@ -239,13 +243,6 @@ def test_w8a8_block_fp8_flashinfer_matmul(M, N, K, block_size, out_dtype, seed): Bs = Bs_fp8.to(torch.float32) ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - As_fp8 = get_col_major_tma_aligned_tensor(As_fp8) - - # Transpose earlier so that the testing will not trigger transposing kernels - - assert As_fp8.shape == (M, (K + 127) // 128), ( - f"{As_fp8.shape} != {(M, (K + 127) // 128)}" - ) out = flashinfer_fp8_blockscale_gemm( input=A_bf16, diff --git a/vllm/envs.py b/vllm/envs.py index 9a9140aa0270d..9f595df66b2c5 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -168,7 +168,7 @@ if TYPE_CHECKING: "relax", ] = "relax" VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True - VLLM_USE_FLASHINFER_FP8_LINEAR: bool = False + VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER: bool = False VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False @@ -1211,8 +1211,8 @@ environment_variables: dict[str, Callable[[], Any]] = { ), # Allow use of FlashInfer FP8 block-scale GEMM for linear layers. # This uses TensorRT-LLM kernels and requires SM90+ (Hopper). - "VLLM_USE_FLASHINFER_FP8_LINEAR": lambda: bool( - int(os.getenv("VLLM_USE_FLASHINFER_FP8_LINEAR", "0")) + "VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER": lambda: bool( + int(os.getenv("VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER", "0")) ), # Allow use of FlashInfer MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool( diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 5ebb7395a3cc7..73f4a793503d7 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -38,7 +38,7 @@ from vllm.utils.deep_gemm import ( from vllm.utils.flashinfer import ( flashinfer_fp8_blockscale_gemm, is_flashinfer_fp8_blockscale_gemm_supported, - should_use_flashinfer_for_block_scale_fp8_linear, + should_use_flashinfer_for_blockscale_fp8_gemm, ) from vllm.utils.torch_utils import direct_register_custom_op @@ -238,7 +238,7 @@ def _flashinfer_fp8_blockscale_gemm_impl( group_size: int, use_deep_gemm_e8m0: bool, ) -> torch.Tensor: - def use_flashinfer( + def use_flashinfer_deepgemm_swapAB( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, @@ -274,11 +274,18 @@ def _flashinfer_fp8_blockscale_gemm_impl( ) return output + # there is only no benefit of using FlashInfer DeepGEMM for higher batch sizes since + # the swapAB optimization is only effective for small batch sizes. + # there is slight accuracy loss when using FlashInfer blockscale gemm for all batch + # sizes for DeepSeek-V3. condition = input.shape[0] < 32 - # Pass all required variables through operands + # torch.cond for torch compile compatibility return torch.cond( - condition, use_flashinfer, use_deepgemm, (input, weight, weight_scale) + condition, + use_flashinfer_deepgemm_swapAB, + use_deepgemm, + (input, weight, weight_scale), ) @@ -357,7 +364,7 @@ class W8A8BlockFp8LinearOp: output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if should_use_flashinfer_for_block_scale_fp8_linear( + if should_use_flashinfer_for_blockscale_fp8_gemm( self.is_flashinfer_supported, output_dtype, input_2d, weight ): output = self._run_flashinfer(input_2d, weight, weight_scale) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 1b01c39cc68a5..0804add2343de 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -548,18 +548,23 @@ flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper( @functools.cache def has_flashinfer_fp8_blockscale_gemm() -> bool: """Return `True` if FlashInfer block-scale FP8 GEMM is available.""" - return has_flashinfer() and hasattr( - _get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90" + return ( + has_flashinfer() + and current_platform.is_device_capability(90) + and hasattr(_get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90") ) @functools.cache def is_flashinfer_fp8_blockscale_gemm_supported() -> bool: """Return `True` if FlashInfer block-scale FP8 GEMM is supported.""" - return envs.VLLM_USE_FLASHINFER_FP8_LINEAR and has_flashinfer_fp8_blockscale_gemm() + return ( + envs.VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER + and has_flashinfer_fp8_blockscale_gemm() + ) -def should_use_flashinfer_for_block_scale_fp8_linear( +def should_use_flashinfer_for_blockscale_fp8_gemm( is_flashinfer_supported: bool, output_dtype: torch.dtype, input: torch.Tensor, @@ -612,6 +617,6 @@ __all__ = [ "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", "flashinfer_fp8_blockscale_gemm", - "should_use_flashinfer_for_block_scale_fp8_linear", + "should_use_flashinfer_for_blockscale_fp8_gemm", "is_flashinfer_fp8_blockscale_gemm_supported", ]