mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 23:56:33 +08:00
Add benchmark dataset for mlperf llama tasks (#20338)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
559756214b
commit
8bb43b9c9e
@ -654,6 +654,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ASRDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = MLPerfDataset
|
||||
args.hf_split = "train"
|
||||
else:
|
||||
supported_datasets = set([
|
||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||
@ -1447,3 +1450,82 @@ class ASRDataset(HuggingFaceDataset):
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MLPerf Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MLPerfDataset(HuggingFaceDataset):
|
||||
"""
|
||||
MLPerf Inference Dataset.
|
||||
|
||||
Dataset on HF:
|
||||
https://huggingface.co/datasets/mgoin/mlperf-inference-llama2-data
|
||||
https://huggingface.co/datasets/mgoin/mlperf-inference-llama3.1-data
|
||||
|
||||
Each record contains:
|
||||
- "system_prompt": system role instruction.
|
||||
- "question": user question.
|
||||
- "output": reference answer.
|
||||
|
||||
We combine the system prompt and question into a chat-formatted prompt
|
||||
(using the tokenizer's chat template) and set the expected output length to
|
||||
the tokenized length of the provided reference answer.
|
||||
"""
|
||||
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"mgoin/mlperf-inference-llama2-data",
|
||||
"mgoin/mlperf-inference-llama3.1-data",
|
||||
}
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
# Force dynamic output length based on reference completion.
|
||||
dynamic_output = output_len is None
|
||||
sampled_requests: list[SampleRequest] = []
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
|
||||
system_prompt = item["system_prompt"]
|
||||
question = item["question"]
|
||||
reference_answer = item["output"]
|
||||
|
||||
# Build chat-style prompt using tokenizer template, if available.
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
prompt_formatted = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
|
||||
# Determine output length from reference answer tokens.
|
||||
ref_out_len = len(
|
||||
tokenizer(reference_answer, add_special_tokens=False).input_ids
|
||||
)
|
||||
expected_output_len = ref_out_len if dynamic_output else output_len
|
||||
|
||||
# Validate sequence lengths.
|
||||
if not is_valid_sequence(prompt_len, expected_output_len):
|
||||
continue
|
||||
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt_formatted,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=expected_output_len,
|
||||
)
|
||||
)
|
||||
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user