[Spec Decode][Benchmark] Add Blitzedit dataset (#23605)

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Ekagra Ranjan 2025-09-08 13:32:52 -04:00 committed by GitHub
parent 3feeeb9fea
commit cd08636926
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1101,6 +1101,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"from the ShareGPT dataset.",
)
blazedit_group = parser.add_argument_group("blazedit dataset options")
blazedit_group.add_argument(
"--blazedit-min-distance",
type=float,
default=0.0,
help=
"Minimum distance for blazedit dataset. Min: 0, Max: 1.0",
)
blazedit_group.add_argument(
"--blazedit-max-distance",
type=float,
default=1.0,
help=
"Maximum distance for blazedit dataset. Min: 0, Max: 1.0",
)
random_group = parser.add_argument_group("random dataset options")
random_group.add_argument(
"--random-input-len",
@ -1333,6 +1349,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
elif args.dataset_name == "hf":
# all following datasets are implemented from the
# HuggingFaceDataset base class
hf_kwargs = {}
if (
args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS
or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS
@ -1376,6 +1393,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
):
dataset_class = ASRDataset
args.hf_split = "train"
elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS:
dataset_class = BlazeditDataset
args.hf_split = "train"
hf_kwargs = {
"min_distance": args.blazedit_min_distance,
"max_distance": args.blazedit_max_distance,
}
elif (
args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS
or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS
@ -1415,6 +1439,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer,
output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix,
**hf_kwargs
)
else:
@ -2090,6 +2115,94 @@ class MTBenchDataset(HuggingFaceDataset):
return sampled_requests
# -----------------------------------------------------------------------------
# Blazedit Dataset Implementation
# -----------------------------------------------------------------------------
class BlazeditDataset(HuggingFaceDataset):
"""
Blazedit Dataset.
https://github.com/ise-uiuc/blazedit
5k char version: vdaita/edit_5k_char
10k char version: vdaita/edit_10k_char
""" # noqa: E501
# 5k char version will have output as ~5k chars
# 10k char version will have output as ~10k chars
# Assuming 3 char per token, 10k chars will be 3333 tokens
# We set default to 4000 to be safe
DEFAULT_OUTPUT_LEN = 4000
SUPPORTED_DATASET_PATHS = {
"vdaita/edit_5k_char",
"vdaita/edit_10k_char",
}
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
request_id_prefix: str = "",
min_distance: float = 0.0,
max_distance: float = 1.0,
**kwargs,
) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
code = item["code"]
change_request = item["change_request"]
norm_distance = item["norm_distance"]
# compare the levenshtein distance normalized by code length
if norm_distance < min_distance or norm_distance > max_distance:
continue
# template copied from
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
instruction = f"""Given a code file, please apply the change requests and generate the new file.
Original file:
```python
{code}
```
Change request:
{change_request}
Please generate the new code file in the "New file" section below.""" # noqa: E501
# apply template
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": instruction
}],
add_generation_prompt=True,
tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests
# -----------------------------------------------------------------------------
# AIMO Dataset Implementation
# -----------------------------------------------------------------------------