diff --git a/benchmarks/kernels/bench_nvfp4_gemm.py b/benchmarks/kernels/bench_nvfp4_gemm.py index 9e832c9faa8e8..6b19eb113f3e7 100644 --- a/benchmarks/kernels/bench_nvfp4_gemm.py +++ b/benchmarks/kernels/bench_nvfp4_gemm.py @@ -3,6 +3,7 @@ import argparse import copy import itertools +import os import torch from weight_shapes import WEIGHT_SHAPES @@ -23,21 +24,45 @@ PROVIDER_CFGS = { "torch-bf16": dict(enabled=True), "nvfp4": dict(no_a_quant=False, enabled=True), "nvfp4-noquant": dict(no_a_quant=True, enabled=True), + "fbgemm-nvfp4": dict(fbgemm=True, no_a_quant=False, enabled=True), + "fbgemm-nvfp4-noquant": dict(fbgemm=True, no_a_quant=True, enabled=True), } +_needs_fbgemm = any( + v.get("fbgemm", False) for v in PROVIDER_CFGS.values() if v.get("enabled", False) +) +if _needs_fbgemm: + try: + from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import ( + triton_scale_nvfp4_quant, + ) + except ImportError: + print( + "WARNING: FBGEMM providers are enabled but fbgemm_gpu is not installed. " + "These providers will be skipped. Please install fbgemm_gpu with: " + "'pip install fbgemm-gpu-genai' to run them." + ) + # Disable FBGEMM providers so the benchmark can run. + for cfg in PROVIDER_CFGS.values(): + if cfg.get("fbgemm"): + cfg["enabled"] = False + _enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] -def _quant_weight_nvfp4(b: torch.Tensor, device: str): +def _quant_weight_nvfp4(b: torch.Tensor, device: str, cfg): # Compute global scale for weight b_amax = torch.abs(b).max().to(torch.float32) b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax - b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale) + if "fbgemm" in cfg and cfg["fbgemm"]: + b_fp4, scale_b_fp4 = triton_scale_nvfp4_quant(b, b_global_scale) + else: + b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale) return b_fp4, scale_b_fp4, b_global_scale def build_nvfp4_runner(cfg, a, b, dtype, device): - b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device) + b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device, cfg) # Compute global scale for activation # NOTE: This is generally provided ahead-of-time by the model checkpoint. @@ -46,6 +71,35 @@ def build_nvfp4_runner(cfg, a, b, dtype, device): # Alpha for the GEMM operation alpha = 1.0 / (a_global_scale * b_global_scale) + if "fbgemm" in cfg and cfg["fbgemm"]: + if cfg["no_a_quant"]: + a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale) + + def run(): + return torch.ops.fbgemm.f4f4bf16( + a_fp4, + b_fp4, + scale_a_fp4, + scale_b_fp4, + global_scale=alpha, + use_mx=False, + ) + + return run + else: + + def run(): + a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale) + return torch.ops.fbgemm.f4f4bf16( + a_fp4, + b_fp4, + scale_a_fp4, + scale_b_fp4, + global_scale=alpha, + use_mx=False, + ) + + return run if cfg["no_a_quant"]: # Pre-quantize activation @@ -130,10 +184,13 @@ if __name__ == "__main__": for K, N, model in prepare_shapes(args): print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:") + save_dir = f"bench_nvfp4_res_n{N}_k{K}" + os.makedirs(save_dir, exist_ok=True) + benchmark.run( print_data=True, show_plots=True, - save_path=f"bench_nvfp4_res_n{N}_k{K}", + save_path=save_dir, N=N, K=K, ) diff --git a/vllm/envs.py b/vllm/envs.py index 5d622c0675290..b8af770d05f60 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -201,6 +201,7 @@ if TYPE_CHECKING: VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True VLLM_USE_NCCL_SYMM_MEM: bool = False VLLM_NCCL_INCLUDE_PATH: Optional[str] = None + VLLM_USE_FBGEMM: bool = False def get_default_cache_root(): @@ -1452,7 +1453,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # NCCL header path "VLLM_NCCL_INCLUDE_PATH": lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None), - + # Flag to enable FBGemm kernels on model execution + "VLLM_USE_FBGEMM": lambda: bool(int(os.getenv("VLLM_USE_FBGEMM", "0"))), } # --8<-- [end:env-vars-definition] @@ -1548,6 +1550,7 @@ def compute_hash() -> str: "VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", + "VLLM_USE_FBGEMM", ] for key in environment_variables_to_hash: # if this goes out of sync with environment_variables, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index dedd681f15ded..d472427756d46 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -30,8 +30,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): if envs.VLLM_USE_TRTLLM_FP4_GEMM: assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" self.backend = "flashinfer-trtllm" + logger.info_once("Using flashinfer-trtllm for FP4") + elif envs.VLLM_USE_FBGEMM: + self.backend = "fbgemm" + try: + import fbgemm_gpu # noqa: F401 + except ImportError as exc: + raise ImportError( + "Backend fbgemm requires fbgemm.f4f4bf16 operator, " + "Please install with: pip install fbgemm-gpu-genai" + ) from exc + logger.info_once("Using FGBEMM-GPU-GENAI for FP4") elif has_flashinfer(): self.backend = "flashinfer-cutlass" + logger.info_once("Using flashinfer-cutlass for FP4") else: self.backend = "cutlass" self.group_size = 16 @@ -116,6 +128,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): layer.weight_packed = Parameter(weight, requires_grad=False) else: swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) + if self.backend == "fbgemm": + swizzled_weight_scale = swizzled_weight_scale.view(-1).view( + torch.uint8) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight_packed = Parameter(layer.weight_packed.data, @@ -153,6 +168,15 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") elif self.backend == "flashinfer-cutlass": out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + elif self.backend == "fbgemm": + out = torch.ops.fbgemm.f4f4bf16( + x_fp4, + layer.weight_packed, + x_blockscale.view(-1).view(torch.uint8), + layer.weight_scale, + layer.alpha, + use_mx=False, + ).to(output_dtype) else: out = cutlass_scaled_fp4_mm(*mm_args)