From 90304003533b68e08d425688c371ba76f017aeba Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 31 Mar 2025 15:25:08 -0700 Subject: [PATCH] add datasets to benchmark_latency --- benchmarks/benchmark_latency.py | 67 ++++++++++++++--- benchmarks/benchmark_throughput.py | 113 ++--------------------------- benchmarks/benchmark_utils.py | 110 ++++++++++++++++++++++++++++ 3 files changed, 171 insertions(+), 119 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index dfd9bb1e6a4d0..5e85a92304d31 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -5,18 +5,21 @@ import argparse import dataclasses import json import os +import random import time from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, Union import numpy as np import torch -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json +from benchmark_utils import (convert_to_pytorch_benchmark_format, get_requests, + validate_dataset, write_to_json) from tqdm import tqdm +from transformers import AutoTokenizer from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs -from vllm.inputs import PromptType +from vllm.inputs import TextPrompt, TokensPrompt from vllm.sampling_params import BeamSearchParams from vllm.utils import FlexibleArgumentParser @@ -55,21 +58,27 @@ def main(args: argparse.Namespace): detokenize=not args.disable_detokenize, ) print(sampling_params) - dummy_prompt_token_ids = np.random.randint(10000, - size=(args.batch_size, - args.input_len)) - dummy_prompts: list[PromptType] = [{ - "prompt_token_ids": batch - } for batch in dummy_prompt_token_ids.tolist()] + + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + requests = get_requests(args.batch_size, args, tokenizer) + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) def llm_generate(): if not args.use_beam_search: - llm.generate(dummy_prompts, + llm.generate(prompts, sampling_params=sampling_params, use_tqdm=False) else: llm.beam_search( - dummy_prompts, + prompts, BeamSearchParams( beam_width=args.n, max_tokens=args.output_len, @@ -180,7 +189,43 @@ if __name__ == "__main__": help=("Do not detokenize responses (i.e. do not include " "detokenization time in the latency measurement)"), ) + parser.add_argument( + "--dataset-name", + type=str, + choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], + help="Name of the dataset to benchmark on.", + default="sharegpt") + # random dataset + parser.add_argument( + "--random-range-ratio", + type=float, + default=None, + help="Range of sampled ratio of input/output length, " + "used only for RandomDataSet.", + ) + parser.add_argument("--dataset-path", + type=str, + default=None, + help="Path to the dataset") + + # LoRA + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to the lora adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.") + + parser.add_argument("--prefix-len", + type=int, + default=None, + help="Number of prefix tokens per request." + "This is for the RandomDataset and SonnetDataset") parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + validate_dataset(args) + random.seed(0) main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1ff63f0a44795..3e3e4272f2ba4 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -11,11 +11,9 @@ from typing import Any, Optional, Union import torch import uvloop -from benchmark_dataset import (BurstGPTDataset, ConversationDataset, - InstructCoderDataset, RandomDataset, - SampleRequest, ShareGPTDataset, SonnetDataset, - VisionArenaDataset) -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json +from benchmark_dataset import SampleRequest +from benchmark_utils import (convert_to_pytorch_benchmark_format, get_requests, + validate_dataset, write_to_json) from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) @@ -287,59 +285,6 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, write_to_json(pt_file, pt_records) -def get_requests(args, tokenizer): - # Common parameters for all dataset types. - common_kwargs = { - "dataset_path": args.dataset_path, - "random_seed": args.seed, - } - sample_kwargs = { - "tokenizer": tokenizer, - "lora_path": args.lora_path, - "max_loras": args.max_loras, - "num_requests": args.num_prompts, - "input_len": args.input_len, - "output_len": args.output_len, - } - - if args.dataset_path is None or args.dataset_name == "random": - sample_kwargs["range_ratio"] = args.random_range_ratio - sample_kwargs["prefix_len"] = args.prefix_len - dataset_cls = RandomDataset - elif args.dataset_name == "sharegpt": - dataset_cls = ShareGPTDataset - if args.backend == "vllm-chat": - sample_kwargs["enable_multimodal_chat"] = True - elif args.dataset_name == "sonnet": - assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset.") - dataset_cls = SonnetDataset - sample_kwargs["prefix_len"] = args.prefix_len - sample_kwargs["return_prompt_formatted"] = True - elif args.dataset_name == "burstgpt": - dataset_cls = BurstGPTDataset - elif args.dataset_name == "hf": - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: - dataset_cls = VisionArenaDataset - common_kwargs['dataset_subset'] = None - common_kwargs['dataset_split'] = "train" - sample_kwargs["enable_multimodal_chat"] = True - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: - dataset_cls = InstructCoderDataset - common_kwargs['dataset_split'] = "train" - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: - dataset_cls = ConversationDataset - common_kwargs['dataset_subset'] = args.hf_subset - common_kwargs['dataset_split'] = args.hf_split - sample_kwargs["enable_multimodal_chat"] = True - - else: - raise ValueError(f"Unknown dataset name: {args.dataset_name}") - # Remove None values - sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} - return dataset_cls(**common_kwargs).sample(**sample_kwargs) - - def main(args: argparse.Namespace): if args.seed is None: args.seed = 0 @@ -348,7 +293,7 @@ def main(args: argparse.Namespace): # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( args.tokenizer, trust_remote_code=args.trust_remote_code) - requests = get_requests(args, tokenizer) + requests = get_requests(args.num_prompts, args, tokenizer) is_multi_modal = any(request.multi_modal_data is not None for request in requests) request_outputs: Optional[list[RequestOutput]] = None @@ -449,47 +394,7 @@ def validate_args(args): if args.backend not in valid_backends: raise ValueError(f"Unsupported backend: {args.backend}") - # === Dataset Configuration === - if not args.dataset and not args.dataset_path: - print( - "When dataset path is not set, it will default to random dataset") - args.dataset_name = 'random' - if args.input_len is None: - raise ValueError("input_len must be provided for a random dataset") - - # === Dataset Name Specific Checks === - # --hf-subset and --hf-split: only used - # when dataset_name is 'hf' - if args.dataset_name != "hf" and ( - getattr(args, "hf_subset", None) is not None - or getattr(args, "hf_split", None) is not None): - warnings.warn("--hf-subset and --hf-split will be ignored \ - since --dataset-name is not 'hf'.", - stacklevel=2) - elif args.dataset_name == "hf": - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: - assert args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend." #noqa: E501 - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: - assert args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend." #noqa: E501 - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: - assert args.backend == "vllm-chat", "ConversationDataset needs to use vllm-chat as the backend." #noqa: E501 - else: - raise ValueError( - f"{args.dataset_path} is not supported by hf dataset.") - - # --random-range-ratio: only used when dataset_name is 'random' - if args.dataset_name != 'random' and args.random_range_ratio is not None: - warnings.warn("--random-range-ratio will be ignored since \ - --dataset-name is not 'random'.", - stacklevel=2) - - # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not - # set. - if args.dataset_name not in {"random", "sonnet", None - } and args.prefix_len is not None: - warnings.warn("--prefix-len will be ignored since --dataset-name\ - is not 'random', 'sonnet', or not set.", - stacklevel=2) + validate_dataset(args) # === LoRA Settings === if getattr(args, "enable_lora", False) and args.backend != "vllm": @@ -529,14 +434,6 @@ if __name__ == "__main__": choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], help="Name of the dataset to benchmark on.", default="sharegpt") - parser.add_argument( - "--dataset", - type=str, - default=None, - help="Path to the ShareGPT dataset, will be deprecated in\ - the next release. The dataset is expected to " - "be a json in form of list[dict[..., conversations: " - "list[dict[..., value: ]]]]") parser.add_argument("--dataset-path", type=str, default=None, diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py index 45a0ddbd5d08d..9400a7e736e2d 100644 --- a/benchmarks/benchmark_utils.py +++ b/benchmarks/benchmark_utils.py @@ -4,8 +4,14 @@ import argparse import json import math import os +import warnings from typing import Any +from benchmark_dataset import (BurstGPTDataset, ConversationDataset, + InstructCoderDataset, RandomDataset, + SampleRequest, ShareGPTDataset, SonnetDataset, + VisionArenaDataset) + def convert_to_pytorch_benchmark_format(args: argparse.Namespace, metrics: dict[str, list], @@ -67,3 +73,107 @@ class InfEncoder(json.JSONEncoder): def write_to_json(filename: str, records: list) -> None: with open(filename, "w") as f: json.dump(records, f, cls=InfEncoder) + + +def get_requests(num_requests: int, args: argparse.Namespace, + tokenizer: Any) -> list[SampleRequest]: + """ + Sample the requests for the benchmark. + """ + # Common parameters for all dataset types. + common_kwargs = { + "dataset_path": args.dataset_path, + "random_seed": args.seed, + } + sample_kwargs = { + "tokenizer": tokenizer, + "lora_path": args.lora_path, + "max_loras": args.max_loras, + "num_requests": num_requests, + "input_len": args.input_len, + "output_len": args.output_len, + } + + if args.dataset_path is None or args.dataset_name == "random": + sample_kwargs["range_ratio"] = args.random_range_ratio + sample_kwargs["prefix_len"] = args.prefix_len + dataset_cls = RandomDataset + elif args.dataset_name == "sharegpt": + dataset_cls = ShareGPTDataset + if getattr(args, "backend", False) and args.backend == "vllm-chat": + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_name == "sonnet": + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + dataset_cls = SonnetDataset + sample_kwargs["prefix_len"] = args.prefix_len + sample_kwargs["return_prompt_formatted"] = True + elif args.dataset_name == "burstgpt": + dataset_cls = BurstGPTDataset + elif args.dataset_name == "hf": + if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = VisionArenaDataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = "train" + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = InstructCoderDataset + common_kwargs['dataset_split'] = "train" + elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = ConversationDataset + common_kwargs['dataset_subset'] = args.hf_subset + common_kwargs['dataset_split'] = args.hf_split + sample_kwargs["enable_multimodal_chat"] = True + + else: + raise ValueError(f"Unknown dataset name: {args.dataset_name}") + # Remove None values + sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} + return dataset_cls(**common_kwargs).sample(**sample_kwargs) + + +def validate_dataset(args: argparse.Namespace, ): + """ + Validate the dataset arguments. + """ + # === Dataset Configuration === + if not args.dataset_path: + print( + "When dataset path is not set, it will default to random dataset") + args.dataset_name = 'random' + if args.input_len is None: + raise ValueError("input_len must be provided for a random dataset") + + # === Dataset Name Specific Checks === + # --hf-subset and --hf-split: only used + # when dataset_name is 'hf' + if args.dataset_name != "hf" and ( + getattr(args, "hf_subset", None) is not None + or getattr(args, "hf_split", None) is not None): + warnings.warn("--hf-subset and --hf-split will be ignored \ + since --dataset-name is not 'hf'.", + stacklevel=2) + elif args.dataset_name == "hf": + if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + assert args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend." #noqa: E501 + elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + assert args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend." #noqa: E501 + elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + assert args.backend == "vllm-chat", "ConversationDataset needs to use vllm-chat as the backend." #noqa: E501 + else: + raise ValueError( + f"{args.dataset_path} is not supported by hf dataset.") + + # --random-range-ratio: only used when dataset_name is 'random' + if args.dataset_name != 'random' and args.random_range_ratio is not None: + warnings.warn("--random-range-ratio will be ignored since \ + --dataset-name is not 'random'.", + stacklevel=2) + + # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not + # set. + if args.dataset_name not in {"random", "sonnet", None + } and args.prefix_len is not None: + warnings.warn("--prefix-len will be ignored since --dataset-name\ + is not 'random', 'sonnet', or not set.", + stacklevel=2)