mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
Benchmark script for fp8 vs bf16 gemm (#17126)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
2dbe8c0774
commit
f49239cb45
222
benchmarks/kernels/bench_fp8_gemm.py
Normal file
222
benchmarks/kernels/bench_fp8_gemm.py
Normal file
@ -0,0 +1,222 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"torch-bf16",
|
||||
# "fp8-tensor-w-token-a",
|
||||
"fp8-tensor-w-tensor-a",
|
||||
"fp8-channel-w-token-a",
|
||||
# "fp8-channel-w-tensor-a",
|
||||
# "fp8-tensor-w-token-a-noquant",
|
||||
"fp8-tensor-w-tensor-a-noquant",
|
||||
"fp8-channel-w-token-a-noquant",
|
||||
# "fp8-channel-w-tensor-a-noquant",
|
||||
],
|
||||
line_names=[
|
||||
"torch-bf16",
|
||||
# "fp8-tensor-w-token-a",
|
||||
"fp8-tensor-w-tensor-a",
|
||||
"fp8-channel-w-token-a",
|
||||
# "fp8-channel-w-tensor-a",
|
||||
# "fp8-tensor-w-token-a-noquant",
|
||||
"fp8-tensor-w-tensor-a-noquant",
|
||||
"fp8-channel-w-token-a-noquant",
|
||||
# "fp8-channel-w-tensor-a-noquant",
|
||||
],
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs FP8 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Create input tensors
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if "torch-bf16" in provider:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||
)
|
||||
|
||||
elif "fp8" in provider:
|
||||
# Weights are always quantized ahead of time
|
||||
if "noquant" in provider:
|
||||
# For no quantization, we just measure the GEMM
|
||||
if "tensor-w-token-a" in provider:
|
||||
# Dynamic per-token quant for A, per-tensor quant for B
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b)
|
||||
assert scale_b_fp8.numel() == 1
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
def run_quant():
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
elif "tensor-w-tensor-a" in provider:
|
||||
# Static per-tensor quantization with fixed scales
|
||||
# for both A and B
|
||||
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
assert scale_b_fp8.numel() == 1
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
|
||||
def run_quant():
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
elif "channel-w-token-a" in provider:
|
||||
# Static per-channel quantization for weights, per-token
|
||||
# quant for A
|
||||
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
|
||||
assert scale_b_fp8.numel() == N
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
def run_quant():
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
elif "channel-w-tensor-a" in provider:
|
||||
# Static per-channel quantization for weights, per-tensor
|
||||
# quant for A
|
||||
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
|
||||
assert scale_b_fp8.numel() == N
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
|
||||
def run_quant():
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
else:
|
||||
# In these cases, we quantize the activations during the GEMM call
|
||||
if "tensor-w-token-a" in provider:
|
||||
# Dynamic per-token quant for A, per-tensor quant for B
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b)
|
||||
assert scale_b_fp8.numel() == 1
|
||||
|
||||
def run_quant():
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=True
|
||||
)
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
elif "tensor-w-tensor-a" in provider:
|
||||
# Static per-tensor quantization with fixed scales
|
||||
# for both A and B
|
||||
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
assert scale_b_fp8.numel() == 1
|
||||
|
||||
def run_quant():
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
elif "channel-w-token-a" in provider:
|
||||
# Static per-channel quantization for weights, per-token
|
||||
# quant for A
|
||||
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
|
||||
assert scale_b_fp8.numel() == N
|
||||
|
||||
def run_quant():
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=True
|
||||
)
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
elif "channel-w-tensor-a" in provider:
|
||||
# Static per-channel quantization for weights, per-tensor
|
||||
# quant for A
|
||||
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
|
||||
assert scale_b_fp8.numel() == N
|
||||
|
||||
def run_quant():
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
b_fp8 = b_fp8.t()
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), quantiles=quantiles
|
||||
)
|
||||
|
||||
# Calculate TFLOP/s, two flops per multiply-add
|
||||
tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
KN_model_names = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
assert model in WEIGHT_SHAPES
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KN.append(model)
|
||||
KN_model_names.append(KN)
|
||||
return KN_model_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
choices=[*WEIGHT_SHAPES.keys()],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_fp8_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
@ -48,4 +48,50 @@ WEIGHT_SHAPES = {
|
||||
([16384, 106496], 1),
|
||||
([53248, 16384], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"mistralai/Mistral-Large-Instruct-2407": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 57344], 1),
|
||||
([28672, 12288], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-7B-Instruct": [
|
||||
([3584, 4608], 1),
|
||||
([3584, 3584], 0),
|
||||
([3584, 37888], 1),
|
||||
([18944, 3584], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-32B-Instruct": [
|
||||
([5120, 7168], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 55296], 1),
|
||||
([27648, 5120], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-72B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 59136], 1),
|
||||
([29568, 8192], 0),
|
||||
],
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||
([2048, 3072], 1),
|
||||
([2048, 4096], 1),
|
||||
([2048, 2048], 0),
|
||||
([2048, 576], 0),
|
||||
([2048, 21888], 1),
|
||||
([10944, 2048], 0),
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user