add mtbench dataste

This commit is contained in:
LiuXiaoxuanPKU 2025-06-29 22:30:12 -07:00
parent c335930d75
commit 17bccecb1c
2 changed files with 137 additions and 3 deletions

View File

@ -405,6 +405,13 @@ class ShareGPTDataset(BenchmarkDataset):
entry["conversations"][1]["value"],
)
prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False)
lora_request, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
prompt_ids = tokenizer(prompt).input_ids
@ -763,6 +770,14 @@ class InstructCoderDataset(HuggingFaceDataset):
if len(sampled_requests) >= num_requests:
break
prompt = f"{item['instruction']}:\n{item['input']}"
prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
@ -801,6 +816,13 @@ class AIMODataset(HuggingFaceDataset):
break
prompt, completion = item['problem'], item["solution"]
prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False)
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
@ -898,3 +920,103 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.", skipped)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
class MTBenchDataset(HuggingFaceDataset):
"""
MT-Bench Dataset.
https://huggingface.co/datasets/philschmid/mt-bench
We create a single turn dataset for MT-Bench.
This is similar to Spec decoding benchmark setup in vLLM
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
""" # noqa: E501
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
SUPPORTED_DATASET_PATHS = {
"philschmid/mt-bench",
}
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt = item['turns'][0]
# apply template
prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
class CNNDailyMailDataset(HuggingFaceDataset):
"""
MT-Bench Dataset.
https://huggingface.co/datasets/philschmid/mt-bench
We create a single turn dataset for MT-Bench.
This is similar to Spec decoding benchmark setup in vLLM
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
""" # noqa: E501
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
SUPPORTED_DATASET_PATHS = {
"abisee/cnn_dailymail",
}
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
instruction = "Could you summarize the following article, " \
"please reuse text from the article if possible: "
prompt = instruction + item['article']
# apply template
prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests

View File

@ -12,7 +12,8 @@ from typing import Any, Optional, Union
import torch
import uvloop
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
ConversationDataset, InstructCoderDataset,
CNNDailyMailDataset, ConversationDataset,
InstructCoderDataset, MTBenchDataset,
RandomDataset, SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
@ -339,6 +340,14 @@ def get_requests(args, tokenizer):
dataset_cls = AIMODataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = MTBenchDataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
elif args.dataset_path in CNNDailyMailDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = CNNDailyMailDataset
common_kwargs['dataset_subset'] = '3.0.0'
common_kwargs['dataset_split'] = "train"
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
@ -477,8 +486,11 @@ def validate_args(args):
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
| ConversationDataset.SUPPORTED_DATASET_PATHS):
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS):
elif args.dataset_path in (
InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS
| MTBenchDataset.SUPPORTED_DATASET_PATHS
| CNNDailyMailDataset.SUPPORTED_DATASET_PATHS):
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
else:
raise ValueError(