diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index b81c2f8192db7..98d3360cd6ffa 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -887,6 +887,94 @@ class AIMODataset(HuggingFaceDataset): return sampled_requests +# ----------------------------------------------------------------------------- +# Next Edit Prediction Dataset Implementation +# ----------------------------------------------------------------------------- + + +zeta_prompt = """### Instruction: +You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. + +### User Edits: + +{} + +### User Excerpt: + +{} + +### Response: + +""" # noqa: E501 + + +def _format_zeta_prompt( + sample: dict, + original_start_marker: str = "<|editable_region_start|>") -> dict: + """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. + + This function formats examples from the NEP dataset + into prompts and expected outputs. It could be + further extended to support more NEP datasets. + + Args: + sample: The dataset sample containing events, + inputs, and outputs. + original_start_marker: The marker indicating the + start of the editable region. Defaults to + "<|editable_region_start|>". + + Returns: + A dictionary with the formatted prompts and expected outputs. + """ + events = sample["events"] + input = sample["input"] + output = sample["output"] + prompt = zeta_prompt.format(events, input) + + # following the original implementation, extract the focused region + # from the raw output + output_start_index = output.find(original_start_marker) + output_focused_region = output[output_start_index:] + expected_output = output_focused_region + + return {"prompt": prompt, "expected_output": expected_output} + + +class NextEditPredictionDataset(HuggingFaceDataset): + """ + Dataset class for processing a Next Edit Prediction dataset. + """ + + SUPPORTED_DATASET_PATHS = { + "zed-industries/zeta", + } + MAPPING_PROMPT_FUNCS = { + "zed-industries/zeta": _format_zeta_prompt, + } + + def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + **kwargs): + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( + self.dataset_path) + if formatting_prompt_func is None: + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + samples = [] + for sample in self.data: + sample = formatting_prompt_func(sample) + samples.append( + SampleRequest( + prompt=sample["prompt"], + prompt_len=len(tokenizer(sample["prompt"]).input_ids), + expected_output_len=len( + tokenizer(sample["expected_output"]).input_ids), + )) + if len(samples) >= num_requests: + break + self.maybe_oversample_requests(samples, num_requests) + return samples + + # ----------------------------------------------------------------------------- # ASR Dataset Implementation # ----------------------------------------------------------------------------- diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index c236d64261d0d..89fb0e1df0353 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -53,8 +53,9 @@ except ImportError: from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset, ConversationDataset, HuggingFaceDataset, InstructCoderDataset, MTBenchDataset, - RandomDataset, SampleRequest, ShareGPTDataset, - SonnetDataset, VisionArenaDataset) + NextEditPredictionDataset, RandomDataset, + SampleRequest, ShareGPTDataset, SonnetDataset, + VisionArenaDataset) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -603,6 +604,9 @@ def main(args: argparse.Namespace): elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_class = AIMODataset args.hf_split = "train" + elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 + dataset_class = NextEditPredictionDataset + args.hf_split = "train" elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: dataset_class = ASRDataset args.hf_split = "train" diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 299c888c2e7b9..fab44fb6062d1 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -829,3 +829,91 @@ class AIMODataset(HuggingFaceDataset): )) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests + + +# ----------------------------------------------------------------------------- +# Next Edit Prediction Dataset Implementation +# ----------------------------------------------------------------------------- + + +zeta_prompt = """### Instruction: +You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. + +### User Edits: + +{} + +### User Excerpt: + +{} + +### Response: + +""" # noqa: E501 + + +def _format_zeta_prompt( + sample: dict, + original_start_marker: str = "<|editable_region_start|>") -> dict: + """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. + + This function formats examples from the NEP dataset + into prompts and expected outputs. It could be + further extended to support more NEP datasets. + + Args: + sample: The dataset sample containing events, + inputs, and outputs. + original_start_marker: The marker indicating the + start of the editable region. Defaults to + "<|editable_region_start|>". + + Returns: + A dictionary with the formatted prompts and expected outputs. + """ + events = sample["events"] + input = sample["input"] + output = sample["output"] + prompt = zeta_prompt.format(events, input) + + # following the original implementation, extract the focused region + # from the raw output + output_start_index = output.find(original_start_marker) + output_focused_region = output[output_start_index:] + expected_output = output_focused_region + + return {"prompt": prompt, "expected_output": expected_output} + + +class NextEditPredictionDataset(HuggingFaceDataset): + """ + Dataset class for processing a Next Edit Prediction dataset. + """ + + SUPPORTED_DATASET_PATHS = { + "zed-industries/zeta", + } + MAPPING_PROMPT_FUNCS = { + "zed-industries/zeta": _format_zeta_prompt, + } + + def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + **kwargs): + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( + self.dataset_path) + if formatting_prompt_func is None: + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + samples = [] + for sample in self.data: + sample = formatting_prompt_func(sample) + samples.append( + SampleRequest( + prompt=sample["prompt"], + prompt_len=len(tokenizer(sample["prompt"]).input_ids), + expected_output_len=len( + tokenizer(sample["expected_output"]).input_ids), + )) + if len(samples) >= num_requests: + break + self.maybe_oversample_requests(samples, num_requests) + return samples