Add PrefixRepetitionRandomDataset to vllm bench serve datasets (#20638)

Signed-off-by: Seiji Eicher <seiji@anyscale.com>
This commit is contained in:
Seiji Eicher 2025-08-15 14:09:23 -07:00 committed by GitHub
parent 7f89ed248f
commit 00d6cba0cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,6 +26,7 @@ from typing import Any, Callable, Optional, Union
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
@ -486,7 +487,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"--dataset-name", "--dataset-name",
type=str, type=str,
default="random", default="random",
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], choices=[
"sharegpt", "burstgpt", "sonnet", "random", "hf", "custom",
"prefix_repetition"
],
help="Name of the dataset to benchmark on.", help="Name of the dataset to benchmark on.",
) )
parser.add_argument( parser.add_argument(
@ -603,6 +607,37 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"from the sampled HF dataset.", "from the sampled HF dataset.",
) )
prefix_repetition_group = parser.add_argument_group(
"prefix repetition dataset options")
prefix_repetition_group.add_argument(
"--prefix-repetition-prefix-len",
type=int,
default=256,
help="Number of prefix tokens per request, used only for prefix "
"repetition dataset.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-suffix-len",
type=int,
default=256,
help="Number of suffix tokens per request, used only for prefix "
"repetition dataset. Total input length is prefix_len + suffix_len.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-num-prefixes",
type=int,
default=10,
help="Number of prefixes to generate, used only for prefix repetition "
"dataset. Prompts per prefix is num_requests // num_prefixes.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-output-len",
type=int,
default=128,
help="Number of output tokens per request, used only for prefix "
"repetition dataset.",
)
def get_samples(args, tokenizer) -> list[SampleRequest]: def get_samples(args, tokenizer) -> list[SampleRequest]:
if args.dataset_name == "custom": if args.dataset_name == "custom":
@ -721,6 +756,17 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
output_len=args.random_output_len, output_len=args.random_output_len,
range_ratio=args.random_range_ratio, range_ratio=args.random_range_ratio,
), ),
"prefix_repetition":
lambda: PrefixRepetitionRandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.prefix_repetition_prefix_len,
suffix_len=args.prefix_repetition_suffix_len,
num_prefixes=args.prefix_repetition_num_prefixes,
output_len=args.prefix_repetition_output_len,
),
} }
try: try:
@ -828,7 +874,9 @@ class CustomDataset(BenchmarkDataset):
# Sonnet Dataset Implementation # Sonnet Dataset Implementation
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@deprecated(
"SonnetDataset is deprecated and will be removed in a future version.",
)
class SonnetDataset(BenchmarkDataset): class SonnetDataset(BenchmarkDataset):
""" """
Simplified implementation of the Sonnet dataset. Loads poem lines from a Simplified implementation of the Sonnet dataset. Loads poem lines from a
@ -1537,3 +1585,84 @@ class MLPerfDataset(HuggingFaceDataset):
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
# -----------------------------------------------------------------------------
# Prefix Repetition Dataset Implementation
# -----------------------------------------------------------------------------
class PrefixRepetitionRandomDataset(BenchmarkDataset):
# Default values copied from benchmark_serving.py for the repeated prefix
# dataset.
DEFAULT_PREFIX_LEN = 256
DEFAULT_SUFFIX_LEN = 256
DEFAULT_NUM_PREFIXES = 10
DEFAULT_OUTPUT_LEN = 128
def __init__(
self,
**kwargs,
) -> None:
super().__init__(**kwargs)
random.seed(self.random_seed)
np.random.seed(self.random_seed)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
prefix_len: int = DEFAULT_PREFIX_LEN,
suffix_len: int = DEFAULT_SUFFIX_LEN,
num_prefixes: int = DEFAULT_NUM_PREFIXES,
output_len: int = DEFAULT_OUTPUT_LEN,
**kwargs,
) -> list[SampleRequest]:
vocab_size = tokenizer.vocab_size
prompts_per_prefix = num_requests // num_prefixes
if prompts_per_prefix == 0:
raise ValueError(
f"num_requests ({num_requests}) must be greater than or equal "
f"to num_prefixes ({num_prefixes})"
)
def _generate_exact_length_tokens(target_length: int) -> list[int]:
"""Generate tokens that decode and re-encode to exactly
target_length."""
# Generate random tokens
tokens = np.random.randint(
0, vocab_size, size=target_length).tolist()
text = tokenizer.decode(tokens)
re_encoded = tokenizer.encode(text, add_special_tokens=False)
if len(re_encoded) == target_length:
return re_encoded
elif len(re_encoded) < target_length:
# Recursively generate additional consistent tokens
needed = target_length - len(re_encoded)
extra_tokens = _generate_exact_length_tokens(needed)
return re_encoded + extra_tokens
else:
# Truncate to target length
return re_encoded[:target_length]
requests = []
for _ in range(num_prefixes):
prefix_tokens = _generate_exact_length_tokens(prefix_len)
for _ in range(prompts_per_prefix):
suffix_tokens = _generate_exact_length_tokens(suffix_len)
combined_tokens = prefix_tokens + suffix_tokens
prompt = tokenizer.decode(combined_tokens)
prompt_len = len(combined_tokens)
requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
)
)
random.shuffle(requests)
return requests