mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 11:06:08 +08:00
Add PrefixRepetitionRandomDataset to vllm bench serve datasets (#20638)
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
This commit is contained in:
parent
7f89ed248f
commit
00d6cba0cf
@ -26,6 +26,7 @@ from typing import Any, Callable, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.lora.utils import get_adapter_absolute_path
|
from vllm.lora.utils import get_adapter_absolute_path
|
||||||
@ -486,7 +487,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
|||||||
"--dataset-name",
|
"--dataset-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="random",
|
default="random",
|
||||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
|
choices=[
|
||||||
|
"sharegpt", "burstgpt", "sonnet", "random", "hf", "custom",
|
||||||
|
"prefix_repetition"
|
||||||
|
],
|
||||||
help="Name of the dataset to benchmark on.",
|
help="Name of the dataset to benchmark on.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -603,6 +607,37 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
|||||||
"from the sampled HF dataset.",
|
"from the sampled HF dataset.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prefix_repetition_group = parser.add_argument_group(
|
||||||
|
"prefix repetition dataset options")
|
||||||
|
prefix_repetition_group.add_argument(
|
||||||
|
"--prefix-repetition-prefix-len",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Number of prefix tokens per request, used only for prefix "
|
||||||
|
"repetition dataset.",
|
||||||
|
)
|
||||||
|
prefix_repetition_group.add_argument(
|
||||||
|
"--prefix-repetition-suffix-len",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Number of suffix tokens per request, used only for prefix "
|
||||||
|
"repetition dataset. Total input length is prefix_len + suffix_len.",
|
||||||
|
)
|
||||||
|
prefix_repetition_group.add_argument(
|
||||||
|
"--prefix-repetition-num-prefixes",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Number of prefixes to generate, used only for prefix repetition "
|
||||||
|
"dataset. Prompts per prefix is num_requests // num_prefixes.",
|
||||||
|
)
|
||||||
|
prefix_repetition_group.add_argument(
|
||||||
|
"--prefix-repetition-output-len",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help="Number of output tokens per request, used only for prefix "
|
||||||
|
"repetition dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_samples(args, tokenizer) -> list[SampleRequest]:
|
def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||||
if args.dataset_name == "custom":
|
if args.dataset_name == "custom":
|
||||||
@ -721,6 +756,17 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
output_len=args.random_output_len,
|
output_len=args.random_output_len,
|
||||||
range_ratio=args.random_range_ratio,
|
range_ratio=args.random_range_ratio,
|
||||||
),
|
),
|
||||||
|
"prefix_repetition":
|
||||||
|
lambda: PrefixRepetitionRandomDataset(
|
||||||
|
random_seed=args.seed, dataset_path=args.dataset_path
|
||||||
|
).sample(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
prefix_len=args.prefix_repetition_prefix_len,
|
||||||
|
suffix_len=args.prefix_repetition_suffix_len,
|
||||||
|
num_prefixes=args.prefix_repetition_num_prefixes,
|
||||||
|
output_len=args.prefix_repetition_output_len,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -828,7 +874,9 @@ class CustomDataset(BenchmarkDataset):
|
|||||||
# Sonnet Dataset Implementation
|
# Sonnet Dataset Implementation
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
"SonnetDataset is deprecated and will be removed in a future version.",
|
||||||
|
)
|
||||||
class SonnetDataset(BenchmarkDataset):
|
class SonnetDataset(BenchmarkDataset):
|
||||||
"""
|
"""
|
||||||
Simplified implementation of the Sonnet dataset. Loads poem lines from a
|
Simplified implementation of the Sonnet dataset. Loads poem lines from a
|
||||||
@ -1537,3 +1585,84 @@ class MLPerfDataset(HuggingFaceDataset):
|
|||||||
|
|
||||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Prefix Repetition Dataset Implementation
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
||||||
|
# Default values copied from benchmark_serving.py for the repeated prefix
|
||||||
|
# dataset.
|
||||||
|
DEFAULT_PREFIX_LEN = 256
|
||||||
|
DEFAULT_SUFFIX_LEN = 256
|
||||||
|
DEFAULT_NUM_PREFIXES = 10
|
||||||
|
DEFAULT_OUTPUT_LEN = 128
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
random.seed(self.random_seed)
|
||||||
|
np.random.seed(self.random_seed)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
num_requests: int,
|
||||||
|
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||||
|
suffix_len: int = DEFAULT_SUFFIX_LEN,
|
||||||
|
num_prefixes: int = DEFAULT_NUM_PREFIXES,
|
||||||
|
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[SampleRequest]:
|
||||||
|
vocab_size = tokenizer.vocab_size
|
||||||
|
prompts_per_prefix = num_requests // num_prefixes
|
||||||
|
if prompts_per_prefix == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"num_requests ({num_requests}) must be greater than or equal "
|
||||||
|
f"to num_prefixes ({num_prefixes})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_exact_length_tokens(target_length: int) -> list[int]:
|
||||||
|
"""Generate tokens that decode and re-encode to exactly
|
||||||
|
target_length."""
|
||||||
|
# Generate random tokens
|
||||||
|
tokens = np.random.randint(
|
||||||
|
0, vocab_size, size=target_length).tolist()
|
||||||
|
text = tokenizer.decode(tokens)
|
||||||
|
re_encoded = tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
if len(re_encoded) == target_length:
|
||||||
|
return re_encoded
|
||||||
|
elif len(re_encoded) < target_length:
|
||||||
|
# Recursively generate additional consistent tokens
|
||||||
|
needed = target_length - len(re_encoded)
|
||||||
|
extra_tokens = _generate_exact_length_tokens(needed)
|
||||||
|
return re_encoded + extra_tokens
|
||||||
|
else:
|
||||||
|
# Truncate to target length
|
||||||
|
return re_encoded[:target_length]
|
||||||
|
|
||||||
|
requests = []
|
||||||
|
for _ in range(num_prefixes):
|
||||||
|
prefix_tokens = _generate_exact_length_tokens(prefix_len)
|
||||||
|
|
||||||
|
for _ in range(prompts_per_prefix):
|
||||||
|
suffix_tokens = _generate_exact_length_tokens(suffix_len)
|
||||||
|
|
||||||
|
combined_tokens = prefix_tokens + suffix_tokens
|
||||||
|
prompt = tokenizer.decode(combined_tokens)
|
||||||
|
prompt_len = len(combined_tokens)
|
||||||
|
requests.append(
|
||||||
|
SampleRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
expected_output_len=output_len,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
random.shuffle(requests)
|
||||||
|
return requests
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user