mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 04:47:03 +08:00
Enable Fbgemm NVFP4 on Dense models (#25609)
Signed-off-by: Saman Keon <samanamp@outlook.com>
This commit is contained in:
parent
4492e3a554
commit
90b139cfff
@ -3,6 +3,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from weight_shapes import WEIGHT_SHAPES
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
@ -23,21 +24,45 @@ PROVIDER_CFGS = {
|
|||||||
"torch-bf16": dict(enabled=True),
|
"torch-bf16": dict(enabled=True),
|
||||||
"nvfp4": dict(no_a_quant=False, enabled=True),
|
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||||
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||||
|
"fbgemm-nvfp4": dict(fbgemm=True, no_a_quant=False, enabled=True),
|
||||||
|
"fbgemm-nvfp4-noquant": dict(fbgemm=True, no_a_quant=True, enabled=True),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_needs_fbgemm = any(
|
||||||
|
v.get("fbgemm", False) for v in PROVIDER_CFGS.values() if v.get("enabled", False)
|
||||||
|
)
|
||||||
|
if _needs_fbgemm:
|
||||||
|
try:
|
||||||
|
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
|
||||||
|
triton_scale_nvfp4_quant,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"WARNING: FBGEMM providers are enabled but fbgemm_gpu is not installed. "
|
||||||
|
"These providers will be skipped. Please install fbgemm_gpu with: "
|
||||||
|
"'pip install fbgemm-gpu-genai' to run them."
|
||||||
|
)
|
||||||
|
# Disable FBGEMM providers so the benchmark can run.
|
||||||
|
for cfg in PROVIDER_CFGS.values():
|
||||||
|
if cfg.get("fbgemm"):
|
||||||
|
cfg["enabled"] = False
|
||||||
|
|
||||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||||
|
|
||||||
|
|
||||||
def _quant_weight_nvfp4(b: torch.Tensor, device: str):
|
def _quant_weight_nvfp4(b: torch.Tensor, device: str, cfg):
|
||||||
# Compute global scale for weight
|
# Compute global scale for weight
|
||||||
b_amax = torch.abs(b).max().to(torch.float32)
|
b_amax = torch.abs(b).max().to(torch.float32)
|
||||||
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||||
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
|
if "fbgemm" in cfg and cfg["fbgemm"]:
|
||||||
|
b_fp4, scale_b_fp4 = triton_scale_nvfp4_quant(b, b_global_scale)
|
||||||
|
else:
|
||||||
|
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
|
||||||
return b_fp4, scale_b_fp4, b_global_scale
|
return b_fp4, scale_b_fp4, b_global_scale
|
||||||
|
|
||||||
|
|
||||||
def build_nvfp4_runner(cfg, a, b, dtype, device):
|
def build_nvfp4_runner(cfg, a, b, dtype, device):
|
||||||
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device)
|
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device, cfg)
|
||||||
|
|
||||||
# Compute global scale for activation
|
# Compute global scale for activation
|
||||||
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
|
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
|
||||||
@ -46,6 +71,35 @@ def build_nvfp4_runner(cfg, a, b, dtype, device):
|
|||||||
|
|
||||||
# Alpha for the GEMM operation
|
# Alpha for the GEMM operation
|
||||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||||
|
if "fbgemm" in cfg and cfg["fbgemm"]:
|
||||||
|
if cfg["no_a_quant"]:
|
||||||
|
a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
return torch.ops.fbgemm.f4f4bf16(
|
||||||
|
a_fp4,
|
||||||
|
b_fp4,
|
||||||
|
scale_a_fp4,
|
||||||
|
scale_b_fp4,
|
||||||
|
global_scale=alpha,
|
||||||
|
use_mx=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
else:
|
||||||
|
|
||||||
|
def run():
|
||||||
|
a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale)
|
||||||
|
return torch.ops.fbgemm.f4f4bf16(
|
||||||
|
a_fp4,
|
||||||
|
b_fp4,
|
||||||
|
scale_a_fp4,
|
||||||
|
scale_b_fp4,
|
||||||
|
global_scale=alpha,
|
||||||
|
use_mx=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
if cfg["no_a_quant"]:
|
if cfg["no_a_quant"]:
|
||||||
# Pre-quantize activation
|
# Pre-quantize activation
|
||||||
@ -130,10 +184,13 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
for K, N, model in prepare_shapes(args):
|
for K, N, model in prepare_shapes(args):
|
||||||
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||||
|
save_dir = f"bench_nvfp4_res_n{N}_k{K}"
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
benchmark.run(
|
benchmark.run(
|
||||||
print_data=True,
|
print_data=True,
|
||||||
show_plots=True,
|
show_plots=True,
|
||||||
save_path=f"bench_nvfp4_res_n{N}_k{K}",
|
save_path=save_dir,
|
||||||
N=N,
|
N=N,
|
||||||
K=K,
|
K=K,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -201,6 +201,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True
|
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True
|
||||||
VLLM_USE_NCCL_SYMM_MEM: bool = False
|
VLLM_USE_NCCL_SYMM_MEM: bool = False
|
||||||
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
|
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
|
||||||
|
VLLM_USE_FBGEMM: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -1452,7 +1453,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# NCCL header path
|
# NCCL header path
|
||||||
"VLLM_NCCL_INCLUDE_PATH":
|
"VLLM_NCCL_INCLUDE_PATH":
|
||||||
lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None),
|
lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None),
|
||||||
|
# Flag to enable FBGemm kernels on model execution
|
||||||
|
"VLLM_USE_FBGEMM": lambda: bool(int(os.getenv("VLLM_USE_FBGEMM", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
# --8<-- [end:env-vars-definition]
|
# --8<-- [end:env-vars-definition]
|
||||||
@ -1548,6 +1550,7 @@ def compute_hash() -> str:
|
|||||||
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
||||||
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
|
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
|
||||||
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
||||||
|
"VLLM_USE_FBGEMM",
|
||||||
]
|
]
|
||||||
for key in environment_variables_to_hash:
|
for key in environment_variables_to_hash:
|
||||||
# if this goes out of sync with environment_variables,
|
# if this goes out of sync with environment_variables,
|
||||||
|
|||||||
@ -30,8 +30,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||||
self.backend = "flashinfer-trtllm"
|
self.backend = "flashinfer-trtllm"
|
||||||
|
logger.info_once("Using flashinfer-trtllm for FP4")
|
||||||
|
elif envs.VLLM_USE_FBGEMM:
|
||||||
|
self.backend = "fbgemm"
|
||||||
|
try:
|
||||||
|
import fbgemm_gpu # noqa: F401
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
|
||||||
|
"Please install with: pip install fbgemm-gpu-genai"
|
||||||
|
) from exc
|
||||||
|
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
|
||||||
elif has_flashinfer():
|
elif has_flashinfer():
|
||||||
self.backend = "flashinfer-cutlass"
|
self.backend = "flashinfer-cutlass"
|
||||||
|
logger.info_once("Using flashinfer-cutlass for FP4")
|
||||||
else:
|
else:
|
||||||
self.backend = "cutlass"
|
self.backend = "cutlass"
|
||||||
self.group_size = 16
|
self.group_size = 16
|
||||||
@ -116,6 +128,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
layer.weight_packed = Parameter(weight, requires_grad=False)
|
layer.weight_packed = Parameter(weight, requires_grad=False)
|
||||||
else:
|
else:
|
||||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||||
|
if self.backend == "fbgemm":
|
||||||
|
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(
|
||||||
|
torch.uint8)
|
||||||
layer.weight_scale = Parameter(swizzled_weight_scale,
|
layer.weight_scale = Parameter(swizzled_weight_scale,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.weight_packed = Parameter(layer.weight_packed.data,
|
layer.weight_packed = Parameter(layer.weight_packed.data,
|
||||||
@ -153,6 +168,15 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||||
elif self.backend == "flashinfer-cutlass":
|
elif self.backend == "flashinfer-cutlass":
|
||||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||||
|
elif self.backend == "fbgemm":
|
||||||
|
out = torch.ops.fbgemm.f4f4bf16(
|
||||||
|
x_fp4,
|
||||||
|
layer.weight_packed,
|
||||||
|
x_blockscale.view(-1).view(torch.uint8),
|
||||||
|
layer.weight_scale,
|
||||||
|
layer.alpha,
|
||||||
|
use_mx=False,
|
||||||
|
).to(output_dtype)
|
||||||
else:
|
else:
|
||||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user