mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 03:55:42 +08:00
[Spec Decode][Benchmark] Generalize spec decode offline benchmark to more methods and datasets (#18847)
This commit is contained in:
parent
4b25ab14e2
commit
017ef648e9
@ -137,4 +137,8 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
print(
|
||||||
|
"[WARNING] Use examples/offline_inference/spec_decode.py"
|
||||||
|
" instead of this script."
|
||||||
|
)
|
||||||
main()
|
main()
|
||||||
|
|||||||
137
examples/offline_inference/spec_decode.py
Normal file
137
examples/offline_inference/spec_decode.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
|
||||||
|
from vllm.v1.metrics.reader import Counter, Vector
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
except ImportError:
|
||||||
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
add_dataset_parser(parser)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
default="./examples/data/gsm8k.jsonl",
|
||||||
|
help="downloaded from the eagle repo "
|
||||||
|
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--method", type=str, default="eagle", choices=["ngram", "eagle", "eagle3"]
|
||||||
|
)
|
||||||
|
parser.add_argument("--max-num-seqs", type=int, default=8)
|
||||||
|
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
||||||
|
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
||||||
|
parser.add_argument("--prompt-lookup-min", type=int, default=2)
|
||||||
|
parser.add_argument("--tp", type=int, default=1)
|
||||||
|
parser.add_argument("--draft-tp", type=int, default=1)
|
||||||
|
parser.add_argument("--enforce-eager", action="store_true")
|
||||||
|
parser.add_argument("--enable-chunked-prefill", action="store_true")
|
||||||
|
parser.add_argument("--max-num-batched-tokens", type=int, default=2048)
|
||||||
|
parser.add_argument("--temp", type=float, default=0)
|
||||||
|
parser.add_argument("--top-p", type=float, default=1.0)
|
||||||
|
parser.add_argument("--top-k", type=int, default=-1)
|
||||||
|
parser.add_argument("--print-output", action="store_true")
|
||||||
|
parser.add_argument("--output-len", type=int, default=256)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
args.endpoint_type = "openai-chat"
|
||||||
|
|
||||||
|
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
|
max_model_len = 2048
|
||||||
|
|
||||||
|
prompts = get_samples(args, tokenizer)
|
||||||
|
# add_special_tokens is False to avoid adding bos twice when using chat templates
|
||||||
|
prompt_ids = [
|
||||||
|
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
if args.method == "eagle" or args.method == "eagle3":
|
||||||
|
if args.method == "eagle":
|
||||||
|
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||||
|
elif args.method == "eagle3":
|
||||||
|
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||||
|
speculative_config = {
|
||||||
|
"method": args.method,
|
||||||
|
"model": eagle_dir,
|
||||||
|
"num_speculative_tokens": args.num_spec_tokens,
|
||||||
|
"draft_tensor_parallel_size": args.draft_tp,
|
||||||
|
"max_model_len": max_model_len,
|
||||||
|
}
|
||||||
|
elif args.method == "ngram":
|
||||||
|
speculative_config = {
|
||||||
|
"method": "ngram",
|
||||||
|
"num_speculative_tokens": args.num_spec_tokens,
|
||||||
|
"prompt_lookup_max": args.prompt_lookup_max,
|
||||||
|
"prompt_lookup_min": args.prompt_lookup_min,
|
||||||
|
"max_model_len": max_model_len,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown method: {args.method}")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model_dir,
|
||||||
|
trust_remote_code=True,
|
||||||
|
tensor_parallel_size=args.tp,
|
||||||
|
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||||
|
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||||
|
enforce_eager=args.enforce_eager,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
max_num_seqs=args.max_num_seqs,
|
||||||
|
gpu_memory_utilization=0.8,
|
||||||
|
speculative_config=speculative_config,
|
||||||
|
disable_log_stats=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||||
|
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
|
||||||
|
|
||||||
|
# print the generated text
|
||||||
|
if args.print_output:
|
||||||
|
for output in outputs:
|
||||||
|
print("-" * 50)
|
||||||
|
print(f"prompt: {output.prompt}")
|
||||||
|
print(f"generated text: {output.outputs[0].text}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
metrics = llm.get_metrics()
|
||||||
|
except AssertionError:
|
||||||
|
print("Metrics are not supported in the V0 engine.")
|
||||||
|
return
|
||||||
|
|
||||||
|
num_drafts = num_accepted = 0
|
||||||
|
acceptance_counts = [0] * args.num_spec_tokens
|
||||||
|
for metric in metrics:
|
||||||
|
if metric.name == "vllm:spec_decode_num_drafts":
|
||||||
|
assert isinstance(metric, Counter)
|
||||||
|
num_drafts += metric.value
|
||||||
|
elif metric.name == "vllm:spec_decode_num_accepted_tokens":
|
||||||
|
assert isinstance(metric, Counter)
|
||||||
|
num_accepted += metric.value
|
||||||
|
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
|
||||||
|
assert isinstance(metric, Vector)
|
||||||
|
for pos in range(len(metric.values)):
|
||||||
|
acceptance_counts[pos] += metric.values[pos]
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
|
print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
# print acceptance at each token position
|
||||||
|
for i in range(len(acceptance_counts)):
|
||||||
|
print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -31,6 +31,8 @@ def test_bench_serve(server):
|
|||||||
server.host,
|
server.host,
|
||||||
"--port",
|
"--port",
|
||||||
str(server.port),
|
str(server.port),
|
||||||
|
"--dataset-name",
|
||||||
|
"random",
|
||||||
"--random-input-len",
|
"--random-input-len",
|
||||||
"32",
|
"32",
|
||||||
"--random-output-len",
|
"--random-output-len",
|
||||||
|
|||||||
@ -50,6 +50,11 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
librosa = PlaceholderModule("librosa")
|
librosa = PlaceholderModule("librosa")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
except ImportError:
|
||||||
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@ -458,6 +463,253 @@ class ShareGPTDataset(BenchmarkDataset):
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-prompts",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of prompts to process.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-name",
|
||||||
|
type=str,
|
||||||
|
default="random",
|
||||||
|
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
|
||||||
|
help="Name of the dataset to benchmark on.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the sharegpt/sonnet dataset. "
|
||||||
|
"Or the huggingface dataset ID if using HF dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# group for dataset specific arguments
|
||||||
|
custom_group = parser.add_argument_group("custom dataset options")
|
||||||
|
custom_group.add_argument(
|
||||||
|
"--custom-output-len",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help=
|
||||||
|
"Number of output tokens per request, used only for custom dataset.",
|
||||||
|
)
|
||||||
|
custom_group.add_argument(
|
||||||
|
"--custom-skip-chat-template",
|
||||||
|
action="store_true",
|
||||||
|
help=
|
||||||
|
"Skip applying chat template to prompt, used only for custom dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||||
|
sonnet_group.add_argument(
|
||||||
|
"--sonnet-input-len",
|
||||||
|
type=int,
|
||||||
|
default=550,
|
||||||
|
help=
|
||||||
|
"Number of input tokens per request, used only for sonnet dataset.",
|
||||||
|
)
|
||||||
|
sonnet_group.add_argument(
|
||||||
|
"--sonnet-output-len",
|
||||||
|
type=int,
|
||||||
|
default=150,
|
||||||
|
help=
|
||||||
|
"Number of output tokens per request, used only for sonnet dataset.",
|
||||||
|
)
|
||||||
|
sonnet_group.add_argument(
|
||||||
|
"--sonnet-prefix-len",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help=
|
||||||
|
"Number of prefix tokens per request, used only for sonnet dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
|
||||||
|
sharegpt_group.add_argument(
|
||||||
|
"--sharegpt-output-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Output length for each request. Overrides the output length "
|
||||||
|
"from the ShareGPT dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
random_group = parser.add_argument_group("random dataset options")
|
||||||
|
random_group.add_argument(
|
||||||
|
"--random-input-len",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help=
|
||||||
|
"Number of input tokens per request, used only for random sampling.",
|
||||||
|
)
|
||||||
|
random_group.add_argument(
|
||||||
|
"--random-output-len",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help=
|
||||||
|
"Number of output tokens per request, used only for random sampling.",
|
||||||
|
)
|
||||||
|
random_group.add_argument(
|
||||||
|
"--random-range-ratio",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Range ratio for sampling input/output length, "
|
||||||
|
"used only for random sampling. Must be in the range [0, 1) to define "
|
||||||
|
"a symmetric sampling range"
|
||||||
|
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||||
|
)
|
||||||
|
random_group.add_argument(
|
||||||
|
"--random-prefix-len",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help=("Number of fixed prefix tokens before the random context "
|
||||||
|
"in a request. "
|
||||||
|
"The total input length is the sum of `random-prefix-len` and "
|
||||||
|
"a random "
|
||||||
|
"context length sampled from [input_len * (1 - range_ratio), "
|
||||||
|
"input_len * (1 + range_ratio)]."),
|
||||||
|
)
|
||||||
|
|
||||||
|
hf_group = parser.add_argument_group("hf dataset options")
|
||||||
|
hf_group.add_argument("--hf-subset",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Subset of the HF dataset.")
|
||||||
|
hf_group.add_argument("--hf-split",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Split of the HF dataset.")
|
||||||
|
hf_group.add_argument(
|
||||||
|
"--hf-output-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Output length for each request. Overrides the output lengths "
|
||||||
|
"from the sampled HF dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||||
|
if args.dataset_name == "custom":
|
||||||
|
dataset = CustomDataset(dataset_path=args.dataset_path)
|
||||||
|
input_requests = dataset.sample(
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
output_len=args.custom_output_len,
|
||||||
|
skip_chat_template=args.custom_skip_chat_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif args.dataset_name == "sonnet":
|
||||||
|
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
||||||
|
# For the "sonnet" dataset, formatting depends on the backend.
|
||||||
|
if args.endpoint_type == "openai-chat":
|
||||||
|
input_requests = dataset.sample(
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
input_len=args.sonnet_input_len,
|
||||||
|
output_len=args.sonnet_output_len,
|
||||||
|
prefix_len=args.sonnet_prefix_len,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
return_prompt_formatted=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||||
|
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||||
|
input_requests = dataset.sample(
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
input_len=args.sonnet_input_len,
|
||||||
|
output_len=args.sonnet_output_len,
|
||||||
|
prefix_len=args.sonnet_prefix_len,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
return_prompt_formatted=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif args.dataset_name == "hf":
|
||||||
|
# all following datasets are implemented from the
|
||||||
|
# HuggingFaceDataset base class
|
||||||
|
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_class = VisionArenaDataset
|
||||||
|
args.hf_split = "train"
|
||||||
|
args.hf_subset = None
|
||||||
|
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_class = InstructCoderDataset
|
||||||
|
args.hf_split = "train"
|
||||||
|
elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_class = MTBenchDataset
|
||||||
|
args.hf_split = "train"
|
||||||
|
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_class = ConversationDataset
|
||||||
|
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_class = AIMODataset
|
||||||
|
args.hf_split = "train"
|
||||||
|
elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501
|
||||||
|
dataset_class = NextEditPredictionDataset
|
||||||
|
args.hf_split = "train"
|
||||||
|
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_class = ASRDataset
|
||||||
|
args.hf_split = "train"
|
||||||
|
else:
|
||||||
|
supported_datasets = set([
|
||||||
|
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||||
|
for dataset_name in cls.SUPPORTED_DATASET_PATHS
|
||||||
|
])
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported dataset path: {args.dataset_path}. "
|
||||||
|
"Huggingface dataset only supports dataset_path"
|
||||||
|
f" from one of following: {supported_datasets}. "
|
||||||
|
"Please consider contributing if you would "
|
||||||
|
"like to add support for additional dataset formats.")
|
||||||
|
|
||||||
|
if dataset_class.IS_MULTIMODAL and args.endpoint_type not in [
|
||||||
|
"openai-chat",
|
||||||
|
"openai-audio",
|
||||||
|
]:
|
||||||
|
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||||
|
raise ValueError(
|
||||||
|
"Multi-modal content is only supported on 'openai-chat' and "
|
||||||
|
"'openai-audio' backend.")
|
||||||
|
input_requests = dataset_class(
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
dataset_subset=args.hf_subset,
|
||||||
|
dataset_split=args.hf_split,
|
||||||
|
random_seed=args.seed,
|
||||||
|
).sample(
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
output_len=args.hf_output_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# For datasets that follow a similar structure, use a mapping.
|
||||||
|
dataset_mapping = {
|
||||||
|
"sharegpt":
|
||||||
|
lambda: ShareGPTDataset(random_seed=args.seed,
|
||||||
|
dataset_path=args.dataset_path).sample(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
output_len=args.sharegpt_output_len,
|
||||||
|
),
|
||||||
|
"burstgpt":
|
||||||
|
lambda: BurstGPTDataset(random_seed=args.seed,
|
||||||
|
dataset_path=args.dataset_path).
|
||||||
|
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
|
||||||
|
"random":
|
||||||
|
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
prefix_len=args.random_prefix_len,
|
||||||
|
input_len=args.random_input_len,
|
||||||
|
output_len=args.random_output_len,
|
||||||
|
range_ratio=args.random_range_ratio,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
input_requests = dataset_mapping[args.dataset_name]()
|
||||||
|
except KeyError as err:
|
||||||
|
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
||||||
|
|
||||||
|
return input_requests
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Custom Dataset Implementation
|
# Custom Dataset Implementation
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
@ -32,12 +32,8 @@ import numpy as np
|
|||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.benchmarks.datasets import (AIMODataset, ASRDataset, BurstGPTDataset,
|
from vllm.benchmarks.datasets import (SampleRequest, add_dataset_parser,
|
||||||
ConversationDataset, HuggingFaceDataset,
|
get_samples)
|
||||||
InstructCoderDataset, MTBenchDataset,
|
|
||||||
NextEditPredictionDataset, RandomDataset,
|
|
||||||
SampleRequest, ShareGPTDataset,
|
|
||||||
SonnetDataset, VisionArenaDataset)
|
|
||||||
from vllm.benchmarks.endpoint_request_func import (ASYNC_REQUEST_FUNCS,
|
from vllm.benchmarks.endpoint_request_func import (ASYNC_REQUEST_FUNCS,
|
||||||
OPENAI_COMPATIBLE_BACKENDS,
|
OPENAI_COMPATIBLE_BACKENDS,
|
||||||
RequestFuncInput,
|
RequestFuncInput,
|
||||||
@ -543,6 +539,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
|||||||
|
|
||||||
|
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
|
add_dataset_parser(parser)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--endpoint-type",
|
"--endpoint-type",
|
||||||
type=str,
|
type=str,
|
||||||
@ -571,20 +568,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
default="/v1/completions",
|
default="/v1/completions",
|
||||||
help="API endpoint.",
|
help="API endpoint.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--dataset-name",
|
|
||||||
type=str,
|
|
||||||
default="random",
|
|
||||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
|
|
||||||
help="Name of the dataset to benchmark on.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset-path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to the sharegpt/sonnet dataset. "
|
|
||||||
"Or the huggingface dataset ID if using HF dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-concurrency",
|
"--max-concurrency",
|
||||||
type=int,
|
type=int,
|
||||||
@ -611,12 +594,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||||
)
|
)
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
parser.add_argument(
|
|
||||||
"--num-prompts",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Number of prompts to process.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logprobs",
|
"--logprobs",
|
||||||
type=int,
|
type=int,
|
||||||
@ -648,7 +625,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
"bursty requests. A higher burstiness value (burstiness > 1) "
|
"bursty requests. A higher burstiness value (burstiness > 1) "
|
||||||
"results in a more uniform arrival of requests.",
|
"results in a more uniform arrival of requests.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -739,89 +715,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
|
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
|
||||||
)
|
)
|
||||||
|
|
||||||
# group for dataset specific arguments
|
|
||||||
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
|
||||||
sonnet_group.add_argument(
|
|
||||||
"--sonnet-input-len",
|
|
||||||
type=int,
|
|
||||||
default=550,
|
|
||||||
help=
|
|
||||||
"Number of input tokens per request, used only for sonnet dataset.",
|
|
||||||
)
|
|
||||||
sonnet_group.add_argument(
|
|
||||||
"--sonnet-output-len",
|
|
||||||
type=int,
|
|
||||||
default=150,
|
|
||||||
help=
|
|
||||||
"Number of output tokens per request, used only for sonnet dataset.",
|
|
||||||
)
|
|
||||||
sonnet_group.add_argument(
|
|
||||||
"--sonnet-prefix-len",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help=
|
|
||||||
"Number of prefix tokens per request, used only for sonnet dataset.",
|
|
||||||
)
|
|
||||||
|
|
||||||
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
|
|
||||||
sharegpt_group.add_argument(
|
|
||||||
"--sharegpt-output-len",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Output length for each request. Overrides the output length "
|
|
||||||
"from the ShareGPT dataset.",
|
|
||||||
)
|
|
||||||
|
|
||||||
random_group = parser.add_argument_group("random dataset options")
|
|
||||||
random_group.add_argument(
|
|
||||||
"--random-input-len",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help=
|
|
||||||
"Number of input tokens per request, used only for random sampling.",
|
|
||||||
)
|
|
||||||
random_group.add_argument(
|
|
||||||
"--random-output-len",
|
|
||||||
type=int,
|
|
||||||
default=128,
|
|
||||||
help=
|
|
||||||
"Number of output tokens per request, used only for random sampling.",
|
|
||||||
)
|
|
||||||
random_group.add_argument(
|
|
||||||
"--random-range-ratio",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="Range ratio for sampling input/output length, "
|
|
||||||
"used only for random sampling. Must be in the range [0, 1) to define "
|
|
||||||
"a symmetric sampling range"
|
|
||||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
|
||||||
)
|
|
||||||
random_group.add_argument(
|
|
||||||
"--random-prefix-len",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Number of fixed prefix tokens before random "
|
|
||||||
" context. The length range of context in a random "
|
|
||||||
" request is [random-prefix-len, "
|
|
||||||
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
|
||||||
|
|
||||||
hf_group = parser.add_argument_group("hf dataset options")
|
|
||||||
hf_group.add_argument("--hf-subset",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Subset of the HF dataset.")
|
|
||||||
hf_group.add_argument("--hf-split",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Split of the HF dataset.")
|
|
||||||
hf_group.add_argument(
|
|
||||||
"--hf-output-len",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Output length for each request. Overrides the output lengths "
|
|
||||||
"from the sampled HF dataset.",
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_group = parser.add_argument_group("sampling parameters")
|
sampling_group = parser.add_argument_group("sampling parameters")
|
||||||
sampling_group.add_argument(
|
sampling_group.add_argument(
|
||||||
"--top-p",
|
"--top-p",
|
||||||
@ -884,7 +777,6 @@ def main(args: argparse.Namespace):
|
|||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
endpoint_type = args.endpoint_type
|
|
||||||
label = args.label
|
label = args.label
|
||||||
model_id = args.model
|
model_id = args.model
|
||||||
model_name = args.served_model_name
|
model_name = args.served_model_name
|
||||||
@ -907,115 +799,8 @@ def main(args: argparse.Namespace):
|
|||||||
"Please specify '--dataset-name' and the corresponding "
|
"Please specify '--dataset-name' and the corresponding "
|
||||||
"'--dataset-path' if required.")
|
"'--dataset-path' if required.")
|
||||||
|
|
||||||
if args.dataset_name == "sonnet":
|
# Load the dataset.
|
||||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
input_requests = get_samples(args, tokenizer)
|
||||||
# For the "sonnet" dataset, formatting depends on the backend.
|
|
||||||
if args.backend == "openai-chat":
|
|
||||||
input_requests = dataset.sample(
|
|
||||||
num_requests=args.num_prompts,
|
|
||||||
input_len=args.sonnet_input_len,
|
|
||||||
output_len=args.sonnet_output_len,
|
|
||||||
prefix_len=args.sonnet_prefix_len,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
return_prompt_formatted=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
|
||||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
|
||||||
input_requests = dataset.sample(
|
|
||||||
num_requests=args.num_prompts,
|
|
||||||
input_len=args.sonnet_input_len,
|
|
||||||
output_len=args.sonnet_output_len,
|
|
||||||
prefix_len=args.sonnet_prefix_len,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
return_prompt_formatted=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif args.dataset_name == "hf":
|
|
||||||
# all following datasets are implemented from the
|
|
||||||
# HuggingFaceDataset base class
|
|
||||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
|
||||||
dataset_class = VisionArenaDataset
|
|
||||||
args.hf_split = "train"
|
|
||||||
args.hf_subset = None
|
|
||||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
|
||||||
dataset_class = InstructCoderDataset
|
|
||||||
args.hf_split = "train"
|
|
||||||
elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
|
|
||||||
dataset_class = MTBenchDataset
|
|
||||||
args.hf_split = "train"
|
|
||||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
|
||||||
dataset_class = ConversationDataset
|
|
||||||
args.hf_split = "train"
|
|
||||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
|
||||||
dataset_class = AIMODataset
|
|
||||||
args.hf_split = "train"
|
|
||||||
elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501
|
|
||||||
dataset_class = NextEditPredictionDataset
|
|
||||||
args.hf_split = "train"
|
|
||||||
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
|
||||||
dataset_class = ASRDataset
|
|
||||||
args.hf_split = "train"
|
|
||||||
else:
|
|
||||||
supported_datasets = set([
|
|
||||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
|
||||||
for dataset_name in cls.SUPPORTED_DATASET_PATHS
|
|
||||||
])
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported dataset path: {args.dataset_path}. "
|
|
||||||
"Huggingface dataset only supports dataset_path"
|
|
||||||
f" from one of following: {supported_datasets}. "
|
|
||||||
"Please consider contributing if you would "
|
|
||||||
"like to add support for additional dataset formats.")
|
|
||||||
|
|
||||||
if dataset_class.IS_MULTIMODAL and endpoint_type not in [
|
|
||||||
"openai-chat",
|
|
||||||
"openai-audio",
|
|
||||||
]:
|
|
||||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
|
||||||
raise ValueError(
|
|
||||||
"Multi-modal content is only supported on 'openai-chat' and "
|
|
||||||
"'openai-audio' backend.")
|
|
||||||
input_requests = dataset_class(
|
|
||||||
dataset_path=args.dataset_path,
|
|
||||||
dataset_subset=args.hf_subset,
|
|
||||||
dataset_split=args.hf_split,
|
|
||||||
random_seed=args.seed,
|
|
||||||
).sample(
|
|
||||||
num_requests=args.num_prompts,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
output_len=args.hf_output_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# For datasets that follow a similar structure, use a mapping.
|
|
||||||
dataset_mapping = {
|
|
||||||
"sharegpt":
|
|
||||||
lambda: ShareGPTDataset(random_seed=args.seed,
|
|
||||||
dataset_path=args.dataset_path).sample(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_requests=args.num_prompts,
|
|
||||||
output_len=args.sharegpt_output_len,
|
|
||||||
),
|
|
||||||
"burstgpt":
|
|
||||||
lambda: BurstGPTDataset(random_seed=args.seed,
|
|
||||||
dataset_path=args.dataset_path).
|
|
||||||
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
|
|
||||||
"random":
|
|
||||||
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_requests=args.num_prompts,
|
|
||||||
prefix_len=args.random_prefix_len,
|
|
||||||
input_len=args.random_input_len,
|
|
||||||
output_len=args.random_output_len,
|
|
||||||
range_ratio=args.random_range_ratio,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
input_requests = dataset_mapping[args.dataset_name]()
|
|
||||||
except KeyError as err:
|
|
||||||
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
|
||||||
goodput_config_dict = check_goodput_args(args)
|
goodput_config_dict = check_goodput_args(args)
|
||||||
|
|
||||||
# Collect the sampling parameters.
|
# Collect the sampling parameters.
|
||||||
@ -1043,7 +828,7 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
benchmark_result = asyncio.run(
|
benchmark_result = asyncio.run(
|
||||||
benchmark(
|
benchmark(
|
||||||
endpoint_type=endpoint_type,
|
endpoint_type=args.endpoint_type,
|
||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -1073,7 +858,7 @@ def main(args: argparse.Namespace):
|
|||||||
# Setup
|
# Setup
|
||||||
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
result_json["date"] = current_dt
|
result_json["date"] = current_dt
|
||||||
result_json["endpoint_type"] = endpoint_type
|
result_json["endpoint_type"] = args.endpoint_type
|
||||||
result_json["label"] = label
|
result_json["label"] = label
|
||||||
result_json["model_id"] = model_id
|
result_json["model_id"] = model_id
|
||||||
result_json["tokenizer_id"] = tokenizer_id
|
result_json["tokenizer_id"] = tokenizer_id
|
||||||
@ -1118,7 +903,7 @@ def main(args: argparse.Namespace):
|
|||||||
base_model_id = model_id.split("/")[-1]
|
base_model_id = model_id.split("/")[-1]
|
||||||
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||||
if args.max_concurrency is not None else "")
|
if args.max_concurrency is not None else "")
|
||||||
label = label or endpoint_type
|
label = label or args.endpoint_type
|
||||||
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
|
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
|
||||||
if args.result_filename:
|
if args.result_filename:
|
||||||
file_name = args.result_filename
|
file_name = args.result_filename
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user