From 17bccecb1c45a5859e68d42b0f6d6184ed3dfe03 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 29 Jun 2025 22:30:12 -0700 Subject: [PATCH] add mtbench dataste --- benchmarks/benchmark_dataset.py | 122 +++++++++++++++++++++++++++++ benchmarks/benchmark_throughput.py | 18 ++++- 2 files changed, 137 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 0fdd0f5e4d8f4..6c08dff3c8aa7 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -405,6 +405,13 @@ class ShareGPTDataset(BenchmarkDataset): entry["conversations"][1]["value"], ) + prompt = tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False) + lora_request, tokenizer = self.get_random_lora_request( tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) prompt_ids = tokenizer(prompt).input_ids @@ -763,6 +770,14 @@ class InstructCoderDataset(HuggingFaceDataset): if len(sampled_requests) >= num_requests: break prompt = f"{item['instruction']}:\n{item['input']}" + + prompt = tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False) + prompt_len = len(tokenizer(prompt).input_ids) sampled_requests.append( SampleRequest( @@ -801,6 +816,13 @@ class AIMODataset(HuggingFaceDataset): break prompt, completion = item['problem'], item["solution"] + prompt = tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False) + prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids prompt_len = len(prompt_ids) @@ -898,3 +920,103 @@ class ASRDataset(HuggingFaceDataset): " what Whisper supports.", skipped) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests + + +class MTBenchDataset(HuggingFaceDataset): + """ + MT-Bench Dataset. + https://huggingface.co/datasets/philschmid/mt-bench + + We create a single turn dataset for MT-Bench. + This is similar to Spec decoding benchmark setup in vLLM + https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 + """ # noqa: E501 + + DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM + SUPPORTED_DATASET_PATHS = { + "philschmid/mt-bench", + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item['turns'][0] + + # apply template + prompt = tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +class CNNDailyMailDataset(HuggingFaceDataset): + """ + MT-Bench Dataset. + https://huggingface.co/datasets/philschmid/mt-bench + + We create a single turn dataset for MT-Bench. + This is similar to Spec decoding benchmark setup in vLLM + https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 + """ # noqa: E501 + + DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM + SUPPORTED_DATASET_PATHS = { + "abisee/cnn_dailymail", + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + instruction = "Could you summarize the following article, " \ + "please reuse text from the article if possible: " + prompt = instruction + item['article'] + + # apply template + prompt = tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 19528e417a4a0..02078871aee85 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -12,7 +12,8 @@ from typing import Any, Optional, Union import torch import uvloop from benchmark_dataset import (AIMODataset, BurstGPTDataset, - ConversationDataset, InstructCoderDataset, + CNNDailyMailDataset, ConversationDataset, + InstructCoderDataset, MTBenchDataset, RandomDataset, SampleRequest, ShareGPTDataset, SonnetDataset, VisionArenaDataset) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json @@ -339,6 +340,14 @@ def get_requests(args, tokenizer): dataset_cls = AIMODataset common_kwargs['dataset_subset'] = None common_kwargs['dataset_split'] = "train" + elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = MTBenchDataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = "train" + elif args.dataset_path in CNNDailyMailDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = CNNDailyMailDataset + common_kwargs['dataset_subset'] = '3.0.0' + common_kwargs['dataset_split'] = "train" else: raise ValueError(f"Unknown dataset name: {args.dataset_name}") # Remove None values @@ -477,8 +486,11 @@ def validate_args(args): VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() | ConversationDataset.SUPPORTED_DATASET_PATHS): assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 - elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS - | AIMODataset.SUPPORTED_DATASET_PATHS): + elif args.dataset_path in ( + InstructCoderDataset.SUPPORTED_DATASET_PATHS + | AIMODataset.SUPPORTED_DATASET_PATHS + | MTBenchDataset.SUPPORTED_DATASET_PATHS + | CNNDailyMailDataset.SUPPORTED_DATASET_PATHS): assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 else: raise ValueError(