diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 5299dcf54b39..72d7ce49b8e1 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -26,6 +26,7 @@ from typing import Any, Callable, Optional, Union import numpy as np from PIL import Image from transformers import PreTrainedTokenizerBase +from typing_extensions import deprecated from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path @@ -486,7 +487,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--dataset-name", type=str, default="random", - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], + choices=[ + "sharegpt", "burstgpt", "sonnet", "random", "hf", "custom", + "prefix_repetition" + ], help="Name of the dataset to benchmark on.", ) parser.add_argument( @@ -603,6 +607,37 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "from the sampled HF dataset.", ) + prefix_repetition_group = parser.add_argument_group( + "prefix repetition dataset options") + prefix_repetition_group.add_argument( + "--prefix-repetition-prefix-len", + type=int, + default=256, + help="Number of prefix tokens per request, used only for prefix " + "repetition dataset.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-suffix-len", + type=int, + default=256, + help="Number of suffix tokens per request, used only for prefix " + "repetition dataset. Total input length is prefix_len + suffix_len.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-num-prefixes", + type=int, + default=10, + help="Number of prefixes to generate, used only for prefix repetition " + "dataset. Prompts per prefix is num_requests // num_prefixes.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for prefix " + "repetition dataset.", + ) + def get_samples(args, tokenizer) -> list[SampleRequest]: if args.dataset_name == "custom": @@ -721,6 +756,17 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: output_len=args.random_output_len, range_ratio=args.random_range_ratio, ), + "prefix_repetition": + lambda: PrefixRepetitionRandomDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.prefix_repetition_prefix_len, + suffix_len=args.prefix_repetition_suffix_len, + num_prefixes=args.prefix_repetition_num_prefixes, + output_len=args.prefix_repetition_output_len, + ), } try: @@ -828,7 +874,9 @@ class CustomDataset(BenchmarkDataset): # Sonnet Dataset Implementation # ----------------------------------------------------------------------------- - +@deprecated( + "SonnetDataset is deprecated and will be removed in a future version.", +) class SonnetDataset(BenchmarkDataset): """ Simplified implementation of the Sonnet dataset. Loads poem lines from a @@ -1537,3 +1585,84 @@ class MLPerfDataset(HuggingFaceDataset): self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests + + +# ----------------------------------------------------------------------------- +# Prefix Repetition Dataset Implementation +# ----------------------------------------------------------------------------- + + +class PrefixRepetitionRandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the repeated prefix + # dataset. + DEFAULT_PREFIX_LEN = 256 + DEFAULT_SUFFIX_LEN = 256 + DEFAULT_NUM_PREFIXES = 10 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + random.seed(self.random_seed) + np.random.seed(self.random_seed) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + suffix_len: int = DEFAULT_SUFFIX_LEN, + num_prefixes: int = DEFAULT_NUM_PREFIXES, + output_len: int = DEFAULT_OUTPUT_LEN, + **kwargs, + ) -> list[SampleRequest]: + vocab_size = tokenizer.vocab_size + prompts_per_prefix = num_requests // num_prefixes + if prompts_per_prefix == 0: + raise ValueError( + f"num_requests ({num_requests}) must be greater than or equal " + f"to num_prefixes ({num_prefixes})" + ) + + def _generate_exact_length_tokens(target_length: int) -> list[int]: + """Generate tokens that decode and re-encode to exactly + target_length.""" + # Generate random tokens + tokens = np.random.randint( + 0, vocab_size, size=target_length).tolist() + text = tokenizer.decode(tokens) + re_encoded = tokenizer.encode(text, add_special_tokens=False) + + if len(re_encoded) == target_length: + return re_encoded + elif len(re_encoded) < target_length: + # Recursively generate additional consistent tokens + needed = target_length - len(re_encoded) + extra_tokens = _generate_exact_length_tokens(needed) + return re_encoded + extra_tokens + else: + # Truncate to target length + return re_encoded[:target_length] + + requests = [] + for _ in range(num_prefixes): + prefix_tokens = _generate_exact_length_tokens(prefix_len) + + for _ in range(prompts_per_prefix): + suffix_tokens = _generate_exact_length_tokens(suffix_len) + + combined_tokens = prefix_tokens + suffix_tokens + prompt = tokenizer.decode(combined_tokens) + prompt_len = len(combined_tokens) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + ) + ) + + random.shuffle(requests) + return requests