Add PrefixRepetitionRandomDataset to vllm bench serve datasets (#20638)

Signed-off-by: Seiji Eicher <seiji@anyscale.com>
This commit is contained in:
Seiji Eicher 2025-08-15 14:09:23 -07:00 committed by GitHub
parent 7f89ed248f
commit 00d6cba0cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,6 +26,7 @@ from typing import Any, Callable, Optional, Union
import numpy as np
from PIL import Image
from transformers import PreTrainedTokenizerBase
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
@ -486,7 +487,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"--dataset-name",
type=str,
default="random",
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
choices=[
"sharegpt", "burstgpt", "sonnet", "random", "hf", "custom",
"prefix_repetition"
],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
@ -603,6 +607,37 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"from the sampled HF dataset.",
)
prefix_repetition_group = parser.add_argument_group(
"prefix repetition dataset options")
prefix_repetition_group.add_argument(
"--prefix-repetition-prefix-len",
type=int,
default=256,
help="Number of prefix tokens per request, used only for prefix "
"repetition dataset.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-suffix-len",
type=int,
default=256,
help="Number of suffix tokens per request, used only for prefix "
"repetition dataset. Total input length is prefix_len + suffix_len.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-num-prefixes",
type=int,
default=10,
help="Number of prefixes to generate, used only for prefix repetition "
"dataset. Prompts per prefix is num_requests // num_prefixes.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-output-len",
type=int,
default=128,
help="Number of output tokens per request, used only for prefix "
"repetition dataset.",
)
def get_samples(args, tokenizer) -> list[SampleRequest]:
if args.dataset_name == "custom":
@ -721,6 +756,17 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
output_len=args.random_output_len,
range_ratio=args.random_range_ratio,
),
"prefix_repetition":
lambda: PrefixRepetitionRandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.prefix_repetition_prefix_len,
suffix_len=args.prefix_repetition_suffix_len,
num_prefixes=args.prefix_repetition_num_prefixes,
output_len=args.prefix_repetition_output_len,
),
}
try:
@ -828,7 +874,9 @@ class CustomDataset(BenchmarkDataset):
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------
@deprecated(
"SonnetDataset is deprecated and will be removed in a future version.",
)
class SonnetDataset(BenchmarkDataset):
"""
Simplified implementation of the Sonnet dataset. Loads poem lines from a
@ -1537,3 +1585,84 @@ class MLPerfDataset(HuggingFaceDataset):
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# -----------------------------------------------------------------------------
# Prefix Repetition Dataset Implementation
# -----------------------------------------------------------------------------
class PrefixRepetitionRandomDataset(BenchmarkDataset):
# Default values copied from benchmark_serving.py for the repeated prefix
# dataset.
DEFAULT_PREFIX_LEN = 256
DEFAULT_SUFFIX_LEN = 256
DEFAULT_NUM_PREFIXES = 10
DEFAULT_OUTPUT_LEN = 128
def __init__(
self,
**kwargs,
) -> None:
super().__init__(**kwargs)
random.seed(self.random_seed)
np.random.seed(self.random_seed)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
prefix_len: int = DEFAULT_PREFIX_LEN,
suffix_len: int = DEFAULT_SUFFIX_LEN,
num_prefixes: int = DEFAULT_NUM_PREFIXES,
output_len: int = DEFAULT_OUTPUT_LEN,
**kwargs,
) -> list[SampleRequest]:
vocab_size = tokenizer.vocab_size
prompts_per_prefix = num_requests // num_prefixes
if prompts_per_prefix == 0:
raise ValueError(
f"num_requests ({num_requests}) must be greater than or equal "
f"to num_prefixes ({num_prefixes})"
)
def _generate_exact_length_tokens(target_length: int) -> list[int]:
"""Generate tokens that decode and re-encode to exactly
target_length."""
# Generate random tokens
tokens = np.random.randint(
0, vocab_size, size=target_length).tolist()
text = tokenizer.decode(tokens)
re_encoded = tokenizer.encode(text, add_special_tokens=False)
if len(re_encoded) == target_length:
return re_encoded
elif len(re_encoded) < target_length:
# Recursively generate additional consistent tokens
needed = target_length - len(re_encoded)
extra_tokens = _generate_exact_length_tokens(needed)
return re_encoded + extra_tokens
else:
# Truncate to target length
return re_encoded[:target_length]
requests = []
for _ in range(num_prefixes):
prefix_tokens = _generate_exact_length_tokens(prefix_len)
for _ in range(prompts_per_prefix):
suffix_tokens = _generate_exact_length_tokens(suffix_len)
combined_tokens = prefix_tokens + suffix_tokens
prompt = tokenizer.decode(combined_tokens)
prompt_len = len(combined_tokens)
requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
)
)
random.shuffle(requests)
return requests