mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 18:14:03 +08:00
Enable Fbgemm NVFP4 on Dense models (#25609)
Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
517a857166
commit
12c21d28c1
@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
@ -23,21 +24,45 @@ PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||
"fbgemm-nvfp4": dict(fbgemm=True, no_a_quant=False, enabled=True),
|
||||
"fbgemm-nvfp4-noquant": dict(fbgemm=True, no_a_quant=True, enabled=True),
|
||||
}
|
||||
|
||||
_needs_fbgemm = any(
|
||||
v.get("fbgemm", False) for v in PROVIDER_CFGS.values() if v.get("enabled", False)
|
||||
)
|
||||
if _needs_fbgemm:
|
||||
try:
|
||||
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
|
||||
triton_scale_nvfp4_quant,
|
||||
)
|
||||
except ImportError:
|
||||
print(
|
||||
"WARNING: FBGEMM providers are enabled but fbgemm_gpu is not installed. "
|
||||
"These providers will be skipped. Please install fbgemm_gpu with: "
|
||||
"'pip install fbgemm-gpu-genai' to run them."
|
||||
)
|
||||
# Disable FBGEMM providers so the benchmark can run.
|
||||
for cfg in PROVIDER_CFGS.values():
|
||||
if cfg.get("fbgemm"):
|
||||
cfg["enabled"] = False
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def _quant_weight_nvfp4(b: torch.Tensor, device: str):
|
||||
def _quant_weight_nvfp4(b: torch.Tensor, device: str, cfg):
|
||||
# Compute global scale for weight
|
||||
b_amax = torch.abs(b).max().to(torch.float32)
|
||||
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
|
||||
if "fbgemm" in cfg and cfg["fbgemm"]:
|
||||
b_fp4, scale_b_fp4 = triton_scale_nvfp4_quant(b, b_global_scale)
|
||||
else:
|
||||
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
|
||||
return b_fp4, scale_b_fp4, b_global_scale
|
||||
|
||||
|
||||
def build_nvfp4_runner(cfg, a, b, dtype, device):
|
||||
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device)
|
||||
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device, cfg)
|
||||
|
||||
# Compute global scale for activation
|
||||
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
|
||||
@ -46,6 +71,35 @@ def build_nvfp4_runner(cfg, a, b, dtype, device):
|
||||
|
||||
# Alpha for the GEMM operation
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
if "fbgemm" in cfg and cfg["fbgemm"]:
|
||||
if cfg["no_a_quant"]:
|
||||
a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale)
|
||||
|
||||
def run():
|
||||
return torch.ops.fbgemm.f4f4bf16(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
scale_a_fp4,
|
||||
scale_b_fp4,
|
||||
global_scale=alpha,
|
||||
use_mx=False,
|
||||
)
|
||||
|
||||
return run
|
||||
else:
|
||||
|
||||
def run():
|
||||
a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale)
|
||||
return torch.ops.fbgemm.f4f4bf16(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
scale_a_fp4,
|
||||
scale_b_fp4,
|
||||
global_scale=alpha,
|
||||
use_mx=False,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
if cfg["no_a_quant"]:
|
||||
# Pre-quantize activation
|
||||
@ -130,10 +184,13 @@ if __name__ == "__main__":
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||
save_dir = f"bench_nvfp4_res_n{N}_k{K}"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_nvfp4_res_n{N}_k{K}",
|
||||
save_path=save_dir,
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
@ -201,6 +201,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True
|
||||
VLLM_USE_NCCL_SYMM_MEM: bool = False
|
||||
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
|
||||
VLLM_USE_FBGEMM: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -1452,7 +1453,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# NCCL header path
|
||||
"VLLM_NCCL_INCLUDE_PATH":
|
||||
lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None),
|
||||
|
||||
# Flag to enable FBGemm kernels on model execution
|
||||
"VLLM_USE_FBGEMM": lambda: bool(int(os.getenv("VLLM_USE_FBGEMM", "0"))),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
@ -1548,6 +1550,7 @@ def compute_hash() -> str:
|
||||
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
||||
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
|
||||
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
||||
"VLLM_USE_FBGEMM",
|
||||
]
|
||||
for key in environment_variables_to_hash:
|
||||
# if this goes out of sync with environment_variables,
|
||||
|
||||
@ -30,8 +30,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||
self.backend = "flashinfer-trtllm"
|
||||
logger.info_once("Using flashinfer-trtllm for FP4")
|
||||
elif envs.VLLM_USE_FBGEMM:
|
||||
self.backend = "fbgemm"
|
||||
try:
|
||||
import fbgemm_gpu # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
|
||||
"Please install with: pip install fbgemm-gpu-genai"
|
||||
) from exc
|
||||
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
|
||||
elif has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
logger.info_once("Using flashinfer-cutlass for FP4")
|
||||
else:
|
||||
self.backend = "cutlass"
|
||||
self.group_size = 16
|
||||
@ -116,6 +128,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
layer.weight_packed = Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
if self.backend == "fbgemm":
|
||||
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(
|
||||
torch.uint8)
|
||||
layer.weight_scale = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight_packed = Parameter(layer.weight_packed.data,
|
||||
@ -153,6 +168,15 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||
elif self.backend == "fbgemm":
|
||||
out = torch.ops.fbgemm.f4f4bf16(
|
||||
x_fp4,
|
||||
layer.weight_packed,
|
||||
x_blockscale.view(-1).view(torch.uint8),
|
||||
layer.weight_scale,
|
||||
layer.alpha,
|
||||
use_mx=False,
|
||||
).to(output_dtype)
|
||||
else:
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user