mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 01:03:05 +08:00
178 lines
5.0 KiB
Python
178 lines
5.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import argparse
|
|
import copy
|
|
import itertools
|
|
|
|
import torch
|
|
from weight_shapes import WEIGHT_SHAPES
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.platforms import current_platform
|
|
from vllm.scalar_type import scalar_types
|
|
from vllm.triton_utils import triton
|
|
from vllm.utils.flashinfer import flashinfer_fp4_quantize
|
|
|
|
if not current_platform.has_device_capability(100):
|
|
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
|
|
|
|
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
|
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
|
|
|
PROVIDER_CFGS = {
|
|
"vllm": dict(backend="vllm", enabled=True),
|
|
"flashinfer": dict(backend="flashinfer", enabled=True),
|
|
}
|
|
|
|
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
|
|
|
|
|
def compute_global_scale(tensor: torch.Tensor) -> torch.Tensor:
|
|
"""Compute global scale for FP4 quantization."""
|
|
amax = torch.abs(tensor).max().to(torch.float32)
|
|
return FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / amax
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["batch_size"],
|
|
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
|
|
x_log=False,
|
|
line_arg="provider",
|
|
line_vals=_enabled,
|
|
line_names=_enabled,
|
|
ylabel="us (lower is better)",
|
|
plot_name="NVFP4 Input Quantization Latency (us)",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark(batch_size, provider, N, K):
|
|
M = batch_size
|
|
device = "cuda"
|
|
dtype = torch.bfloat16
|
|
|
|
# Create input tensor
|
|
a = torch.randn((M, K), device=device, dtype=dtype)
|
|
|
|
# Compute global scale for activation
|
|
a_global_scale = compute_global_scale(a)
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
|
|
cfg = PROVIDER_CFGS[provider]
|
|
|
|
if cfg["backend"] == "vllm":
|
|
# vLLM's FP4 quantization
|
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
|
lambda: ops.scaled_fp4_quant(a, a_global_scale),
|
|
quantiles=quantiles,
|
|
)
|
|
elif cfg["backend"] == "flashinfer":
|
|
# FlashInfer's FP4 quantization
|
|
# Use is_sf_swizzled_layout=True to match vLLM's output format
|
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
|
lambda: flashinfer_fp4_quantize(
|
|
a, a_global_scale, is_sf_swizzled_layout=True
|
|
),
|
|
quantiles=quantiles,
|
|
)
|
|
|
|
# Convert ms to us for better readability at small batch sizes
|
|
to_us = lambda t_ms: t_ms * 1000
|
|
return to_us(ms), to_us(max_ms), to_us(min_ms)
|
|
|
|
|
|
def prepare_shapes(args):
|
|
out = []
|
|
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
|
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
|
KN[tp_dim] //= tp_size
|
|
KN.append(model)
|
|
out.append(KN)
|
|
return out
|
|
|
|
|
|
def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str):
|
|
"""Test accuracy between vLLM and FlashInfer FP4 quantization."""
|
|
# Create input tensor
|
|
a = torch.randn((M, K), device=device, dtype=dtype)
|
|
|
|
# Compute global scale
|
|
a_global_scale = compute_global_scale(a)
|
|
|
|
# vLLM quantization
|
|
vllm_fp4, vllm_scale = ops.scaled_fp4_quant(a, a_global_scale)
|
|
|
|
# FlashInfer quantization (with swizzled layout to match vLLM's output)
|
|
flashinfer_fp4, flashinfer_scale = flashinfer_fp4_quantize(
|
|
a, a_global_scale, is_sf_swizzled_layout=True
|
|
)
|
|
flashinfer_scale = flashinfer_scale.view(torch.float8_e4m3fn)
|
|
|
|
# Compare outputs
|
|
torch.testing.assert_close(
|
|
vllm_fp4,
|
|
flashinfer_fp4,
|
|
)
|
|
print(f"M={M}, K={K}, dtype={dtype}: PASSED")
|
|
|
|
|
|
def test_accuracy():
|
|
"""Run accuracy tests across various shapes."""
|
|
print("\n" + "=" * 60)
|
|
print("Running accuracy tests: vLLM vs FlashInfer")
|
|
print("=" * 60)
|
|
|
|
device = "cuda"
|
|
dtype = torch.bfloat16
|
|
|
|
# Test various batch sizes and hidden dimensions
|
|
Ms = [1, 1024]
|
|
Ks = [4096]
|
|
|
|
for M in Ms:
|
|
for K in Ks:
|
|
_test_accuracy_once(M, K, dtype, device)
|
|
|
|
print("\nAll accuracy tests passed!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Benchmark NVFP4 quantization: vLLM vs FlashInfer"
|
|
)
|
|
parser.add_argument(
|
|
"--models",
|
|
nargs="+",
|
|
type=str,
|
|
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
|
choices=list(WEIGHT_SHAPES.keys()),
|
|
)
|
|
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
|
parser.add_argument(
|
|
"--save-path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to save benchmark results",
|
|
)
|
|
parser.add_argument(
|
|
"--accuracy",
|
|
action="store_true",
|
|
help="Run accuracy tests",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.accuracy:
|
|
test_accuracy()
|
|
|
|
for K, N, model in prepare_shapes(args):
|
|
print(f"\n{model}, N={N} K={K}")
|
|
benchmark.run(
|
|
print_data=True,
|
|
save_path=args.save_path,
|
|
N=N,
|
|
K=K,
|
|
)
|
|
|
|
print("\nBenchmark finished!")
|