From 3d330c4c095b78b3e6226d99f4d4a7a0965f3758 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Sun, 15 Jun 2025 03:15:37 -0400 Subject: [PATCH] [Benchmark] Refactor benchmark script for fp8 & int8 (#19627) Signed-off-by: yewentao256 --- benchmarks/kernels/bench_fp8_gemm.py | 249 ++++++++++---------------- benchmarks/kernels/bench_int8_gemm.py | 215 ++++++++++------------ 2 files changed, 184 insertions(+), 280 deletions(-) diff --git a/benchmarks/kernels/bench_fp8_gemm.py b/benchmarks/kernels/bench_fp8_gemm.py index b964ed242edf8..d17443871cf66 100644 --- a/benchmarks/kernels/bench_fp8_gemm.py +++ b/benchmarks/kernels/bench_fp8_gemm.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import copy import itertools @@ -11,6 +10,80 @@ from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant from vllm.triton_utils import triton +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "fp8-tensor-w-token-a": dict( + w="tensor", a="token", no_a_quant=False, enabled=False + ), + "fp8-tensor-w-tensor-a": dict( + w="tensor", a="tensor", no_a_quant=False, enabled=True + ), + "fp8-channel-w-token-a": dict( + w="channel", a="token", no_a_quant=False, enabled=True + ), + "fp8-channel-w-tensor-a": dict( + w="channel", a="tensor", no_a_quant=False, enabled=False + ), + "fp8-tensor-w-token-a-noquant": dict( + w="tensor", a="token", no_a_quant=True, enabled=False + ), + "fp8-tensor-w-tensor-a-noquant": dict( + w="tensor", a="tensor", no_a_quant=True, enabled=True + ), + "fp8-channel-w-token-a-noquant": dict( + w="channel", a="token", no_a_quant=True, enabled=True + ), + "fp8-channel-w-tensor-a-noquant": dict( + w="channel", a="tensor", no_a_quant=True, enabled=False + ), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def _quant_weight_fp8(b: torch.Tensor, w_type: str, device: str): + if w_type == "tensor": + scale_b = torch.ones(1, device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + else: + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, use_per_token_if_dynamic=True) + return b_fp8.t(), scale_b_fp8 + + +def build_fp8_runner(cfg, a, b, dtype, device): + b_fp8, scale_b_fp8 = _quant_weight_fp8(b, cfg["w"], device) + + scale_a_const = ( + torch.ones(1, device=device, dtype=torch.float32) + if cfg["a"] == "tensor" + else None + ) + + if cfg["no_a_quant"]: + if cfg["a"] == "tensor": + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const) + else: + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True) + + def run(): + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + return run + + if cfg["a"] == "tensor": + + def run(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + else: + + def run(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + return run + @triton.testing.perf_report( triton.testing.Benchmark( @@ -18,28 +91,8 @@ from vllm.triton_utils import triton x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], x_log=False, line_arg="provider", - line_vals=[ - "torch-bf16", - # "fp8-tensor-w-token-a", - "fp8-tensor-w-tensor-a", - "fp8-channel-w-token-a", - # "fp8-channel-w-tensor-a", - # "fp8-tensor-w-token-a-noquant", - "fp8-tensor-w-tensor-a-noquant", - "fp8-channel-w-token-a-noquant", - # "fp8-channel-w-tensor-a-noquant", - ], - line_names=[ - "torch-bf16", - # "fp8-tensor-w-token-a", - "fp8-tensor-w-tensor-a", - "fp8-channel-w-token-a", - # "fp8-channel-w-tensor-a", - # "fp8-tensor-w-token-a-noquant", - "fp8-tensor-w-tensor-a-noquant", - "fp8-channel-w-token-a-noquant", - # "fp8-channel-w-tensor-a-noquant", - ], + line_vals=_enabled, + line_names=_enabled, ylabel="TFLOP/s (larger is better)", plot_name="BF16 vs FP8 GEMMs", args={}, @@ -50,144 +103,34 @@ def benchmark(batch_size, provider, N, K): device = "cuda" dtype = torch.bfloat16 - # Create input tensors a = torch.randn((M, K), device=device, dtype=dtype) b = torch.randn((N, K), device=device, dtype=dtype) quantiles = [0.5, 0.2, 0.8] - if "torch-bf16" in provider: + if provider == "torch-bf16": ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: torch.nn.functional.linear(a, b), quantiles=quantiles ) - - elif "fp8" in provider: - # Weights are always quantized ahead of time - if "noquant" in provider: - # For no quantization, we just measure the GEMM - if "tensor-w-token-a" in provider: - # Dynamic per-token quant for A, per-tensor quant for B - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) - assert scale_b_fp8.numel() == 1 - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( - a, use_per_token_if_dynamic=True - ) - - def run_quant(): - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - elif "tensor-w-tensor-a" in provider: - # Static per-tensor quantization with fixed scales - # for both A and B - scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) - scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - assert scale_b_fp8.numel() == 1 - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) - - def run_quant(): - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - elif "channel-w-token-a" in provider: - # Static per-channel quantization for weights, per-token - # quant for A - scale_b = torch.tensor((N,), device=device, dtype=torch.float32) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - scale_b_fp8 = scale_b_fp8.expand(N).contiguous() - assert scale_b_fp8.numel() == N - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( - a, use_per_token_if_dynamic=True - ) - - def run_quant(): - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - elif "channel-w-tensor-a" in provider: - # Static per-channel quantization for weights, per-tensor - # quant for A - scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) - scale_b = torch.tensor((N,), device=device, dtype=torch.float32) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - scale_b_fp8 = scale_b_fp8.expand(N).contiguous() - assert scale_b_fp8.numel() == N - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) - - def run_quant(): - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - else: - # In these cases, we quantize the activations during the GEMM call - if "tensor-w-token-a" in provider: - # Dynamic per-token quant for A, per-tensor quant for B - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) - assert scale_b_fp8.numel() == 1 - - def run_quant(): - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( - a, use_per_token_if_dynamic=True - ) - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - elif "tensor-w-tensor-a" in provider: - # Static per-tensor quantization with fixed scales - # for both A and B - scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) - scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - assert scale_b_fp8.numel() == 1 - - def run_quant(): - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - elif "channel-w-token-a" in provider: - # Static per-channel quantization for weights, per-token - # quant for A - scale_b = torch.tensor((N,), device=device, dtype=torch.float32) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - scale_b_fp8 = scale_b_fp8.expand(N).contiguous() - assert scale_b_fp8.numel() == N - - def run_quant(): - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( - a, use_per_token_if_dynamic=True - ) - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - elif "channel-w-tensor-a" in provider: - # Static per-channel quantization for weights, per-tensor - # quant for A - scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) - scale_b = torch.tensor((N,), device=device, dtype=torch.float32) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - scale_b_fp8 = scale_b_fp8.expand(N).contiguous() - assert scale_b_fp8.numel() == N - - def run_quant(): - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - b_fp8 = b_fp8.t() - + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_fp8_runner(cfg, a, b, dtype, device) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: run_quant(), quantiles=quantiles ) - # Calculate TFLOP/s, two flops per multiply-add - tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3) - return tflops(ms), tflops(max_ms), tflops(min_ms) + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) def prepare_shapes(args): - KN_model_names = [] - models_tps = list(itertools.product(args.models, args.tp_sizes)) - for model, tp_size in models_tps: - assert model in WEIGHT_SHAPES - for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): - KN[tp_split_dim] = KN[tp_split_dim] // tp_size + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size KN.append(model) - KN_model_names.append(KN) - return KN_model_names + out.append(KN) + return out if __name__ == "__main__": @@ -197,21 +140,13 @@ if __name__ == "__main__": nargs="+", type=str, default=["meta-llama/Llama-3.1-8B-Instruct"], - choices=[*WEIGHT_SHAPES.keys()], - help="List of models to benchmark", - ) - parser.add_argument( - "--tp-sizes", - nargs="+", - type=int, - default=[1], - help="List of tensor parallel sizes", + choices=list(WEIGHT_SHAPES.keys()), ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) args = parser.parse_args() - KN_model_names = prepare_shapes(args) - for K, N, model_name in KN_model_names: - print(f"{model_name}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") + for K, N, model in prepare_shapes(args): + print(f"{model}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") benchmark.run( print_data=True, show_plots=True, diff --git a/benchmarks/kernels/bench_int8_gemm.py b/benchmarks/kernels/bench_int8_gemm.py index e6adcaa00ded0..e9c6d64404d0d 100644 --- a/benchmarks/kernels/bench_int8_gemm.py +++ b/benchmarks/kernels/bench_int8_gemm.py @@ -11,6 +11,84 @@ from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant from vllm.triton_utils import triton +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "int8-tensor-w-token-a": dict( + w="tensor", a="token", no_a_quant=False, enabled=False + ), + "int8-tensor-w-tensor-a": dict( + w="tensor", a="tensor", no_a_quant=False, enabled=True + ), + "int8-channel-w-token-a": dict( + w="channel", a="token", no_a_quant=False, enabled=True + ), + "int8-channel-w-tensor-a": dict( + w="channel", a="tensor", no_a_quant=False, enabled=False + ), + "int8-tensor-w-token-a-noquant": dict( + w="tensor", a="token", no_a_quant=True, enabled=False + ), + "int8-tensor-w-tensor-a-noquant": dict( + w="tensor", a="tensor", no_a_quant=True, enabled=True + ), + "int8-channel-w-token-a-noquant": dict( + w="channel", a="token", no_a_quant=True, enabled=True + ), + "int8-channel-w-tensor-a-noquant": dict( + w="channel", a="tensor", no_a_quant=True, enabled=False + ), +} + + +def _quant_weight(b, w_type, device): + if w_type == "tensor": + scale_b = torch.ones(1, device=device, dtype=torch.float32) + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) + assert scale_b_int8.numel() == 1 + else: # channel + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) + assert scale_b_int8.numel() == b.shape[0] + return b_int8.t(), scale_b_int8 + + +def build_int8_runner(cfg, a, b, dtype, device): + # quant before running the kernel + b_int8, scale_b_int8 = _quant_weight(b, cfg["w"], device) + + scale_a_const = None + if cfg["a"] == "tensor": + scale_a_const = torch.ones(1, device=device, dtype=torch.float32) + + # no quant, create activation ahead + if cfg["no_a_quant"]: + if cfg["a"] == "tensor": + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const) + else: # token + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) + + def run_quant(): + return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) + + return run_quant + + # dynamic quant, create activation inside + if cfg["a"] == "tensor": + + def run_quant(): + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const) + return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) + + else: # token + + def run_quant(): + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) + return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) + + return run_quant + + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v.get("enabled")] + @triton.testing.perf_report( triton.testing.Benchmark( @@ -18,28 +96,8 @@ from vllm.triton_utils import triton x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], x_log=False, line_arg="provider", - line_vals=[ - "torch-bf16", - # "int8-tensor-w-token-a", - "int8-tensor-w-tensor-a", - "int8-channel-w-token-a", - # "int8-channel-w-tensor-a", - # "int8-tensor-w-token-a-noquant", - "int8-tensor-w-tensor-a-noquant", - "int8-channel-w-token-a-noquant", - # "int8-channel-w-tensor-a-noquant", - ], - line_names=[ - "torch-bf16", - # "int8-tensor-w-token-a", - "int8-tensor-w-tensor-a", - "int8-channel-w-token-a", - # "int8-channel-w-tensor-a", - # "int8-tensor-w-token-a-noquant", - "int8-tensor-w-tensor-a-noquant", - "int8-channel-w-token-a-noquant", - # "int8-channel-w-tensor-a-noquant", - ], + line_vals=_enabled, + line_names=[k for k in _enabled], ylabel="TFLOP/s (larger is better)", plot_name="BF16 vs INT8 GEMMs", args={}, @@ -54,114 +112,26 @@ def benchmark(batch_size, provider, N, K): quantiles = [0.5, 0.2, 0.8] - if "torch-bf16" in provider: + if provider == "torch-bf16": ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: torch.nn.functional.linear(a, b), quantiles=quantiles ) - - elif "int8" in provider: - # Weights are always quantized ahead of time - if "noquant" in provider: - # For "no quant", we don't measure the time for activations - if "tensor-w-token-a" in provider: - # Dynamic per-token quant for A, static per-tensor quant for B - scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) - b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) - assert scale_b_int8.numel() == 1 - a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) - - elif "tensor-w-tensor-a" in provider: - # Static per-tensor quantization with fixed scales for both A and B - scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) - scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) - b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) - assert scale_b_int8.numel() == 1 - a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a) - - elif "channel-w-token-a" in provider: - # Dynamic per-channel quantization for weights, per-token quant for A - b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) - assert scale_b_int8.numel() == N - a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) - - elif "channel-w-tensor-a" in provider: - # Dynamic per-channel quantization for weights, per-tensor quant for A - scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) - b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) - assert scale_b_int8.numel() == N - a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a) - - def run_quant(): - return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) - - else: - # Quantize the activations during the GEMM call - if "tensor-w-token-a" in provider: - # Dynamic per-token quant for A, static per-tensor quant for B - scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) - b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) - assert scale_b_int8.numel() == 1 - - def run_quant(): - a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) - return vllm_scaled_mm( - a_int8, b_int8, scale_a_int8, scale_b_int8, dtype - ) - - elif "tensor-w-tensor-a" in provider: - # Static per-tensor quantization with fixed scales for both A and B - scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) - scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) - b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) - assert scale_b_int8.numel() == 1 - - def run_quant(): - a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a) - return vllm_scaled_mm( - a_int8, b_int8, scale_a_int8, scale_b_int8, dtype - ) - - elif "channel-w-token-a" in provider: - # Dynamic per-channel quant for weights, per-token quant for A - b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) - assert scale_b_int8.numel() == N - - def run_quant(): - a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) - return vllm_scaled_mm( - a_int8, b_int8, scale_a_int8, scale_b_int8, dtype - ) - - elif "channel-w-tensor-a" in provider: - # Dynamic per-channel quant for weights, static per-tensor quant for A - scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) - b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) - assert scale_b_int8.numel() == N - - def run_quant(): - a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a) - return vllm_scaled_mm( - a_int8, b_int8, scale_a_int8, scale_b_int8, dtype - ) - - b_int8 = b_int8.t() - + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_int8_runner(cfg, a, b, dtype, device) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: run_quant(), quantiles=quantiles ) - # Calculate TFLOP/s, two flops per multiply-add - tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3) - return tflops(ms), tflops(max_ms), tflops(min_ms) + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) def prepare_shapes(args): KN_model_names = [] - models_tps = list(itertools.product(args.models, args.tp_sizes)) - for model, tp_size in models_tps: - assert model in WEIGHT_SHAPES - for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): - KN[tp_split_dim] = KN[tp_split_dim] // tp_size + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size KN.append(model) KN_model_names.append(KN) return KN_model_names @@ -174,7 +144,7 @@ if __name__ == "__main__": nargs="+", type=str, default=["meta-llama/Llama-3.1-8B-Instruct"], - choices=[*WEIGHT_SHAPES.keys()], + choices=list(WEIGHT_SHAPES.keys()), help="List of models to benchmark", ) parser.add_argument( @@ -186,9 +156,8 @@ if __name__ == "__main__": ) args = parser.parse_args() - KN_model_names = prepare_shapes(args) - for K, N, model_name in KN_model_names: - print(f"{model_name}, N={N} K={K}, BF16 vs INT8 GEMMs TFLOP/s:") + for K, N, model in prepare_shapes(args): + print(f"{model}, N={N} K={K}, BF16 vs INT8 GEMMs TFLOP/s:") benchmark.run( print_data=True, show_plots=True,