mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 10:34:58 +08:00
[Benchmarking] Add disable_shuffle option for dataset loading (#26258)
Signed-off-by: Yasmin Moslem <48152713+ymoslem@users.noreply.github.com>
This commit is contained in:
parent
039b6bade3
commit
7c2ec0fe87
@ -96,6 +96,8 @@ class BenchmarkDataset(ABC):
|
||||
self,
|
||||
dataset_path: Optional[str] = None,
|
||||
random_seed: int = DEFAULT_SEED,
|
||||
disable_shuffle: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the BenchmarkDataset with an optional dataset path and random
|
||||
@ -111,6 +113,7 @@ class BenchmarkDataset(ABC):
|
||||
# Set the random seed, ensuring that a None value is replaced with the
|
||||
# default seed.
|
||||
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
|
||||
self.disable_shuffle = disable_shuffle
|
||||
self.data = None
|
||||
|
||||
def apply_multimodal_chat_transformation(
|
||||
@ -1044,7 +1047,8 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
if "conversations" in entry and len(entry["conversations"]) >= 2
|
||||
]
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
if not getattr(self, "disable_shuffle", False):
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@ -1175,6 +1179,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
action="store_true",
|
||||
help="Skip applying chat template to prompt for datasets that support it.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-shuffle",
|
||||
action="store_true",
|
||||
help="Disable shuffling of dataset samples for deterministic ordering.",
|
||||
)
|
||||
|
||||
# group for dataset specific arguments
|
||||
custom_group = parser.add_argument_group("custom dataset options")
|
||||
@ -1441,7 +1450,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
args.request_id_prefix = ""
|
||||
|
||||
if args.dataset_name == "custom":
|
||||
dataset = CustomDataset(dataset_path=args.dataset_path)
|
||||
dataset = CustomDataset(
|
||||
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
|
||||
)
|
||||
input_requests = dataset.sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
@ -1452,7 +1463,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sonnet":
|
||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
||||
dataset = SonnetDataset(
|
||||
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
|
||||
)
|
||||
# For the "sonnet" dataset, formatting depends on the backend.
|
||||
if args.backend == "openai-chat":
|
||||
input_requests = dataset.sample(
|
||||
@ -1586,6 +1599,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
random_seed=args.seed,
|
||||
no_stream=args.no_stream,
|
||||
hf_name=args.hf_name,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
@ -1600,7 +1614,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
# For datasets that follow a similar structure, use a mapping.
|
||||
dataset_mapping = {
|
||||
"spec_bench": lambda: SpecBench(
|
||||
dataset_path=args.dataset_path, category=args.spec_bench_category
|
||||
dataset_path=args.dataset_path,
|
||||
category=args.spec_bench_category,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
@ -1609,7 +1625,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
no_oversample=args.no_oversample,
|
||||
),
|
||||
"sharegpt": lambda: ShareGPTDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
random_seed=args.seed,
|
||||
dataset_path=args.dataset_path,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
@ -1618,7 +1636,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
no_oversample=args.no_oversample,
|
||||
),
|
||||
"burstgpt": lambda: BurstGPTDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
random_seed=args.seed,
|
||||
dataset_path=args.dataset_path,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
@ -1626,7 +1646,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
no_oversample=args.no_oversample,
|
||||
),
|
||||
"random": lambda: RandomDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
random_seed=args.seed,
|
||||
dataset_path=args.dataset_path,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
@ -1639,7 +1661,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
no_oversample=args.no_oversample,
|
||||
),
|
||||
"random-mm": lambda: RandomMultiModalDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
random_seed=args.seed,
|
||||
dataset_path=args.dataset_path,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
@ -1655,7 +1679,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
no_oversample=args.no_oversample,
|
||||
),
|
||||
"prefix_repetition": lambda: PrefixRepetitionRandomDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
random_seed=args.seed,
|
||||
dataset_path=args.dataset_path,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
@ -1733,7 +1759,8 @@ class CustomDataset(BenchmarkDataset):
|
||||
)
|
||||
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
if not getattr(self, "disable_shuffle", False):
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@ -1825,7 +1852,8 @@ class SpecBench(CustomDataset):
|
||||
self.data.append({"prompt": prompt})
|
||||
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
if not getattr(self, "disable_shuffle", False):
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(self, **kwargs) -> list:
|
||||
# leverage CustomDataset sample
|
||||
@ -2033,7 +2061,8 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
split=self.dataset_split,
|
||||
streaming=self.load_stream,
|
||||
)
|
||||
self.data = self.data.shuffle(seed=self.random_seed)
|
||||
if not getattr(self, "disable_shuffle", False):
|
||||
self.data = self.data.shuffle(seed=self.random_seed)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -2849,7 +2878,8 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
||||
abs(token_mismatch_total),
|
||||
sign,
|
||||
)
|
||||
random.shuffle(requests)
|
||||
if not getattr(self, "disable_shuffle", False):
|
||||
random.shuffle(requests)
|
||||
return requests
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user