mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:06:03 +08:00
Add PrefixRepetitionRandomDataset to vllm bench serve datasets (#20638)
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
This commit is contained in:
parent
7f89ed248f
commit
00d6cba0cf
@ -26,6 +26,7 @@ from typing import Any, Callable, Optional, Union
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
@ -486,7 +487,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="random",
|
||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
|
||||
choices=[
|
||||
"sharegpt", "burstgpt", "sonnet", "random", "hf", "custom",
|
||||
"prefix_repetition"
|
||||
],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -603,6 +607,37 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
"from the sampled HF dataset.",
|
||||
)
|
||||
|
||||
prefix_repetition_group = parser.add_argument_group(
|
||||
"prefix repetition dataset options")
|
||||
prefix_repetition_group.add_argument(
|
||||
"--prefix-repetition-prefix-len",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of prefix tokens per request, used only for prefix "
|
||||
"repetition dataset.",
|
||||
)
|
||||
prefix_repetition_group.add_argument(
|
||||
"--prefix-repetition-suffix-len",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of suffix tokens per request, used only for prefix "
|
||||
"repetition dataset. Total input length is prefix_len + suffix_len.",
|
||||
)
|
||||
prefix_repetition_group.add_argument(
|
||||
"--prefix-repetition-num-prefixes",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of prefixes to generate, used only for prefix repetition "
|
||||
"dataset. Prompts per prefix is num_requests // num_prefixes.",
|
||||
)
|
||||
prefix_repetition_group.add_argument(
|
||||
"--prefix-repetition-output-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Number of output tokens per request, used only for prefix "
|
||||
"repetition dataset.",
|
||||
)
|
||||
|
||||
|
||||
def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
if args.dataset_name == "custom":
|
||||
@ -721,6 +756,17 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
output_len=args.random_output_len,
|
||||
range_ratio=args.random_range_ratio,
|
||||
),
|
||||
"prefix_repetition":
|
||||
lambda: PrefixRepetitionRandomDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
prefix_len=args.prefix_repetition_prefix_len,
|
||||
suffix_len=args.prefix_repetition_suffix_len,
|
||||
num_prefixes=args.prefix_repetition_num_prefixes,
|
||||
output_len=args.prefix_repetition_output_len,
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
@ -828,7 +874,9 @@ class CustomDataset(BenchmarkDataset):
|
||||
# Sonnet Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@deprecated(
|
||||
"SonnetDataset is deprecated and will be removed in a future version.",
|
||||
)
|
||||
class SonnetDataset(BenchmarkDataset):
|
||||
"""
|
||||
Simplified implementation of the Sonnet dataset. Loads poem lines from a
|
||||
@ -1537,3 +1585,84 @@ class MLPerfDataset(HuggingFaceDataset):
|
||||
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Prefix Repetition Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
||||
# Default values copied from benchmark_serving.py for the repeated prefix
|
||||
# dataset.
|
||||
DEFAULT_PREFIX_LEN = 256
|
||||
DEFAULT_SUFFIX_LEN = 256
|
||||
DEFAULT_NUM_PREFIXES = 10
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
random.seed(self.random_seed)
|
||||
np.random.seed(self.random_seed)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||
suffix_len: int = DEFAULT_SUFFIX_LEN,
|
||||
num_prefixes: int = DEFAULT_NUM_PREFIXES,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
vocab_size = tokenizer.vocab_size
|
||||
prompts_per_prefix = num_requests // num_prefixes
|
||||
if prompts_per_prefix == 0:
|
||||
raise ValueError(
|
||||
f"num_requests ({num_requests}) must be greater than or equal "
|
||||
f"to num_prefixes ({num_prefixes})"
|
||||
)
|
||||
|
||||
def _generate_exact_length_tokens(target_length: int) -> list[int]:
|
||||
"""Generate tokens that decode and re-encode to exactly
|
||||
target_length."""
|
||||
# Generate random tokens
|
||||
tokens = np.random.randint(
|
||||
0, vocab_size, size=target_length).tolist()
|
||||
text = tokenizer.decode(tokens)
|
||||
re_encoded = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
if len(re_encoded) == target_length:
|
||||
return re_encoded
|
||||
elif len(re_encoded) < target_length:
|
||||
# Recursively generate additional consistent tokens
|
||||
needed = target_length - len(re_encoded)
|
||||
extra_tokens = _generate_exact_length_tokens(needed)
|
||||
return re_encoded + extra_tokens
|
||||
else:
|
||||
# Truncate to target length
|
||||
return re_encoded[:target_length]
|
||||
|
||||
requests = []
|
||||
for _ in range(num_prefixes):
|
||||
prefix_tokens = _generate_exact_length_tokens(prefix_len)
|
||||
|
||||
for _ in range(prompts_per_prefix):
|
||||
suffix_tokens = _generate_exact_length_tokens(suffix_len)
|
||||
|
||||
combined_tokens = prefix_tokens + suffix_tokens
|
||||
prompt = tokenizer.decode(combined_tokens)
|
||||
prompt_len = len(combined_tokens)
|
||||
requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
)
|
||||
)
|
||||
|
||||
random.shuffle(requests)
|
||||
return requests
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user