mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 23:05:59 +08:00
[Misc] Add Next Edit Prediction (NEP) datasets support in benchmark_serving.py (#16839)
Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal> Signed-off-by: dtransposed <> Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
This commit is contained in:
parent
621ca2c0ab
commit
d456aea71f
@ -887,6 +887,94 @@ class AIMODataset(HuggingFaceDataset):
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Next Edit Prediction Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
zeta_prompt = """### Instruction:
|
||||
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
|
||||
|
||||
### User Edits:
|
||||
|
||||
{}
|
||||
|
||||
### User Excerpt:
|
||||
|
||||
{}
|
||||
|
||||
### Response:
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
def _format_zeta_prompt(
|
||||
sample: dict,
|
||||
original_start_marker: str = "<|editable_region_start|>") -> dict:
|
||||
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
|
||||
|
||||
This function formats examples from the NEP dataset
|
||||
into prompts and expected outputs. It could be
|
||||
further extended to support more NEP datasets.
|
||||
|
||||
Args:
|
||||
sample: The dataset sample containing events,
|
||||
inputs, and outputs.
|
||||
original_start_marker: The marker indicating the
|
||||
start of the editable region. Defaults to
|
||||
"<|editable_region_start|>".
|
||||
|
||||
Returns:
|
||||
A dictionary with the formatted prompts and expected outputs.
|
||||
"""
|
||||
events = sample["events"]
|
||||
input = sample["input"]
|
||||
output = sample["output"]
|
||||
prompt = zeta_prompt.format(events, input)
|
||||
|
||||
# following the original implementation, extract the focused region
|
||||
# from the raw output
|
||||
output_start_index = output.find(original_start_marker)
|
||||
output_focused_region = output[output_start_index:]
|
||||
expected_output = output_focused_region
|
||||
|
||||
return {"prompt": prompt, "expected_output": expected_output}
|
||||
|
||||
|
||||
class NextEditPredictionDataset(HuggingFaceDataset):
|
||||
"""
|
||||
Dataset class for processing a Next Edit Prediction dataset.
|
||||
"""
|
||||
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"zed-industries/zeta",
|
||||
}
|
||||
MAPPING_PROMPT_FUNCS = {
|
||||
"zed-industries/zeta": _format_zeta_prompt,
|
||||
}
|
||||
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
|
||||
**kwargs):
|
||||
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
|
||||
self.dataset_path)
|
||||
if formatting_prompt_func is None:
|
||||
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
|
||||
samples = []
|
||||
for sample in self.data:
|
||||
sample = formatting_prompt_func(sample)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=sample["prompt"],
|
||||
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
|
||||
expected_output_len=len(
|
||||
tokenizer(sample["expected_output"]).input_ids),
|
||||
))
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ASR Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@ -53,8 +53,9 @@ except ImportError:
|
||||
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
|
||||
ConversationDataset, HuggingFaceDataset,
|
||||
InstructCoderDataset, MTBenchDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
NextEditPredictionDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
@ -603,6 +604,9 @@ def main(args: argparse.Namespace):
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = AIMODataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501
|
||||
dataset_class = NextEditPredictionDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ASRDataset
|
||||
args.hf_split = "train"
|
||||
|
||||
@ -829,3 +829,91 @@ class AIMODataset(HuggingFaceDataset):
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Next Edit Prediction Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
zeta_prompt = """### Instruction:
|
||||
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
|
||||
|
||||
### User Edits:
|
||||
|
||||
{}
|
||||
|
||||
### User Excerpt:
|
||||
|
||||
{}
|
||||
|
||||
### Response:
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
def _format_zeta_prompt(
|
||||
sample: dict,
|
||||
original_start_marker: str = "<|editable_region_start|>") -> dict:
|
||||
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
|
||||
|
||||
This function formats examples from the NEP dataset
|
||||
into prompts and expected outputs. It could be
|
||||
further extended to support more NEP datasets.
|
||||
|
||||
Args:
|
||||
sample: The dataset sample containing events,
|
||||
inputs, and outputs.
|
||||
original_start_marker: The marker indicating the
|
||||
start of the editable region. Defaults to
|
||||
"<|editable_region_start|>".
|
||||
|
||||
Returns:
|
||||
A dictionary with the formatted prompts and expected outputs.
|
||||
"""
|
||||
events = sample["events"]
|
||||
input = sample["input"]
|
||||
output = sample["output"]
|
||||
prompt = zeta_prompt.format(events, input)
|
||||
|
||||
# following the original implementation, extract the focused region
|
||||
# from the raw output
|
||||
output_start_index = output.find(original_start_marker)
|
||||
output_focused_region = output[output_start_index:]
|
||||
expected_output = output_focused_region
|
||||
|
||||
return {"prompt": prompt, "expected_output": expected_output}
|
||||
|
||||
|
||||
class NextEditPredictionDataset(HuggingFaceDataset):
|
||||
"""
|
||||
Dataset class for processing a Next Edit Prediction dataset.
|
||||
"""
|
||||
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"zed-industries/zeta",
|
||||
}
|
||||
MAPPING_PROMPT_FUNCS = {
|
||||
"zed-industries/zeta": _format_zeta_prompt,
|
||||
}
|
||||
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
|
||||
**kwargs):
|
||||
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
|
||||
self.dataset_path)
|
||||
if formatting_prompt_func is None:
|
||||
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
|
||||
samples = []
|
||||
for sample in self.data:
|
||||
sample = formatting_prompt_func(sample)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=sample["prompt"],
|
||||
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
|
||||
expected_output_len=len(
|
||||
tokenizer(sample["expected_output"]).input_ids),
|
||||
))
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user