Enable Fbgemm NVFP4 on Dense models (#25609)

Signed-off-by: Saman Keon <samanamp@outlook.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Saman A. Pour 2025-09-24 21:12:53 -07:00 committed by yewentao256
parent 517a857166
commit 12c21d28c1
3 changed files with 89 additions and 5 deletions

View File

@ -3,6 +3,7 @@
import argparse
import copy
import itertools
import os
import torch
from weight_shapes import WEIGHT_SHAPES
@ -23,21 +24,45 @@ PROVIDER_CFGS = {
"torch-bf16": dict(enabled=True),
"nvfp4": dict(no_a_quant=False, enabled=True),
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
"fbgemm-nvfp4": dict(fbgemm=True, no_a_quant=False, enabled=True),
"fbgemm-nvfp4-noquant": dict(fbgemm=True, no_a_quant=True, enabled=True),
}
_needs_fbgemm = any(
v.get("fbgemm", False) for v in PROVIDER_CFGS.values() if v.get("enabled", False)
)
if _needs_fbgemm:
try:
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
triton_scale_nvfp4_quant,
)
except ImportError:
print(
"WARNING: FBGEMM providers are enabled but fbgemm_gpu is not installed. "
"These providers will be skipped. Please install fbgemm_gpu with: "
"'pip install fbgemm-gpu-genai' to run them."
)
# Disable FBGEMM providers so the benchmark can run.
for cfg in PROVIDER_CFGS.values():
if cfg.get("fbgemm"):
cfg["enabled"] = False
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
def _quant_weight_nvfp4(b: torch.Tensor, device: str):
def _quant_weight_nvfp4(b: torch.Tensor, device: str, cfg):
# Compute global scale for weight
b_amax = torch.abs(b).max().to(torch.float32)
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
if "fbgemm" in cfg and cfg["fbgemm"]:
b_fp4, scale_b_fp4 = triton_scale_nvfp4_quant(b, b_global_scale)
else:
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
return b_fp4, scale_b_fp4, b_global_scale
def build_nvfp4_runner(cfg, a, b, dtype, device):
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device)
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device, cfg)
# Compute global scale for activation
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
@ -46,6 +71,35 @@ def build_nvfp4_runner(cfg, a, b, dtype, device):
# Alpha for the GEMM operation
alpha = 1.0 / (a_global_scale * b_global_scale)
if "fbgemm" in cfg and cfg["fbgemm"]:
if cfg["no_a_quant"]:
a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale)
def run():
return torch.ops.fbgemm.f4f4bf16(
a_fp4,
b_fp4,
scale_a_fp4,
scale_b_fp4,
global_scale=alpha,
use_mx=False,
)
return run
else:
def run():
a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale)
return torch.ops.fbgemm.f4f4bf16(
a_fp4,
b_fp4,
scale_a_fp4,
scale_b_fp4,
global_scale=alpha,
use_mx=False,
)
return run
if cfg["no_a_quant"]:
# Pre-quantize activation
@ -130,10 +184,13 @@ if __name__ == "__main__":
for K, N, model in prepare_shapes(args):
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
save_dir = f"bench_nvfp4_res_n{N}_k{K}"
os.makedirs(save_dir, exist_ok=True)
benchmark.run(
print_data=True,
show_plots=True,
save_path=f"bench_nvfp4_res_n{N}_k{K}",
save_path=save_dir,
N=N,
K=K,
)

View File

@ -201,6 +201,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True
VLLM_USE_NCCL_SYMM_MEM: bool = False
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
VLLM_USE_FBGEMM: bool = False
def get_default_cache_root():
@ -1452,7 +1453,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# NCCL header path
"VLLM_NCCL_INCLUDE_PATH":
lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None),
# Flag to enable FBGemm kernels on model execution
"VLLM_USE_FBGEMM": lambda: bool(int(os.getenv("VLLM_USE_FBGEMM", "0"))),
}
# --8<-- [end:env-vars-definition]
@ -1548,6 +1550,7 @@ def compute_hash() -> str:
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
"VLLM_USE_FBGEMM",
]
for key in environment_variables_to_hash:
# if this goes out of sync with environment_variables,

View File

@ -30,8 +30,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
self.backend = "flashinfer-trtllm"
logger.info_once("Using flashinfer-trtllm for FP4")
elif envs.VLLM_USE_FBGEMM:
self.backend = "fbgemm"
try:
import fbgemm_gpu # noqa: F401
except ImportError as exc:
raise ImportError(
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
) from exc
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
elif has_flashinfer():
self.backend = "flashinfer-cutlass"
logger.info_once("Using flashinfer-cutlass for FP4")
else:
self.backend = "cutlass"
self.group_size = 16
@ -116,6 +128,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
if self.backend == "fbgemm":
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(
torch.uint8)
layer.weight_scale = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight_packed = Parameter(layer.weight_packed.data,
@ -153,6 +168,15 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
elif self.backend == "flashinfer-cutlass":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
elif self.backend == "fbgemm":
out = torch.ops.fbgemm.f4f4bf16(
x_fp4,
layer.weight_packed,
x_blockscale.view(-1).view(torch.uint8),
layer.weight_scale,
layer.alpha,
use_mx=False,
).to(output_dtype)
else:
out = cutlass_scaled_fp4_mm(*mm_args)