mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 11:24:31 +08:00
Fix rotary embedding benchmark script (#28323)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
parent
f2d9ad0620
commit
57201a6a4c
@ -1,97 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from itertools import accumulate
|
||||
import itertools
|
||||
|
||||
import nvtx
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
batch_size_range = [2**i for i in range(0, 8, 2)]
|
||||
seq_len_range = [2**i for i in range(6, 10, 1)]
|
||||
num_heads_range = [32, 48]
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range))
|
||||
|
||||
def benchmark_rope_kernels_multi_lora(
|
||||
is_neox_style: bool,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: int | None,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position: int = 8192,
|
||||
base: float = 10000,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
# silulating serving 4 LoRAs
|
||||
scaling_factors = [1, 2, 4, 8]
|
||||
# batched RoPE can take multiple scaling factors
|
||||
batched_rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
{"rope_type": "linear", "factor": tuple(scaling_factors)},
|
||||
)
|
||||
# non-batched RoPE takes only one scaling factor, we create multiple
|
||||
# instances to simulate the same behavior
|
||||
non_batched_ropes: list[RotaryEmbedding] = []
|
||||
for scaling_factor in scaling_factors:
|
||||
non_batched_ropes.append(
|
||||
get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
{"rope_type": "linear", "factor": (scaling_factor,)},
|
||||
)
|
||||
)
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
# create query offsets for batched RoPE, we concat multiple kv cache
|
||||
# together and each query needs to find the right kv cache of its type
|
||||
offset_map = torch.tensor(
|
||||
list(
|
||||
accumulate(
|
||||
[0]
|
||||
+ [
|
||||
max_position * scaling_factor * 2
|
||||
for scaling_factor in scaling_factors[:-1]
|
||||
]
|
||||
)
|
||||
def get_benchmark(head_size, rotary_dim, is_neox_style, device):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len", "num_heads"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "flashinfer", "vllm"],
|
||||
line_names=["PyTorch", "FlashInfer", "vLLM"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
query_types = torch.randint(
|
||||
0, len(scaling_factors), (batch_size, seq_len), device=device
|
||||
)
|
||||
# map query types to offsets
|
||||
query_offsets = offset_map[query_types]
|
||||
# the kernel takes flattened offsets
|
||||
flatten_offsets = query_offsets.flatten()
|
||||
def benchmark(batch_size, seq_len, num_heads, provider):
|
||||
dtype = torch.bfloat16
|
||||
max_position = 8192
|
||||
base = 10000
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
||||
rope = rope.to(dtype=dtype, device=device)
|
||||
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
|
||||
|
||||
# batched queries of the same type together for non-batched RoPE
|
||||
queries = [query[query_types == i] for i in range(len(scaling_factors))]
|
||||
keys = [key[query_types == i] for i in range(len(scaling_factors))]
|
||||
packed_qkr = zip(queries, keys, non_batched_ropes)
|
||||
# synchronize before start timing
|
||||
torch.cuda.synchronize()
|
||||
with nvtx.annotate("non-batched", color="yellow"):
|
||||
for q, k, r in packed_qkr:
|
||||
r.forward(positions, q, k)
|
||||
torch.cuda.synchronize()
|
||||
with nvtx.annotate("batched", color="green"):
|
||||
batched_rope.forward(positions, query, key, flatten_offsets)
|
||||
torch.cuda.synchronize()
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len), device=device)
|
||||
query = torch.randn(
|
||||
(batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device
|
||||
)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rope.forward_native(positions, query.clone(), key.clone()),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "flashinfer":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch.ops.vllm.flashinfer_rotary_embedding(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rope.forward_cuda(positions, query.clone(), key.clone()),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -116,17 +95,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
|
||||
)
|
||||
parser.add_argument("--save-path", type=str, default="./configs/rope/")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
benchmark_rope_kernels_multi_lora(
|
||||
is_neox_style=args.is_neox_style,
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
num_heads=args.num_heads,
|
||||
head_size=args.head_size,
|
||||
rotary_dim=args.rotary_dim,
|
||||
dtype=getattr(torch, args.dtype),
|
||||
seed=args.seed,
|
||||
device=args.device,
|
||||
# Get the benchmark function
|
||||
benchmark = get_benchmark(
|
||||
args.head_size, args.rotary_dim, args.is_neox_style, args.device
|
||||
)
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user