mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:45:44 +08:00
107 lines
3.8 KiB
Python
107 lines
3.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import itertools
|
|
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.triton_utils import triton
|
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
|
|
|
batch_size_range = [2**i for i in range(0, 8, 2)]
|
|
seq_len_range = [2**i for i in range(6, 10, 1)]
|
|
num_heads_range = [32, 48]
|
|
configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range))
|
|
|
|
|
|
def get_benchmark(head_size, rotary_dim, is_neox_style, device):
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["batch_size", "seq_len", "num_heads"],
|
|
x_vals=[list(_) for _ in configs],
|
|
line_arg="provider",
|
|
line_vals=["torch", "flashinfer", "vllm"],
|
|
line_names=["PyTorch", "FlashInfer", "vLLM"],
|
|
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
|
ylabel="us",
|
|
plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark(batch_size, seq_len, num_heads, provider):
|
|
dtype = torch.bfloat16
|
|
max_position = 8192
|
|
base = 10000
|
|
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
|
rope = rope.to(dtype=dtype, device=device)
|
|
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
|
|
|
|
positions = torch.randint(0, max_position, (batch_size, seq_len), device=device)
|
|
query = torch.randn(
|
|
(batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device
|
|
)
|
|
key = torch.randn_like(query)
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
|
|
if provider == "torch":
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: rope.forward_native(positions, query.clone(), key.clone()),
|
|
quantiles=quantiles,
|
|
)
|
|
elif provider == "flashinfer":
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: torch.ops.vllm.flashinfer_rotary_embedding(
|
|
positions,
|
|
query.clone(),
|
|
key.clone(),
|
|
head_size,
|
|
cos_sin_cache,
|
|
is_neox_style,
|
|
),
|
|
quantiles=quantiles,
|
|
)
|
|
else:
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: rope.forward_cuda(positions, query.clone(), key.clone()),
|
|
quantiles=quantiles,
|
|
)
|
|
|
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
|
|
|
return benchmark
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = FlexibleArgumentParser(
|
|
description="Benchmark the rotary embedding kernels."
|
|
)
|
|
parser.add_argument("--is-neox-style", type=bool, default=True)
|
|
parser.add_argument("--batch-size", type=int, default=16)
|
|
parser.add_argument("--seq-len", type=int, default=512)
|
|
parser.add_argument("--num-heads", type=int, default=8)
|
|
parser.add_argument(
|
|
"--head-size",
|
|
type=int,
|
|
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
|
default=128,
|
|
)
|
|
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
|
parser.add_argument(
|
|
"--dtype", type=str, choices=["bfloat16", "float"], default="float"
|
|
)
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
parser.add_argument(
|
|
"--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
|
|
)
|
|
parser.add_argument("--save-path", type=str, default="./configs/rope/")
|
|
args = parser.parse_args()
|
|
|
|
# Get the benchmark function
|
|
benchmark = get_benchmark(
|
|
args.head_size, args.rotary_dim, args.is_neox_style, args.device
|
|
)
|
|
# Run performance benchmark
|
|
benchmark.run(print_data=True, save_path=args.save_path)
|