diff --git a/benchmarks/README.md b/benchmarks/README.md index cbf2f281bdde7..6f9fbb91cbd91 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -64,6 +64,12 @@ become available. ✅ lmms-lab/LLaVA-OneVision-Data, Aeala/ShareGPT_Vicuna_unfiltered + + Custom + ✅ + ✅ + Local file: data.jsonl + @@ -124,6 +130,38 @@ P99 ITL (ms): 8.39 ================================================== ``` +### Custom Dataset +If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl + +``` +{"prompt": "What is the capital of India?"} +{"prompt": "What is the capital of Iran?"} +{"prompt": "What is the capital of China?"} +``` + +```bash +# start server +VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests +``` + +```bash +# run benchmarking script +python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detailed \ + --backend vllm \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --endpoint /v1/completions \ + --dataset-name custom \ + --dataset-path \ + --custom-skip-chat-template \ + --num-prompts 80 \ + --max-concurrency 1 \ + --temperature=0.3 \ + --top-p=0.75 \ + --result-dir "./log/" +``` + +You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. + ### VisionArena Benchmark for Vision Language Models ```bash @@ -203,6 +241,16 @@ python3 vllm/benchmarks/benchmark_serving.py \ --seed 42 ``` +**`philschmid/mt-bench`** + +``` bash +python3 vllm/benchmarks/benchmark_serving.py \ + --model Qwen/QwQ-32B \ + --dataset-name hf \ + --dataset-path philschmid/mt-bench \ + --num-prompts 80 +``` + ### Running With Sampling Parameters When using OpenAI-compatible backends such as `vllm`, optional sampling diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 5513a5f78f1ce..d86bf045ea47e 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -9,9 +9,6 @@ generation. Supported dataset types include: - BurstGPT - HuggingFace - VisionArena - -TODO: Implement CustomDataset to parse a JSON file and convert its contents into -SampleRequest instances, similar to the approach used in ShareGPT. """ import base64 @@ -442,6 +439,97 @@ class ShareGPTDataset(BenchmarkDataset): return samples +# ----------------------------------------------------------------------------- +# Custom Dataset Implementation +# ----------------------------------------------------------------------------- + + +class CustomDataset(BenchmarkDataset): + """ + Implements the Custom dataset. Loads data from a JSONL file and generates + sample requests based on conversation turns. E.g., + ``` + {"prompt": "What is the capital of India?"} + {"prompt": "What is the capital of Iran?"} + {"prompt": "What is the capital of China?"} + ``` + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + # self.data will be a list of dictionaries + # e.g., [{"prompt": "What is the capital of India?"}, ...] + # This will be the standardized format which load_data() + # has to convert into depending on the filetype of dataset_path. + # sample() will assume this standardized format of self.data + self.data = [] + + # Load the JSONL file + if self.dataset_path.endswith(".jsonl"): + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True) + + # check if the JSONL file has a 'prompt' column + if "prompt" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'prompt' column.") + + # Convert each row to a dictionary and append to self.data + # This will convert the DataFrame to a list of dictionaries + # where each dictionary corresponds to a row in the DataFrame. + # This is the standardized format we want for self.data + for _, row in jsonl_data.iterrows(): + self.data.append(row.to_dict()) + else: + raise NotImplementedError( + "Only JSONL format is supported for CustomDataset." + ) + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + skip_chat_template: bool = False, + **kwargs, + ) -> list: + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["prompt"] + + # apply template + if not skip_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + ) + ) + self.maybe_oversample_requests(sampled_requests, num_requests) + + return sampled_requests + + # ----------------------------------------------------------------------------- # Sonnet Dataset Implementation # ----------------------------------------------------------------------------- diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 79024a9d61c51..6bd9f1b49c2ec 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -60,6 +60,7 @@ from benchmark_dataset import ( ASRDataset, BurstGPTDataset, ConversationDataset, + CustomDataset, HuggingFaceDataset, InstructCoderDataset, MTBenchDataset, @@ -627,7 +628,16 @@ def main(args: argparse.Namespace): "'--dataset-path' if required." ) - if args.dataset_name == "sonnet": + if args.dataset_name == "custom": + dataset = CustomDataset(dataset_path=args.dataset_path) + input_requests = dataset.sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.custom_output_len, + skip_chat_template=args.custom_skip_chat_template, + ) + + elif args.dataset_name == "sonnet": dataset = SonnetDataset(dataset_path=args.dataset_path) # For the "sonnet" dataset, formatting depends on the backend. if args.backend == "openai-chat": @@ -838,6 +848,8 @@ def main(args: argparse.Namespace): ]: if field in result_json: del result_json[field] + if field in benchmark_result: + del benchmark_result[field] # Save to file base_model_id = model_id.split("/")[-1] @@ -850,6 +862,7 @@ def main(args: argparse.Namespace): if args.result_filename: file_name = args.result_filename if args.result_dir: + os.makedirs(args.result_dir, exist_ok=True) file_name = os.path.join(args.result_dir, file_name) with open( file_name, mode="a+" if args.append_result else "w", encoding="utf-8" @@ -890,7 +903,7 @@ if __name__ == "__main__": "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], help="Name of the dataset to benchmark on.", ) parser.add_argument( @@ -1060,6 +1073,19 @@ if __name__ == "__main__": ) # group for dataset specific arguments + custom_group = parser.add_argument_group("custom dataset options") + custom_group.add_argument( + "--custom-output-len", + type=int, + default=256, + help="Number of output tokens per request, used only for custom dataset.", + ) + custom_group.add_argument( + "--custom-skip-chat-template", + action="store_true", + help="Skip applying chat template to prompt, used only for custom dataset.", + ) + sonnet_group = parser.add_argument_group("sonnet dataset options") sonnet_group.add_argument( "--sonnet-input-len", diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 712e83528f122..35cc303f60eeb 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -9,9 +9,6 @@ generation. Supported dataset types include: - BurstGPT - HuggingFace - VisionArena - -TODO: Implement CustomDataset to parse a JSON file and convert its contents into -SampleRequest instances, similar to the approach used in ShareGPT. """ import base64 import io @@ -26,6 +23,7 @@ from io import BytesIO from typing import Any, Callable, Optional, Union import numpy as np +import pandas as pd from PIL import Image from transformers import PreTrainedTokenizerBase @@ -443,6 +441,99 @@ class ShareGPTDataset(BenchmarkDataset): return samples +# ----------------------------------------------------------------------------- +# Custom Dataset Implementation +# ----------------------------------------------------------------------------- + + +class CustomDataset(BenchmarkDataset): + """ + Implements the Custom dataset. Loads data from a JSONL file and generates + sample requests based on conversation turns. E.g., + ``` + {"prompt": "What is the capital of India?"} + {"prompt": "What is the capital of Iran?"} + {"prompt": "What is the capital of China?"} + ``` + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + # self.data will be a list of dictionaries + # e.g., [{"prompt": "What is the capital of India?"}, ...] + # This will be the standardized format which load_data() + # has to convert into depending on the filetype of dataset_path. + # sample() will assume this standardized format of self.data + self.data = [] + + # Load the JSONL file + if self.dataset_path.endswith(".jsonl"): + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, + lines=True) + + # check if the JSONL file has a 'prompt' column + if "prompt" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'prompt' column.") + + # Convert each row to a dictionary and append to self.data + # This will convert the DataFrame to a list of dictionaries + # where each dictionary corresponds to a row in the DataFrame. + # This is the standardized format we want for self.data + for _, row in jsonl_data.iterrows(): + self.data.append(row.to_dict()) + else: + raise NotImplementedError( + "Only JSONL format is supported for CustomDataset.") + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + skip_chat_template: bool = False, + **kwargs, + ) -> list: + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["prompt"] + + # apply template + if not skip_chat_template: + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + + return sampled_requests + + # ----------------------------------------------------------------------------- # Sonnet Dataset Implementation # ----------------------------------------------------------------------------- diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 040815e879f0c..858a0c6a00e4b 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -1110,6 +1110,8 @@ def main(args: argparse.Namespace): ]: if field in result_json: del result_json[field] + if field in benchmark_result: + del benchmark_result[field] # Save to file base_model_id = model_id.split("/")[-1] @@ -1120,6 +1122,7 @@ def main(args: argparse.Namespace): if args.result_filename: file_name = args.result_filename if args.result_dir: + os.makedirs(args.result_dir, exist_ok=True) file_name = os.path.join(args.result_dir, file_name) with open(file_name, mode="a+" if args.append_result else "w",