[NVIDIA] Add support for cudnn fp4 gemm via flashinfer (#26107)

Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Kaixi Hou 2025-10-15 10:53:00 -07:00 committed by GitHub
parent a1063628a4
commit de92d916fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 38 deletions

View File

@ -191,6 +191,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: str | None = None VLLM_USE_TRTLLM_ATTENTION: str | None = None
VLLM_NVFP4_GEMM_BACKEND: str | None = None
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
@ -1292,11 +1293,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, it means we pre-downloaded cubin files and flashinfer will # If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly. # read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False), "VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer. # Supported options:
# Otherwise, uses the first available of: flashinfer cutlass GEMM, # - "flashinfer-cudnn": use flashinfer cudnn GEMM backend
# vllm cutlass GEMM, marlin GEMM. # - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
"VLLM_USE_TRTLLM_FP4_GEMM": lambda: bool( # - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0")) # - <none>: automatically pick an available backend
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
"VLLM_NVFP4_GEMM_BACKEND",
None,
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"],
), ),
# Controls garbage collection during CUDA graph capture. # Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time. # If set to 0 (default), enables GC freezing to speed up capture time.
@ -1492,7 +1497,6 @@ def compute_hash() -> str:
"VLLM_DISABLED_KERNELS", "VLLM_DISABLED_KERNELS",
"VLLM_USE_DEEP_GEMM", "VLLM_USE_DEEP_GEMM",
"VLLM_USE_DEEP_GEMM_E8M0", "VLLM_USE_DEEP_GEMM_E8M0",
"VLLM_USE_TRTLLM_FP4_GEMM",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP16", "VLLM_USE_FLASHINFER_MOE_FP16",
"VLLM_USE_FLASHINFER_MOE_FP8", "VLLM_USE_FLASHINFER_MOE_FP8",
@ -1524,6 +1528,7 @@ def compute_hash() -> str:
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
"VLLM_NVFP4_GEMM_BACKEND",
"VLLM_USE_FBGEMM", "VLLM_USE_FBGEMM",
] ]
for key in environment_variables_to_hash: for key in environment_variables_to_hash:

View File

@ -14,7 +14,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
run_nvfp4_emulations, run_nvfp4_emulations,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
swizzle_blockscale,
)
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
GroupQuantScaleParameter, GroupQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
@ -29,10 +32,12 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
def __init__(self): def __init__(self):
if envs.VLLM_USE_TRTLLM_FP4_GEMM: self.backend = "none"
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" if envs.VLLM_NVFP4_GEMM_BACKEND is None:
self.backend = "flashinfer-trtllm" if has_flashinfer():
logger.info_once("Using flashinfer-trtllm for FP4") self.backend = "flashinfer-cutlass"
elif cutlass_fp4_supported():
self.backend = "cutlass"
elif envs.VLLM_USE_FBGEMM: elif envs.VLLM_USE_FBGEMM:
self.backend = "fbgemm" self.backend = "fbgemm"
try: try:
@ -42,12 +47,17 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
"Backend fbgemm requires fbgemm.f4f4bf16 operator, " "Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai" "Please install with: pip install fbgemm-gpu-genai"
) from exc ) from exc
logger.info_once("Using FGBEMM-GPU-GENAI for FP4") elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
elif has_flashinfer(): self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
self.backend = "flashinfer-cutlass" assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
logger.info_once("Using flashinfer-cutlass for FP4")
else: if self.backend == "none":
self.backend = "cutlass" raise ValueError(
"No valid NVFP4 GEMM backend found. "
"Please check your platform capability."
)
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
self.group_size = 16 self.group_size = 16
@classmethod @classmethod
@ -184,10 +194,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
layer.alpha, layer.alpha,
output_dtype, output_dtype,
) )
if self.backend == "flashinfer-trtllm": if self.backend.startswith("flashinfer-"):
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") backend_name = self.backend[len("flashinfer-") :]
elif self.backend == "flashinfer-cutlass": out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
elif self.backend == "fbgemm": elif self.backend == "fbgemm":
out = torch.ops.fbgemm.f4f4bf16( out = torch.ops.fbgemm.f4f4bf16(
x_fp4, x_fp4,
@ -198,6 +207,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
use_mx=False, use_mx=False,
).to(output_dtype) ).to(output_dtype)
else: else:
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args) out = cutlass_scaled_fp4_mm(*mm_args)
if bias is not None: if bias is not None:

View File

@ -926,22 +926,26 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config) -> None: def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config self.quant_config = quant_config
if envs.VLLM_USE_TRTLLM_FP4_GEMM: self.backend = "none"
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" if envs.VLLM_NVFP4_GEMM_BACKEND is None:
self.backend = "flashinfer-trtllm" if has_flashinfer():
elif has_flashinfer(): self.backend = "flashinfer-cutlass"
self.backend = "flashinfer-cutlass" elif cutlass_fp4_supported():
elif cutlass_fp4_supported(): self.backend = "cutlass"
self.backend = "cutlass" elif is_fp4_marlin_supported():
elif is_fp4_marlin_supported(): self.backend = "marlin"
self.backend = "marlin" elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
else: self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
if self.backend == "none":
raise ValueError( raise ValueError(
"Current platform does not support NVFP4" "No valid NVFP4 GEMM backend found. "
" quantization. Please use Blackwell and" "Please check your platform capability."
" above."
) )
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@ -1109,11 +1113,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.alpha, layer.alpha,
output_dtype, output_dtype,
) )
if self.backend == "flashinfer-trtllm": if self.backend.startswith("flashinfer-"):
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") backend_name = self.backend[len("flashinfer-") :]
elif self.backend == "flashinfer-cutlass": out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
else: else:
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args) out = cutlass_scaled_fp4_mm(*mm_args)
if bias is not None: if bias is not None: