[Benchmarking] Add disable_shuffle option for dataset loading (#26258)

Signed-off-by: Yasmin Moslem <48152713+ymoslem@users.noreply.github.com>
This commit is contained in:
Yasmin Moslem 2025-10-06 08:05:44 +01:00 committed by GitHub
parent 039b6bade3
commit 7c2ec0fe87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -96,6 +96,8 @@ class BenchmarkDataset(ABC):
self,
dataset_path: Optional[str] = None,
random_seed: int = DEFAULT_SEED,
disable_shuffle: bool = False,
**kwargs,
) -> None:
"""
Initialize the BenchmarkDataset with an optional dataset path and random
@ -111,6 +113,7 @@ class BenchmarkDataset(ABC):
# Set the random seed, ensuring that a None value is replaced with the
# default seed.
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
self.disable_shuffle = disable_shuffle
self.data = None
def apply_multimodal_chat_transformation(
@ -1044,7 +1047,8 @@ class ShareGPTDataset(BenchmarkDataset):
if "conversations" in entry and len(entry["conversations"]) >= 2
]
random.seed(self.random_seed)
random.shuffle(self.data)
if not getattr(self, "disable_shuffle", False):
random.shuffle(self.data)
def sample(
self,
@ -1175,6 +1179,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
action="store_true",
help="Skip applying chat template to prompt for datasets that support it.",
)
parser.add_argument(
"--disable-shuffle",
action="store_true",
help="Disable shuffling of dataset samples for deterministic ordering.",
)
# group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options")
@ -1441,7 +1450,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
args.request_id_prefix = ""
if args.dataset_name == "custom":
dataset = CustomDataset(dataset_path=args.dataset_path)
dataset = CustomDataset(
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
)
input_requests = dataset.sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
@ -1452,7 +1463,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
)
elif args.dataset_name == "sonnet":
dataset = SonnetDataset(dataset_path=args.dataset_path)
dataset = SonnetDataset(
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
)
# For the "sonnet" dataset, formatting depends on the backend.
if args.backend == "openai-chat":
input_requests = dataset.sample(
@ -1586,6 +1599,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
random_seed=args.seed,
no_stream=args.no_stream,
hf_name=args.hf_name,
disable_shuffle=args.disable_shuffle,
).sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
@ -1600,7 +1614,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"spec_bench": lambda: SpecBench(
dataset_path=args.dataset_path, category=args.spec_bench_category
dataset_path=args.dataset_path,
category=args.spec_bench_category,
disable_shuffle=args.disable_shuffle,
).sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
@ -1609,7 +1625,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample,
),
"sharegpt": lambda: ShareGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
@ -1618,7 +1636,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample,
),
"burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
@ -1626,7 +1646,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample,
),
"random": lambda: RandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
@ -1639,7 +1661,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample,
),
"random-mm": lambda: RandomMultiModalDataset(
random_seed=args.seed, dataset_path=args.dataset_path
random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
@ -1655,7 +1679,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample,
),
"prefix_repetition": lambda: PrefixRepetitionRandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
@ -1733,7 +1759,8 @@ class CustomDataset(BenchmarkDataset):
)
random.seed(self.random_seed)
random.shuffle(self.data)
if not getattr(self, "disable_shuffle", False):
random.shuffle(self.data)
def sample(
self,
@ -1825,7 +1852,8 @@ class SpecBench(CustomDataset):
self.data.append({"prompt": prompt})
random.seed(self.random_seed)
random.shuffle(self.data)
if not getattr(self, "disable_shuffle", False):
random.shuffle(self.data)
def sample(self, **kwargs) -> list:
# leverage CustomDataset sample
@ -2033,7 +2061,8 @@ class HuggingFaceDataset(BenchmarkDataset):
split=self.dataset_split,
streaming=self.load_stream,
)
self.data = self.data.shuffle(seed=self.random_seed)
if not getattr(self, "disable_shuffle", False):
self.data = self.data.shuffle(seed=self.random_seed)
# -----------------------------------------------------------------------------
@ -2849,7 +2878,8 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
abs(token_mismatch_total),
sign,
)
random.shuffle(requests)
if not getattr(self, "disable_shuffle", False):
random.shuffle(requests)
return requests