diff --git a/benchmarks/README.md b/benchmarks/README.md
index cbf2f281bdde7..6f9fbb91cbd91 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -64,6 +64,12 @@ become available.
✅ |
lmms-lab/LLaVA-OneVision-Data, Aeala/ShareGPT_Vicuna_unfiltered |
+
+ | Custom |
+ ✅ |
+ ✅ |
+ Local file: data.jsonl |
+
@@ -124,6 +130,38 @@ P99 ITL (ms): 8.39
==================================================
```
+### Custom Dataset
+If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl
+
+```
+{"prompt": "What is the capital of India?"}
+{"prompt": "What is the capital of Iran?"}
+{"prompt": "What is the capital of China?"}
+```
+
+```bash
+# start server
+VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
+```
+
+```bash
+# run benchmarking script
+python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detailed \
+ --backend vllm \
+ --model meta-llama/Llama-3.1-8B-Instruct \
+ --endpoint /v1/completions \
+ --dataset-name custom \
+ --dataset-path \
+ --custom-skip-chat-template \
+ --num-prompts 80 \
+ --max-concurrency 1 \
+ --temperature=0.3 \
+ --top-p=0.75 \
+ --result-dir "./log/"
+```
+
+You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`.
+
### VisionArena Benchmark for Vision Language Models
```bash
@@ -203,6 +241,16 @@ python3 vllm/benchmarks/benchmark_serving.py \
--seed 42
```
+**`philschmid/mt-bench`**
+
+``` bash
+python3 vllm/benchmarks/benchmark_serving.py \
+ --model Qwen/QwQ-32B \
+ --dataset-name hf \
+ --dataset-path philschmid/mt-bench \
+ --num-prompts 80
+```
+
### Running With Sampling Parameters
When using OpenAI-compatible backends such as `vllm`, optional sampling
diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py
index 5513a5f78f1ce..d86bf045ea47e 100644
--- a/benchmarks/benchmark_dataset.py
+++ b/benchmarks/benchmark_dataset.py
@@ -9,9 +9,6 @@ generation. Supported dataset types include:
- BurstGPT
- HuggingFace
- VisionArena
-
-TODO: Implement CustomDataset to parse a JSON file and convert its contents into
-SampleRequest instances, similar to the approach used in ShareGPT.
"""
import base64
@@ -442,6 +439,97 @@ class ShareGPTDataset(BenchmarkDataset):
return samples
+# -----------------------------------------------------------------------------
+# Custom Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class CustomDataset(BenchmarkDataset):
+ """
+ Implements the Custom dataset. Loads data from a JSONL file and generates
+ sample requests based on conversation turns. E.g.,
+ ```
+ {"prompt": "What is the capital of India?"}
+ {"prompt": "What is the capital of Iran?"}
+ {"prompt": "What is the capital of China?"}
+ ```
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.load_data()
+
+ def load_data(self) -> None:
+ if self.dataset_path is None:
+ raise ValueError("dataset_path must be provided for loading data.")
+
+ # self.data will be a list of dictionaries
+ # e.g., [{"prompt": "What is the capital of India?"}, ...]
+ # This will be the standardized format which load_data()
+ # has to convert into depending on the filetype of dataset_path.
+ # sample() will assume this standardized format of self.data
+ self.data = []
+
+ # Load the JSONL file
+ if self.dataset_path.endswith(".jsonl"):
+ jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True)
+
+ # check if the JSONL file has a 'prompt' column
+ if "prompt" not in jsonl_data.columns:
+ raise ValueError("JSONL file must contain a 'prompt' column.")
+
+ # Convert each row to a dictionary and append to self.data
+ # This will convert the DataFrame to a list of dictionaries
+ # where each dictionary corresponds to a row in the DataFrame.
+ # This is the standardized format we want for self.data
+ for _, row in jsonl_data.iterrows():
+ self.data.append(row.to_dict())
+ else:
+ raise NotImplementedError(
+ "Only JSONL format is supported for CustomDataset."
+ )
+
+ random.seed(self.random_seed)
+ random.shuffle(self.data)
+
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ lora_path: Optional[str] = None,
+ max_loras: Optional[int] = None,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ skip_chat_template: bool = False,
+ **kwargs,
+ ) -> list:
+ sampled_requests = []
+ for item in self.data:
+ if len(sampled_requests) >= num_requests:
+ break
+ prompt = item["prompt"]
+
+ # apply template
+ if not skip_chat_template:
+ prompt = tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt}],
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+
+ prompt_len = len(tokenizer(prompt).input_ids)
+ sampled_requests.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ )
+ )
+ self.maybe_oversample_requests(sampled_requests, num_requests)
+
+ return sampled_requests
+
+
# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
index 79024a9d61c51..6bd9f1b49c2ec 100644
--- a/benchmarks/benchmark_serving.py
+++ b/benchmarks/benchmark_serving.py
@@ -60,6 +60,7 @@ from benchmark_dataset import (
ASRDataset,
BurstGPTDataset,
ConversationDataset,
+ CustomDataset,
HuggingFaceDataset,
InstructCoderDataset,
MTBenchDataset,
@@ -627,7 +628,16 @@ def main(args: argparse.Namespace):
"'--dataset-path' if required."
)
- if args.dataset_name == "sonnet":
+ if args.dataset_name == "custom":
+ dataset = CustomDataset(dataset_path=args.dataset_path)
+ input_requests = dataset.sample(
+ num_requests=args.num_prompts,
+ tokenizer=tokenizer,
+ output_len=args.custom_output_len,
+ skip_chat_template=args.custom_skip_chat_template,
+ )
+
+ elif args.dataset_name == "sonnet":
dataset = SonnetDataset(dataset_path=args.dataset_path)
# For the "sonnet" dataset, formatting depends on the backend.
if args.backend == "openai-chat":
@@ -838,6 +848,8 @@ def main(args: argparse.Namespace):
]:
if field in result_json:
del result_json[field]
+ if field in benchmark_result:
+ del benchmark_result[field]
# Save to file
base_model_id = model_id.split("/")[-1]
@@ -850,6 +862,7 @@ def main(args: argparse.Namespace):
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
+ os.makedirs(args.result_dir, exist_ok=True)
file_name = os.path.join(args.result_dir, file_name)
with open(
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
@@ -890,7 +903,7 @@ if __name__ == "__main__":
"--dataset-name",
type=str,
default="sharegpt",
- choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
+ choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
@@ -1060,6 +1073,19 @@ if __name__ == "__main__":
)
# group for dataset specific arguments
+ custom_group = parser.add_argument_group("custom dataset options")
+ custom_group.add_argument(
+ "--custom-output-len",
+ type=int,
+ default=256,
+ help="Number of output tokens per request, used only for custom dataset.",
+ )
+ custom_group.add_argument(
+ "--custom-skip-chat-template",
+ action="store_true",
+ help="Skip applying chat template to prompt, used only for custom dataset.",
+ )
+
sonnet_group = parser.add_argument_group("sonnet dataset options")
sonnet_group.add_argument(
"--sonnet-input-len",
diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py
index 712e83528f122..35cc303f60eeb 100644
--- a/vllm/benchmarks/datasets.py
+++ b/vllm/benchmarks/datasets.py
@@ -9,9 +9,6 @@ generation. Supported dataset types include:
- BurstGPT
- HuggingFace
- VisionArena
-
-TODO: Implement CustomDataset to parse a JSON file and convert its contents into
-SampleRequest instances, similar to the approach used in ShareGPT.
"""
import base64
import io
@@ -26,6 +23,7 @@ from io import BytesIO
from typing import Any, Callable, Optional, Union
import numpy as np
+import pandas as pd
from PIL import Image
from transformers import PreTrainedTokenizerBase
@@ -443,6 +441,99 @@ class ShareGPTDataset(BenchmarkDataset):
return samples
+# -----------------------------------------------------------------------------
+# Custom Dataset Implementation
+# -----------------------------------------------------------------------------
+
+
+class CustomDataset(BenchmarkDataset):
+ """
+ Implements the Custom dataset. Loads data from a JSONL file and generates
+ sample requests based on conversation turns. E.g.,
+ ```
+ {"prompt": "What is the capital of India?"}
+ {"prompt": "What is the capital of Iran?"}
+ {"prompt": "What is the capital of China?"}
+ ```
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.load_data()
+
+ def load_data(self) -> None:
+ if self.dataset_path is None:
+ raise ValueError("dataset_path must be provided for loading data.")
+
+ # self.data will be a list of dictionaries
+ # e.g., [{"prompt": "What is the capital of India?"}, ...]
+ # This will be the standardized format which load_data()
+ # has to convert into depending on the filetype of dataset_path.
+ # sample() will assume this standardized format of self.data
+ self.data = []
+
+ # Load the JSONL file
+ if self.dataset_path.endswith(".jsonl"):
+ jsonl_data = pd.read_json(path_or_buf=self.dataset_path,
+ lines=True)
+
+ # check if the JSONL file has a 'prompt' column
+ if "prompt" not in jsonl_data.columns:
+ raise ValueError("JSONL file must contain a 'prompt' column.")
+
+ # Convert each row to a dictionary and append to self.data
+ # This will convert the DataFrame to a list of dictionaries
+ # where each dictionary corresponds to a row in the DataFrame.
+ # This is the standardized format we want for self.data
+ for _, row in jsonl_data.iterrows():
+ self.data.append(row.to_dict())
+ else:
+ raise NotImplementedError(
+ "Only JSONL format is supported for CustomDataset.")
+
+ random.seed(self.random_seed)
+ random.shuffle(self.data)
+
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ lora_path: Optional[str] = None,
+ max_loras: Optional[int] = None,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ skip_chat_template: bool = False,
+ **kwargs,
+ ) -> list:
+ sampled_requests = []
+ for item in self.data:
+ if len(sampled_requests) >= num_requests:
+ break
+ prompt = item["prompt"]
+
+ # apply template
+ if not skip_chat_template:
+ prompt = tokenizer.apply_chat_template(
+ [{
+ "role": "user",
+ "content": prompt
+ }],
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+
+ prompt_len = len(tokenizer(prompt).input_ids)
+ sampled_requests.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ ))
+ self.maybe_oversample_requests(sampled_requests, num_requests)
+
+ return sampled_requests
+
+
# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------
diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py
index 040815e879f0c..858a0c6a00e4b 100644
--- a/vllm/benchmarks/serve.py
+++ b/vllm/benchmarks/serve.py
@@ -1110,6 +1110,8 @@ def main(args: argparse.Namespace):
]:
if field in result_json:
del result_json[field]
+ if field in benchmark_result:
+ del benchmark_result[field]
# Save to file
base_model_id = model_id.split("/")[-1]
@@ -1120,6 +1122,7 @@ def main(args: argparse.Namespace):
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
+ os.makedirs(args.result_dir, exist_ok=True)
file_name = os.path.join(args.result_dir, file_name)
with open(file_name,
mode="a+" if args.append_result else "w",