[Benchmarking] Add disable_shuffle option for dataset loading (#26258)

Signed-off-by: Yasmin Moslem <48152713+ymoslem@users.noreply.github.com>
This commit is contained in:
Yasmin Moslem 2025-10-06 08:05:44 +01:00 committed by GitHub
parent 039b6bade3
commit 7c2ec0fe87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -96,6 +96,8 @@ class BenchmarkDataset(ABC):
self, self,
dataset_path: Optional[str] = None, dataset_path: Optional[str] = None,
random_seed: int = DEFAULT_SEED, random_seed: int = DEFAULT_SEED,
disable_shuffle: bool = False,
**kwargs,
) -> None: ) -> None:
""" """
Initialize the BenchmarkDataset with an optional dataset path and random Initialize the BenchmarkDataset with an optional dataset path and random
@ -111,6 +113,7 @@ class BenchmarkDataset(ABC):
# Set the random seed, ensuring that a None value is replaced with the # Set the random seed, ensuring that a None value is replaced with the
# default seed. # default seed.
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
self.disable_shuffle = disable_shuffle
self.data = None self.data = None
def apply_multimodal_chat_transformation( def apply_multimodal_chat_transformation(
@ -1044,6 +1047,7 @@ class ShareGPTDataset(BenchmarkDataset):
if "conversations" in entry and len(entry["conversations"]) >= 2 if "conversations" in entry and len(entry["conversations"]) >= 2
] ]
random.seed(self.random_seed) random.seed(self.random_seed)
if not getattr(self, "disable_shuffle", False):
random.shuffle(self.data) random.shuffle(self.data)
def sample( def sample(
@ -1175,6 +1179,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
action="store_true", action="store_true",
help="Skip applying chat template to prompt for datasets that support it.", help="Skip applying chat template to prompt for datasets that support it.",
) )
parser.add_argument(
"--disable-shuffle",
action="store_true",
help="Disable shuffling of dataset samples for deterministic ordering.",
)
# group for dataset specific arguments # group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options") custom_group = parser.add_argument_group("custom dataset options")
@ -1441,7 +1450,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
args.request_id_prefix = "" args.request_id_prefix = ""
if args.dataset_name == "custom": if args.dataset_name == "custom":
dataset = CustomDataset(dataset_path=args.dataset_path) dataset = CustomDataset(
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
)
input_requests = dataset.sample( input_requests = dataset.sample(
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -1452,7 +1463,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
) )
elif args.dataset_name == "sonnet": elif args.dataset_name == "sonnet":
dataset = SonnetDataset(dataset_path=args.dataset_path) dataset = SonnetDataset(
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
)
# For the "sonnet" dataset, formatting depends on the backend. # For the "sonnet" dataset, formatting depends on the backend.
if args.backend == "openai-chat": if args.backend == "openai-chat":
input_requests = dataset.sample( input_requests = dataset.sample(
@ -1586,6 +1599,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
random_seed=args.seed, random_seed=args.seed,
no_stream=args.no_stream, no_stream=args.no_stream,
hf_name=args.hf_name, hf_name=args.hf_name,
disable_shuffle=args.disable_shuffle,
).sample( ).sample(
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -1600,7 +1614,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
# For datasets that follow a similar structure, use a mapping. # For datasets that follow a similar structure, use a mapping.
dataset_mapping = { dataset_mapping = {
"spec_bench": lambda: SpecBench( "spec_bench": lambda: SpecBench(
dataset_path=args.dataset_path, category=args.spec_bench_category dataset_path=args.dataset_path,
category=args.spec_bench_category,
disable_shuffle=args.disable_shuffle,
).sample( ).sample(
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -1609,7 +1625,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample, no_oversample=args.no_oversample,
), ),
"sharegpt": lambda: ShareGPTDataset( "sharegpt": lambda: ShareGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample( ).sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
@ -1618,7 +1636,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample, no_oversample=args.no_oversample,
), ),
"burstgpt": lambda: BurstGPTDataset( "burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample( ).sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
@ -1626,7 +1646,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample, no_oversample=args.no_oversample,
), ),
"random": lambda: RandomDataset( "random": lambda: RandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample( ).sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
@ -1639,7 +1661,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample, no_oversample=args.no_oversample,
), ),
"random-mm": lambda: RandomMultiModalDataset( "random-mm": lambda: RandomMultiModalDataset(
random_seed=args.seed, dataset_path=args.dataset_path random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample( ).sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
@ -1655,7 +1679,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
no_oversample=args.no_oversample, no_oversample=args.no_oversample,
), ),
"prefix_repetition": lambda: PrefixRepetitionRandomDataset( "prefix_repetition": lambda: PrefixRepetitionRandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample( ).sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
@ -1733,6 +1759,7 @@ class CustomDataset(BenchmarkDataset):
) )
random.seed(self.random_seed) random.seed(self.random_seed)
if not getattr(self, "disable_shuffle", False):
random.shuffle(self.data) random.shuffle(self.data)
def sample( def sample(
@ -1825,6 +1852,7 @@ class SpecBench(CustomDataset):
self.data.append({"prompt": prompt}) self.data.append({"prompt": prompt})
random.seed(self.random_seed) random.seed(self.random_seed)
if not getattr(self, "disable_shuffle", False):
random.shuffle(self.data) random.shuffle(self.data)
def sample(self, **kwargs) -> list: def sample(self, **kwargs) -> list:
@ -2033,6 +2061,7 @@ class HuggingFaceDataset(BenchmarkDataset):
split=self.dataset_split, split=self.dataset_split,
streaming=self.load_stream, streaming=self.load_stream,
) )
if not getattr(self, "disable_shuffle", False):
self.data = self.data.shuffle(seed=self.random_seed) self.data = self.data.shuffle(seed=self.random_seed)
@ -2849,6 +2878,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
abs(token_mismatch_total), abs(token_mismatch_total),
sign, sign,
) )
if not getattr(self, "disable_shuffle", False):
random.shuffle(requests) random.shuffle(requests)
return requests return requests