mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 07:37:03 +08:00
[Spec Decode][Benchmark] Add Spec Bench Dataset for benchmarking (#23563)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
parent
6f4a82f8b5
commit
3feeeb9fea
@ -1020,7 +1020,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
|||||||
default="random",
|
default="random",
|
||||||
choices=[
|
choices=[
|
||||||
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
|
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
|
||||||
"custom", "prefix_repetition"
|
"custom", "prefix_repetition", "spec_bench"
|
||||||
],
|
],
|
||||||
help="Name of the dataset to benchmark on.",
|
help="Name of the dataset to benchmark on.",
|
||||||
)
|
)
|
||||||
@ -1053,6 +1053,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
|||||||
"Skip applying chat template to prompt, used only for custom dataset.",
|
"Skip applying chat template to prompt, used only for custom dataset.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
spec_bench_group = parser.add_argument_group("spec bench dataset options")
|
||||||
|
spec_bench_group.add_argument(
|
||||||
|
"--spec-bench-output-len",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help=
|
||||||
|
"Num of output tokens per request, used only for spec bench dataset.",
|
||||||
|
)
|
||||||
|
spec_bench_group.add_argument(
|
||||||
|
"--spec-bench-category",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
"Category for spec bench dataset. If None, use all categories.",
|
||||||
|
)
|
||||||
|
|
||||||
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||||
sonnet_group.add_argument(
|
sonnet_group.add_argument(
|
||||||
"--sonnet-input-len",
|
"--sonnet-input-len",
|
||||||
@ -1404,6 +1420,14 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
else:
|
else:
|
||||||
# For datasets that follow a similar structure, use a mapping.
|
# For datasets that follow a similar structure, use a mapping.
|
||||||
dataset_mapping = {
|
dataset_mapping = {
|
||||||
|
"spec_bench":
|
||||||
|
lambda: SpecBench(dataset_path=args.dataset_path,
|
||||||
|
category=args.spec_bench_category).sample(
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
output_len=args.spec_bench_output_len,
|
||||||
|
request_id_prefix=args.request_id_prefix,
|
||||||
|
),
|
||||||
"sharegpt": lambda: ShareGPTDataset(
|
"sharegpt": lambda: ShareGPTDataset(
|
||||||
random_seed=args.seed, dataset_path=args.dataset_path
|
random_seed=args.seed, dataset_path=args.dataset_path
|
||||||
).sample(
|
).sample(
|
||||||
@ -1541,6 +1565,14 @@ class CustomDataset(BenchmarkDataset):
|
|||||||
request_id_prefix: str = "",
|
request_id_prefix: str = "",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list:
|
) -> list:
|
||||||
|
# load all data if needed
|
||||||
|
self.num_available_samples = len(self.data)
|
||||||
|
if num_requests <= 0:
|
||||||
|
num_requests = self.num_available_samples
|
||||||
|
logger.info("num_requests is set to 0 or negative, "
|
||||||
|
"so using all available samples: %d",
|
||||||
|
num_requests)
|
||||||
|
|
||||||
sampled_requests = []
|
sampled_requests = []
|
||||||
for i, item in enumerate(self.data):
|
for i, item in enumerate(self.data):
|
||||||
if len(sampled_requests) >= num_requests:
|
if len(sampled_requests) >= num_requests:
|
||||||
@ -1572,6 +1604,52 @@ class CustomDataset(BenchmarkDataset):
|
|||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Spec Bench Dataset Implementation
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class SpecBench(CustomDataset):
|
||||||
|
"""
|
||||||
|
Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
|
||||||
|
Download the dataset using:
|
||||||
|
wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
self.category = kwargs.pop("category", None)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.load_data()
|
||||||
|
|
||||||
|
def load_data(self) -> None:
|
||||||
|
if self.dataset_path is None:
|
||||||
|
raise ValueError("dataset_path must be provided for loading data.")
|
||||||
|
|
||||||
|
self.data = []
|
||||||
|
|
||||||
|
# Load the JSONL file
|
||||||
|
jsonl_data = pd.read_json(path_or_buf=self.dataset_path,
|
||||||
|
lines=True)
|
||||||
|
|
||||||
|
# check if the JSONL file has a 'turns' column
|
||||||
|
if "turns" not in jsonl_data.columns:
|
||||||
|
raise ValueError("JSONL file must contain a 'turns' column.")
|
||||||
|
|
||||||
|
for _, row in jsonl_data.iterrows():
|
||||||
|
# sample only from a specific category if specified
|
||||||
|
if (not self.category) or (self.category == row['category']):
|
||||||
|
prompt = row["turns"][0]
|
||||||
|
self.data.append({"prompt": prompt})
|
||||||
|
|
||||||
|
random.seed(self.random_seed)
|
||||||
|
random.shuffle(self.data)
|
||||||
|
|
||||||
|
def sample(self, **kwargs) -> list:
|
||||||
|
# leverage CustomDataset sample
|
||||||
|
kwargs["skip_chat_template"] = False
|
||||||
|
return super().sample(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Sonnet Dataset Implementation
|
# Sonnet Dataset Implementation
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user