[Spec Decode][Benchmark] Add Spec Bench Dataset for benchmarking (#23563)

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
Ekagra Ranjan 2025-09-08 13:32:42 -04:00 committed by GitHub
parent 6f4a82f8b5
commit 3feeeb9fea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1020,7 +1020,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
default="random",
choices=[
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
"custom", "prefix_repetition"
"custom", "prefix_repetition", "spec_bench"
],
help="Name of the dataset to benchmark on.",
)
@ -1053,6 +1053,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"Skip applying chat template to prompt, used only for custom dataset.",
)
spec_bench_group = parser.add_argument_group("spec bench dataset options")
spec_bench_group.add_argument(
"--spec-bench-output-len",
type=int,
default=256,
help=
"Num of output tokens per request, used only for spec bench dataset.",
)
spec_bench_group.add_argument(
"--spec-bench-category",
type=str,
default=None,
help=
"Category for spec bench dataset. If None, use all categories.",
)
sonnet_group = parser.add_argument_group("sonnet dataset options")
sonnet_group.add_argument(
"--sonnet-input-len",
@ -1404,6 +1420,14 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
else:
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"spec_bench":
lambda: SpecBench(dataset_path=args.dataset_path,
category=args.spec_bench_category).sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
output_len=args.spec_bench_output_len,
request_id_prefix=args.request_id_prefix,
),
"sharegpt": lambda: ShareGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
@ -1541,6 +1565,14 @@ class CustomDataset(BenchmarkDataset):
request_id_prefix: str = "",
**kwargs,
) -> list:
# load all data if needed
self.num_available_samples = len(self.data)
if num_requests <= 0:
num_requests = self.num_available_samples
logger.info("num_requests is set to 0 or negative, "
"so using all available samples: %d",
num_requests)
sampled_requests = []
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
@ -1572,6 +1604,52 @@ class CustomDataset(BenchmarkDataset):
return sampled_requests
# -----------------------------------------------------------------------------
# Spec Bench Dataset Implementation
# -----------------------------------------------------------------------------
class SpecBench(CustomDataset):
"""
Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
Download the dataset using:
wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
""" # noqa: E501
def __init__(self, **kwargs) -> None:
self.category = kwargs.pop("category", None)
super().__init__(**kwargs)
self.load_data()
def load_data(self) -> None:
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")
self.data = []
# Load the JSONL file
jsonl_data = pd.read_json(path_or_buf=self.dataset_path,
lines=True)
# check if the JSONL file has a 'turns' column
if "turns" not in jsonl_data.columns:
raise ValueError("JSONL file must contain a 'turns' column.")
for _, row in jsonl_data.iterrows():
# sample only from a specific category if specified
if (not self.category) or (self.category == row['category']):
prompt = row["turns"][0]
self.data.append({"prompt": prompt})
random.seed(self.random_seed)
random.shuffle(self.data)
def sample(self, **kwargs) -> list:
# leverage CustomDataset sample
kwargs["skip_chat_template"] = False
return super().sample(**kwargs)
# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------