mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 02:14:00 +08:00
[Spec Decode][Benchmark] Add Spec Bench Dataset for benchmarking (#23563)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
parent
6f4a82f8b5
commit
3feeeb9fea
@ -1020,7 +1020,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
default="random",
|
||||
choices=[
|
||||
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
|
||||
"custom", "prefix_repetition"
|
||||
"custom", "prefix_repetition", "spec_bench"
|
||||
],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
@ -1053,6 +1053,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
"Skip applying chat template to prompt, used only for custom dataset.",
|
||||
)
|
||||
|
||||
spec_bench_group = parser.add_argument_group("spec bench dataset options")
|
||||
spec_bench_group.add_argument(
|
||||
"--spec-bench-output-len",
|
||||
type=int,
|
||||
default=256,
|
||||
help=
|
||||
"Num of output tokens per request, used only for spec bench dataset.",
|
||||
)
|
||||
spec_bench_group.add_argument(
|
||||
"--spec-bench-category",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Category for spec bench dataset. If None, use all categories.",
|
||||
)
|
||||
|
||||
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-input-len",
|
||||
@ -1404,6 +1420,14 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
else:
|
||||
# For datasets that follow a similar structure, use a mapping.
|
||||
dataset_mapping = {
|
||||
"spec_bench":
|
||||
lambda: SpecBench(dataset_path=args.dataset_path,
|
||||
category=args.spec_bench_category).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.spec_bench_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
),
|
||||
"sharegpt": lambda: ShareGPTDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
).sample(
|
||||
@ -1541,6 +1565,14 @@ class CustomDataset(BenchmarkDataset):
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
# load all data if needed
|
||||
self.num_available_samples = len(self.data)
|
||||
if num_requests <= 0:
|
||||
num_requests = self.num_available_samples
|
||||
logger.info("num_requests is set to 0 or negative, "
|
||||
"so using all available samples: %d",
|
||||
num_requests)
|
||||
|
||||
sampled_requests = []
|
||||
for i, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
@ -1572,6 +1604,52 @@ class CustomDataset(BenchmarkDataset):
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Spec Bench Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SpecBench(CustomDataset):
|
||||
"""
|
||||
Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
|
||||
Download the dataset using:
|
||||
wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self.category = kwargs.pop("category", None)
|
||||
super().__init__(**kwargs)
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
if self.dataset_path is None:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
self.data = []
|
||||
|
||||
# Load the JSONL file
|
||||
jsonl_data = pd.read_json(path_or_buf=self.dataset_path,
|
||||
lines=True)
|
||||
|
||||
# check if the JSONL file has a 'turns' column
|
||||
if "turns" not in jsonl_data.columns:
|
||||
raise ValueError("JSONL file must contain a 'turns' column.")
|
||||
|
||||
for _, row in jsonl_data.iterrows():
|
||||
# sample only from a specific category if specified
|
||||
if (not self.category) or (self.category == row['category']):
|
||||
prompt = row["turns"][0]
|
||||
self.data.append({"prompt": prompt})
|
||||
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(self, **kwargs) -> list:
|
||||
# leverage CustomDataset sample
|
||||
kwargs["skip_chat_template"] = False
|
||||
return super().sample(**kwargs)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Sonnet Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user