From de92d916fe8a897b00a8adb0aab9ed9ec99f2b6c Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Wed, 15 Oct 2025 10:53:00 -0700 Subject: [PATCH] [NVIDIA] Add support for cudnn fp4 gemm via flashinfer (#26107) Signed-off-by: kaixih Signed-off-by: mgoin Co-authored-by: mgoin --- vllm/envs.py | 17 +++++--- .../schemes/compressed_tensors_w4a4_nvfp4.py | 40 ++++++++++++------- .../layers/quantization/modelopt.py | 38 ++++++++++-------- 3 files changed, 57 insertions(+), 38 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index b5c7f325f670d..cb3dab51eff4d 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -191,6 +191,7 @@ if TYPE_CHECKING: VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_USE_TRTLLM_ATTENTION: str | None = None + VLLM_NVFP4_GEMM_BACKEND: str | None = None VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False @@ -1292,11 +1293,15 @@ environment_variables: dict[str, Callable[[], Any]] = { # If set, it means we pre-downloaded cubin files and flashinfer will # read the cubin files directly. "VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False), - # If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer. - # Otherwise, uses the first available of: flashinfer cutlass GEMM, - # vllm cutlass GEMM, marlin GEMM. - "VLLM_USE_TRTLLM_FP4_GEMM": lambda: bool( - int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0")) + # Supported options: + # - "flashinfer-cudnn": use flashinfer cudnn GEMM backend + # - "flashinfer-trtllm": use flashinfer trtllm GEMM backend + # - "flashinfer-cutlass": use flashinfer cutlass GEMM backend + # - : automatically pick an available backend + "VLLM_NVFP4_GEMM_BACKEND": env_with_choices( + "VLLM_NVFP4_GEMM_BACKEND", + None, + ["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"], ), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. @@ -1492,7 +1497,6 @@ def compute_hash() -> str: "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", "VLLM_USE_DEEP_GEMM_E8M0", - "VLLM_USE_TRTLLM_FP4_GEMM", "VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FLASHINFER_MOE_FP16", "VLLM_USE_FLASHINFER_MOE_FP8", @@ -1524,6 +1528,7 @@ def compute_hash() -> str: "VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", + "VLLM_NVFP4_GEMM_BACKEND", "VLLM_USE_FBGEMM", ] for key in environment_variables_to_hash: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 192661c5b7ece..4127cd2d574bd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -14,7 +14,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 run_nvfp4_emulations, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + cutlass_fp4_supported, + swizzle_blockscale, +) from vllm.model_executor.parameter import ( GroupQuantScaleParameter, ModelWeightParameter, @@ -29,10 +32,12 @@ __all__ = ["CompressedTensorsW4A4Fp4"] class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): def __init__(self): - if envs.VLLM_USE_TRTLLM_FP4_GEMM: - assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" - self.backend = "flashinfer-trtllm" - logger.info_once("Using flashinfer-trtllm for FP4") + self.backend = "none" + if envs.VLLM_NVFP4_GEMM_BACKEND is None: + if has_flashinfer(): + self.backend = "flashinfer-cutlass" + elif cutlass_fp4_supported(): + self.backend = "cutlass" elif envs.VLLM_USE_FBGEMM: self.backend = "fbgemm" try: @@ -42,12 +47,17 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): "Backend fbgemm requires fbgemm.f4f4bf16 operator, " "Please install with: pip install fbgemm-gpu-genai" ) from exc - logger.info_once("Using FGBEMM-GPU-GENAI for FP4") - elif has_flashinfer(): - self.backend = "flashinfer-cutlass" - logger.info_once("Using flashinfer-cutlass for FP4") - else: - self.backend = "cutlass" + elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"): + self.backend = envs.VLLM_NVFP4_GEMM_BACKEND + assert has_flashinfer(), f"FlashInfer is required for {self.backend}" + + if self.backend == "none": + raise ValueError( + "No valid NVFP4 GEMM backend found. " + "Please check your platform capability." + ) + + logger.info_once(f"Using {self.backend} for NVFP4 GEMM") self.group_size = 16 @classmethod @@ -184,10 +194,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): layer.alpha, output_dtype, ) - if self.backend == "flashinfer-trtllm": - out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") - elif self.backend == "flashinfer-cutlass": - out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + if self.backend.startswith("flashinfer-"): + backend_name = self.backend[len("flashinfer-") :] + out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) elif self.backend == "fbgemm": out = torch.ops.fbgemm.f4f4bf16( x_fp4, @@ -198,6 +207,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): use_mx=False, ).to(output_dtype) else: + assert self.backend == "cutlass" out = cutlass_scaled_fp4_mm(*mm_args) if bias is not None: diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 79bf8109b8fd2..41f82de4ff0a6 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -926,22 +926,26 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config - if envs.VLLM_USE_TRTLLM_FP4_GEMM: - assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" - self.backend = "flashinfer-trtllm" - elif has_flashinfer(): - self.backend = "flashinfer-cutlass" - elif cutlass_fp4_supported(): - self.backend = "cutlass" - elif is_fp4_marlin_supported(): - self.backend = "marlin" - else: + self.backend = "none" + if envs.VLLM_NVFP4_GEMM_BACKEND is None: + if has_flashinfer(): + self.backend = "flashinfer-cutlass" + elif cutlass_fp4_supported(): + self.backend = "cutlass" + elif is_fp4_marlin_supported(): + self.backend = "marlin" + elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"): + self.backend = envs.VLLM_NVFP4_GEMM_BACKEND + assert has_flashinfer(), f"FlashInfer is required for {self.backend}" + + if self.backend == "none": raise ValueError( - "Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above." + "No valid NVFP4 GEMM backend found. " + "Please check your platform capability." ) + logger.info_once(f"Using {self.backend} for NVFP4 GEMM") + def create_weights( self, layer: torch.nn.Module, @@ -1109,11 +1113,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): layer.alpha, output_dtype, ) - if self.backend == "flashinfer-trtllm": - out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") - elif self.backend == "flashinfer-cutlass": - out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + if self.backend.startswith("flashinfer-"): + backend_name = self.backend[len("flashinfer-") :] + out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) else: + assert self.backend == "cutlass" out = cutlass_scaled_fp4_mm(*mm_args) if bias is not None: