mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 01:24:27 +08:00
[Spec Decode][Benchmark] Add Blitzedit dataset (#23605)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
3feeeb9fea
commit
cd08636926
@ -1101,6 +1101,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
|||||||
"from the ShareGPT dataset.",
|
"from the ShareGPT dataset.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
blazedit_group = parser.add_argument_group("blazedit dataset options")
|
||||||
|
blazedit_group.add_argument(
|
||||||
|
"--blazedit-min-distance",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help=
|
||||||
|
"Minimum distance for blazedit dataset. Min: 0, Max: 1.0",
|
||||||
|
)
|
||||||
|
blazedit_group.add_argument(
|
||||||
|
"--blazedit-max-distance",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help=
|
||||||
|
"Maximum distance for blazedit dataset. Min: 0, Max: 1.0",
|
||||||
|
)
|
||||||
|
|
||||||
random_group = parser.add_argument_group("random dataset options")
|
random_group = parser.add_argument_group("random dataset options")
|
||||||
random_group.add_argument(
|
random_group.add_argument(
|
||||||
"--random-input-len",
|
"--random-input-len",
|
||||||
@ -1333,6 +1349,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
elif args.dataset_name == "hf":
|
elif args.dataset_name == "hf":
|
||||||
# all following datasets are implemented from the
|
# all following datasets are implemented from the
|
||||||
# HuggingFaceDataset base class
|
# HuggingFaceDataset base class
|
||||||
|
hf_kwargs = {}
|
||||||
if (
|
if (
|
||||||
args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS
|
args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS
|
||||||
or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS
|
or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS
|
||||||
@ -1376,6 +1393,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
):
|
):
|
||||||
dataset_class = ASRDataset
|
dataset_class = ASRDataset
|
||||||
args.hf_split = "train"
|
args.hf_split = "train"
|
||||||
|
elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_class = BlazeditDataset
|
||||||
|
args.hf_split = "train"
|
||||||
|
hf_kwargs = {
|
||||||
|
"min_distance": args.blazedit_min_distance,
|
||||||
|
"max_distance": args.blazedit_max_distance,
|
||||||
|
}
|
||||||
elif (
|
elif (
|
||||||
args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS
|
args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS
|
||||||
or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS
|
or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS
|
||||||
@ -1415,6 +1439,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
output_len=args.hf_output_len,
|
output_len=args.hf_output_len,
|
||||||
request_id_prefix=args.request_id_prefix,
|
request_id_prefix=args.request_id_prefix,
|
||||||
|
**hf_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -2090,6 +2115,94 @@ class MTBenchDataset(HuggingFaceDataset):
|
|||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Blazedit Dataset Implementation
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class BlazeditDataset(HuggingFaceDataset):
|
||||||
|
"""
|
||||||
|
Blazedit Dataset.
|
||||||
|
https://github.com/ise-uiuc/blazedit
|
||||||
|
|
||||||
|
5k char version: vdaita/edit_5k_char
|
||||||
|
10k char version: vdaita/edit_10k_char
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
# 5k char version will have output as ~5k chars
|
||||||
|
# 10k char version will have output as ~10k chars
|
||||||
|
# Assuming 3 char per token, 10k chars will be 3333 tokens
|
||||||
|
# We set default to 4000 to be safe
|
||||||
|
DEFAULT_OUTPUT_LEN = 4000
|
||||||
|
SUPPORTED_DATASET_PATHS = {
|
||||||
|
"vdaita/edit_5k_char",
|
||||||
|
"vdaita/edit_10k_char",
|
||||||
|
}
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
num_requests: int,
|
||||||
|
output_len: Optional[int] = None,
|
||||||
|
request_id_prefix: str = "",
|
||||||
|
min_distance: float = 0.0,
|
||||||
|
max_distance: float = 1.0,
|
||||||
|
**kwargs,
|
||||||
|
) -> list:
|
||||||
|
output_len = (output_len
|
||||||
|
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||||
|
sampled_requests = []
|
||||||
|
|
||||||
|
for i, item in enumerate(self.data):
|
||||||
|
if len(sampled_requests) >= num_requests:
|
||||||
|
break
|
||||||
|
code = item["code"]
|
||||||
|
change_request = item["change_request"]
|
||||||
|
norm_distance = item["norm_distance"]
|
||||||
|
|
||||||
|
# compare the levenshtein distance normalized by code length
|
||||||
|
if norm_distance < min_distance or norm_distance > max_distance:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# template copied from
|
||||||
|
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
|
||||||
|
instruction = f"""Given a code file, please apply the change requests and generate the new file.
|
||||||
|
|
||||||
|
Original file:
|
||||||
|
```python
|
||||||
|
{code}
|
||||||
|
```
|
||||||
|
|
||||||
|
Change request:
|
||||||
|
{change_request}
|
||||||
|
|
||||||
|
Please generate the new code file in the "New file" section below.""" # noqa: E501
|
||||||
|
|
||||||
|
# apply template
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
[{
|
||||||
|
"role": "user",
|
||||||
|
"content": instruction
|
||||||
|
}],
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_len = len(tokenizer(prompt).input_ids)
|
||||||
|
|
||||||
|
sampled_requests.append(
|
||||||
|
SampleRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
expected_output_len=output_len,
|
||||||
|
request_id=request_id_prefix + str(i),
|
||||||
|
))
|
||||||
|
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||||
|
request_id_prefix)
|
||||||
|
|
||||||
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# AIMO Dataset Implementation
|
# AIMO Dataset Implementation
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user