mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 18:24:33 +08:00
[Spec Decode][Benchmark] Add Blitzedit dataset (#23605)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
3feeeb9fea
commit
cd08636926
@ -1101,6 +1101,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
"from the ShareGPT dataset.",
|
||||
)
|
||||
|
||||
blazedit_group = parser.add_argument_group("blazedit dataset options")
|
||||
blazedit_group.add_argument(
|
||||
"--blazedit-min-distance",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help=
|
||||
"Minimum distance for blazedit dataset. Min: 0, Max: 1.0",
|
||||
)
|
||||
blazedit_group.add_argument(
|
||||
"--blazedit-max-distance",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help=
|
||||
"Maximum distance for blazedit dataset. Min: 0, Max: 1.0",
|
||||
)
|
||||
|
||||
random_group = parser.add_argument_group("random dataset options")
|
||||
random_group.add_argument(
|
||||
"--random-input-len",
|
||||
@ -1333,6 +1349,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
elif args.dataset_name == "hf":
|
||||
# all following datasets are implemented from the
|
||||
# HuggingFaceDataset base class
|
||||
hf_kwargs = {}
|
||||
if (
|
||||
args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS
|
||||
or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS
|
||||
@ -1376,6 +1393,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
):
|
||||
dataset_class = ASRDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = BlazeditDataset
|
||||
args.hf_split = "train"
|
||||
hf_kwargs = {
|
||||
"min_distance": args.blazedit_min_distance,
|
||||
"max_distance": args.blazedit_max_distance,
|
||||
}
|
||||
elif (
|
||||
args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS
|
||||
or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS
|
||||
@ -1415,6 +1439,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.hf_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
**hf_kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
@ -2090,6 +2115,94 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Blazedit Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BlazeditDataset(HuggingFaceDataset):
|
||||
"""
|
||||
Blazedit Dataset.
|
||||
https://github.com/ise-uiuc/blazedit
|
||||
|
||||
5k char version: vdaita/edit_5k_char
|
||||
10k char version: vdaita/edit_10k_char
|
||||
""" # noqa: E501
|
||||
|
||||
# 5k char version will have output as ~5k chars
|
||||
# 10k char version will have output as ~10k chars
|
||||
# Assuming 3 char per token, 10k chars will be 3333 tokens
|
||||
# We set default to 4000 to be safe
|
||||
DEFAULT_OUTPUT_LEN = 4000
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"vdaita/edit_5k_char",
|
||||
"vdaita/edit_10k_char",
|
||||
}
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
request_id_prefix: str = "",
|
||||
min_distance: float = 0.0,
|
||||
max_distance: float = 1.0,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
|
||||
for i, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
code = item["code"]
|
||||
change_request = item["change_request"]
|
||||
norm_distance = item["norm_distance"]
|
||||
|
||||
# compare the levenshtein distance normalized by code length
|
||||
if norm_distance < min_distance or norm_distance > max_distance:
|
||||
continue
|
||||
|
||||
# template copied from
|
||||
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
|
||||
instruction = f"""Given a code file, please apply the change requests and generate the new file.
|
||||
|
||||
Original file:
|
||||
```python
|
||||
{code}
|
||||
```
|
||||
|
||||
Change request:
|
||||
{change_request}
|
||||
|
||||
Please generate the new code file in the "New file" section below.""" # noqa: E501
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
"role": "user",
|
||||
"content": instruction
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(i),
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix)
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# AIMO Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user