mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 18:23:39 +08:00
[Test] Add Benchmark and Unit Test for per_token_group_quant (#21860)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
e91d3c9cda
commit
0271c2ff2f
159
benchmarks/kernels/benchmark_per_token_group_quant.py
Normal file
159
benchmarks/kernels/benchmark_per_token_group_quant.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Callable
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _triton_mode():
|
||||||
|
"""Temporarily force the Triton fallback path"""
|
||||||
|
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def _time_cuda(
|
||||||
|
fn: Callable[[], tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
warmup_iters: int,
|
||||||
|
bench_iters: int,
|
||||||
|
) -> float:
|
||||||
|
# warmup
|
||||||
|
for _ in range(warmup_iters):
|
||||||
|
fn()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start = torch.cuda.Event(enable_timing=True)
|
||||||
|
end = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
start.record()
|
||||||
|
for _ in range(bench_iters):
|
||||||
|
fn()
|
||||||
|
end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return start.elapsed_time(end) / bench_iters # ms/iter
|
||||||
|
|
||||||
|
|
||||||
|
def _run_single(
|
||||||
|
shape: tuple[int, int],
|
||||||
|
group_size: int,
|
||||||
|
dtype: str,
|
||||||
|
*,
|
||||||
|
column_major: bool = False,
|
||||||
|
scale_ue8m0: bool = False,
|
||||||
|
warmup_iters: int,
|
||||||
|
bench_iters: int,
|
||||||
|
) -> None:
|
||||||
|
num_tokens, hidden_dim = shape
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
torch.manual_seed(42)
|
||||||
|
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8
|
||||||
|
|
||||||
|
if dtype == "fp8":
|
||||||
|
|
||||||
|
def cuda_impl():
|
||||||
|
return fp8_utils.per_token_group_quant_fp8(
|
||||||
|
x,
|
||||||
|
group_size,
|
||||||
|
column_major_scales=column_major,
|
||||||
|
use_ue8m0=scale_ue8m0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def triton_impl():
|
||||||
|
with _triton_mode():
|
||||||
|
return fp8_utils.per_token_group_quant_fp8(
|
||||||
|
x,
|
||||||
|
group_size,
|
||||||
|
column_major_scales=column_major,
|
||||||
|
use_ue8m0=scale_ue8m0,
|
||||||
|
)
|
||||||
|
elif dtype == "int8":
|
||||||
|
|
||||||
|
def cuda_impl():
|
||||||
|
return int8_utils.per_token_group_quant_int8(x, group_size)
|
||||||
|
|
||||||
|
def triton_impl():
|
||||||
|
with _triton_mode():
|
||||||
|
return int8_utils.per_token_group_quant_int8(x, group_size)
|
||||||
|
else:
|
||||||
|
raise ValueError("dtype must be 'fp8' or 'int8'")
|
||||||
|
|
||||||
|
cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters)
|
||||||
|
triton_ms = _time_cuda(triton_impl, warmup_iters, bench_iters)
|
||||||
|
|
||||||
|
speedup = triton_ms / cuda_ms if cuda_ms else math.inf
|
||||||
|
|
||||||
|
cfg_desc = (
|
||||||
|
f"shape={shape} gs={group_size:<3} col_major={column_major:<5} "
|
||||||
|
f"ue8m0={scale_ue8m0:<5} dtype={dtype}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"{cfg_desc:55} | CUDA {cuda_ms:7.3f} ms | Triton {triton_ms:7.3f} ms | "
|
||||||
|
f"speed-up ×{speedup:5.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--warmup-iters", type=int, default=10)
|
||||||
|
parser.add_argument("--bench-iters", type=int, default=100)
|
||||||
|
parser.add_argument("--dtype", choices=["fp8", "int8", "both"], default="both")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
raise RuntimeError("CUDA device is required to run this benchmark.")
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
warmup_iters, bench_iters = args.warmup_iters, args.bench_iters
|
||||||
|
|
||||||
|
shapes = [(32, 128), (64, 256), (16, 512)]
|
||||||
|
group_sizes = [64, 128]
|
||||||
|
|
||||||
|
dtypes = ["fp8", "int8"] if args.dtype == "both" else [args.dtype]
|
||||||
|
|
||||||
|
header = (
|
||||||
|
"Configuration".ljust(55)
|
||||||
|
+ " | "
|
||||||
|
+ "CUDA (ms)".center(12)
|
||||||
|
+ " | "
|
||||||
|
+ "Triton (ms)".center(13)
|
||||||
|
+ " | "
|
||||||
|
+ "Speed-up"
|
||||||
|
)
|
||||||
|
print(header)
|
||||||
|
print("-" * len(header))
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
for shape in shapes:
|
||||||
|
for gs in group_sizes:
|
||||||
|
if dtype == "fp8":
|
||||||
|
for col_major in (False, True):
|
||||||
|
for ue8m0 in (False, True):
|
||||||
|
_run_single(
|
||||||
|
shape,
|
||||||
|
gs,
|
||||||
|
dtype,
|
||||||
|
column_major=col_major,
|
||||||
|
scale_ue8m0=ue8m0,
|
||||||
|
warmup_iters=warmup_iters,
|
||||||
|
bench_iters=bench_iters,
|
||||||
|
)
|
||||||
|
else: # INT8 has no col-major / ue8m0 switches
|
||||||
|
_run_single(
|
||||||
|
shape,
|
||||||
|
gs,
|
||||||
|
dtype,
|
||||||
|
warmup_iters=warmup_iters,
|
||||||
|
bench_iters=bench_iters,
|
||||||
|
)
|
||||||
@ -5,7 +5,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils import fp8_utils
|
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
|
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
|
||||||
@ -42,3 +42,32 @@ def test_per_token_group_quant_fp8(shape, column_major: bool,
|
|||||||
|
|
||||||
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
|
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
|
||||||
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
|
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
|
||||||
|
@pytest.mark.parametrize("group_size", [64, 128])
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||||
|
def test_per_token_group_quant_int8(shape, group_size: int):
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
num_tokens, hidden_dim = shape
|
||||||
|
|
||||||
|
x = (torch.randn(
|
||||||
|
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8)
|
||||||
|
|
||||||
|
# cuda path
|
||||||
|
out_q, scale = int8_utils.per_token_group_quant_int8(
|
||||||
|
x,
|
||||||
|
group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# triton ref
|
||||||
|
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||||
|
ref_q, ref_s = int8_utils.per_token_group_quant_int8(
|
||||||
|
x,
|
||||||
|
group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
|
||||||
|
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user