From 0271c2ff2fd15bd1a7c19484572a81e056e75620 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 30 Jul 2025 10:15:02 -0400 Subject: [PATCH] [Test] Add Benchmark and Unit Test for `per_token_group_quant` (#21860) Signed-off-by: yewentao256 --- .../benchmark_per_token_group_quant.py | 159 ++++++++++++++++++ .../test_per_token_group_quant.py | 31 +++- 2 files changed, 189 insertions(+), 1 deletion(-) create mode 100644 benchmarks/kernels/benchmark_per_token_group_quant.py diff --git a/benchmarks/kernels/benchmark_per_token_group_quant.py b/benchmarks/kernels/benchmark_per_token_group_quant.py new file mode 100644 index 000000000000..1ccb5e08b3d5 --- /dev/null +++ b/benchmarks/kernels/benchmark_per_token_group_quant.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import math +from contextlib import contextmanager +from typing import Callable +from unittest.mock import patch + +import torch + +from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils +from vllm.platforms import current_platform + + +@contextmanager +def _triton_mode(): + """Temporarily force the Triton fallback path""" + with patch("vllm.platforms.current_platform.is_cuda", return_value=False): + yield + + +def _time_cuda( + fn: Callable[[], tuple[torch.Tensor, torch.Tensor]], + warmup_iters: int, + bench_iters: int, +) -> float: + # warmup + for _ in range(warmup_iters): + fn() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(bench_iters): + fn() + end.record() + torch.cuda.synchronize() + + return start.elapsed_time(end) / bench_iters # ms/iter + + +def _run_single( + shape: tuple[int, int], + group_size: int, + dtype: str, + *, + column_major: bool = False, + scale_ue8m0: bool = False, + warmup_iters: int, + bench_iters: int, +) -> None: + num_tokens, hidden_dim = shape + + device = torch.device("cuda") + torch.manual_seed(42) + x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8 + + if dtype == "fp8": + + def cuda_impl(): + return fp8_utils.per_token_group_quant_fp8( + x, + group_size, + column_major_scales=column_major, + use_ue8m0=scale_ue8m0, + ) + + def triton_impl(): + with _triton_mode(): + return fp8_utils.per_token_group_quant_fp8( + x, + group_size, + column_major_scales=column_major, + use_ue8m0=scale_ue8m0, + ) + elif dtype == "int8": + + def cuda_impl(): + return int8_utils.per_token_group_quant_int8(x, group_size) + + def triton_impl(): + with _triton_mode(): + return int8_utils.per_token_group_quant_int8(x, group_size) + else: + raise ValueError("dtype must be 'fp8' or 'int8'") + + cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters) + triton_ms = _time_cuda(triton_impl, warmup_iters, bench_iters) + + speedup = triton_ms / cuda_ms if cuda_ms else math.inf + + cfg_desc = ( + f"shape={shape} gs={group_size:<3} col_major={column_major:<5} " + f"ue8m0={scale_ue8m0:<5} dtype={dtype}" + ) + print( + f"{cfg_desc:55} | CUDA {cuda_ms:7.3f} ms | Triton {triton_ms:7.3f} ms | " + f"speed-up ×{speedup:5.2f}" + ) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--warmup-iters", type=int, default=10) + parser.add_argument("--bench-iters", type=int, default=100) + parser.add_argument("--dtype", choices=["fp8", "int8", "both"], default="both") + return parser.parse_args() + + +if __name__ == "__main__": + if not current_platform.is_cuda(): + raise RuntimeError("CUDA device is required to run this benchmark.") + + args = parse_args() + warmup_iters, bench_iters = args.warmup_iters, args.bench_iters + + shapes = [(32, 128), (64, 256), (16, 512)] + group_sizes = [64, 128] + + dtypes = ["fp8", "int8"] if args.dtype == "both" else [args.dtype] + + header = ( + "Configuration".ljust(55) + + " | " + + "CUDA (ms)".center(12) + + " | " + + "Triton (ms)".center(13) + + " | " + + "Speed-up" + ) + print(header) + print("-" * len(header)) + + for dtype in dtypes: + for shape in shapes: + for gs in group_sizes: + if dtype == "fp8": + for col_major in (False, True): + for ue8m0 in (False, True): + _run_single( + shape, + gs, + dtype, + column_major=col_major, + scale_ue8m0=ue8m0, + warmup_iters=warmup_iters, + bench_iters=bench_iters, + ) + else: # INT8 has no col-major / ue8m0 switches + _run_single( + shape, + gs, + dtype, + warmup_iters=warmup_iters, + bench_iters=bench_iters, + ) diff --git a/tests/kernels/quantization/test_per_token_group_quant.py b/tests/kernels/quantization/test_per_token_group_quant.py index f826983fe94e..07f17d1efe64 100644 --- a/tests/kernels/quantization/test_per_token_group_quant.py +++ b/tests/kernels/quantization/test_per_token_group_quant.py @@ -5,7 +5,7 @@ from unittest.mock import patch import pytest import torch -from vllm.model_executor.layers.quantization.utils import fp8_utils +from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils @pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)]) @@ -42,3 +42,32 @@ def test_per_token_group_quant_fp8(shape, column_major: bool, assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15) assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01) + + +@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)]) +@pytest.mark.parametrize("group_size", [64, 128]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_per_token_group_quant_int8(shape, group_size: int): + device = "cuda" + + torch.manual_seed(42) + num_tokens, hidden_dim = shape + + x = (torch.randn( + (num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8) + + # cuda path + out_q, scale = int8_utils.per_token_group_quant_int8( + x, + group_size, + ) + + # triton ref + with patch("vllm.platforms.current_platform.is_cuda", return_value=False): + ref_q, ref_s = int8_utils.per_token_group_quant_int8( + x, + group_size, + ) + + assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15) + assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)