mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:25:01 +08:00
[Benchmarking] Add disable_shuffle option for dataset loading (#26258)
Signed-off-by: Yasmin Moslem <48152713+ymoslem@users.noreply.github.com>
This commit is contained in:
parent
039b6bade3
commit
7c2ec0fe87
@ -96,6 +96,8 @@ class BenchmarkDataset(ABC):
|
|||||||
self,
|
self,
|
||||||
dataset_path: Optional[str] = None,
|
dataset_path: Optional[str] = None,
|
||||||
random_seed: int = DEFAULT_SEED,
|
random_seed: int = DEFAULT_SEED,
|
||||||
|
disable_shuffle: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the BenchmarkDataset with an optional dataset path and random
|
Initialize the BenchmarkDataset with an optional dataset path and random
|
||||||
@ -111,6 +113,7 @@ class BenchmarkDataset(ABC):
|
|||||||
# Set the random seed, ensuring that a None value is replaced with the
|
# Set the random seed, ensuring that a None value is replaced with the
|
||||||
# default seed.
|
# default seed.
|
||||||
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
|
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
|
||||||
|
self.disable_shuffle = disable_shuffle
|
||||||
self.data = None
|
self.data = None
|
||||||
|
|
||||||
def apply_multimodal_chat_transformation(
|
def apply_multimodal_chat_transformation(
|
||||||
@ -1044,6 +1047,7 @@ class ShareGPTDataset(BenchmarkDataset):
|
|||||||
if "conversations" in entry and len(entry["conversations"]) >= 2
|
if "conversations" in entry and len(entry["conversations"]) >= 2
|
||||||
]
|
]
|
||||||
random.seed(self.random_seed)
|
random.seed(self.random_seed)
|
||||||
|
if not getattr(self, "disable_shuffle", False):
|
||||||
random.shuffle(self.data)
|
random.shuffle(self.data)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
@ -1175,6 +1179,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Skip applying chat template to prompt for datasets that support it.",
|
help="Skip applying chat template to prompt for datasets that support it.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-shuffle",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable shuffling of dataset samples for deterministic ordering.",
|
||||||
|
)
|
||||||
|
|
||||||
# group for dataset specific arguments
|
# group for dataset specific arguments
|
||||||
custom_group = parser.add_argument_group("custom dataset options")
|
custom_group = parser.add_argument_group("custom dataset options")
|
||||||
@ -1441,7 +1450,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
args.request_id_prefix = ""
|
args.request_id_prefix = ""
|
||||||
|
|
||||||
if args.dataset_name == "custom":
|
if args.dataset_name == "custom":
|
||||||
dataset = CustomDataset(dataset_path=args.dataset_path)
|
dataset = CustomDataset(
|
||||||
|
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
|
||||||
|
)
|
||||||
input_requests = dataset.sample(
|
input_requests = dataset.sample(
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -1452,7 +1463,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif args.dataset_name == "sonnet":
|
elif args.dataset_name == "sonnet":
|
||||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
dataset = SonnetDataset(
|
||||||
|
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
|
||||||
|
)
|
||||||
# For the "sonnet" dataset, formatting depends on the backend.
|
# For the "sonnet" dataset, formatting depends on the backend.
|
||||||
if args.backend == "openai-chat":
|
if args.backend == "openai-chat":
|
||||||
input_requests = dataset.sample(
|
input_requests = dataset.sample(
|
||||||
@ -1586,6 +1599,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
random_seed=args.seed,
|
random_seed=args.seed,
|
||||||
no_stream=args.no_stream,
|
no_stream=args.no_stream,
|
||||||
hf_name=args.hf_name,
|
hf_name=args.hf_name,
|
||||||
|
disable_shuffle=args.disable_shuffle,
|
||||||
).sample(
|
).sample(
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -1600,7 +1614,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
# For datasets that follow a similar structure, use a mapping.
|
# For datasets that follow a similar structure, use a mapping.
|
||||||
dataset_mapping = {
|
dataset_mapping = {
|
||||||
"spec_bench": lambda: SpecBench(
|
"spec_bench": lambda: SpecBench(
|
||||||
dataset_path=args.dataset_path, category=args.spec_bench_category
|
dataset_path=args.dataset_path,
|
||||||
|
category=args.spec_bench_category,
|
||||||
|
disable_shuffle=args.disable_shuffle,
|
||||||
).sample(
|
).sample(
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -1609,7 +1625,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
no_oversample=args.no_oversample,
|
no_oversample=args.no_oversample,
|
||||||
),
|
),
|
||||||
"sharegpt": lambda: ShareGPTDataset(
|
"sharegpt": lambda: ShareGPTDataset(
|
||||||
random_seed=args.seed, dataset_path=args.dataset_path
|
random_seed=args.seed,
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
disable_shuffle=args.disable_shuffle,
|
||||||
).sample(
|
).sample(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
@ -1618,7 +1636,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
no_oversample=args.no_oversample,
|
no_oversample=args.no_oversample,
|
||||||
),
|
),
|
||||||
"burstgpt": lambda: BurstGPTDataset(
|
"burstgpt": lambda: BurstGPTDataset(
|
||||||
random_seed=args.seed, dataset_path=args.dataset_path
|
random_seed=args.seed,
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
disable_shuffle=args.disable_shuffle,
|
||||||
).sample(
|
).sample(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
@ -1626,7 +1646,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
no_oversample=args.no_oversample,
|
no_oversample=args.no_oversample,
|
||||||
),
|
),
|
||||||
"random": lambda: RandomDataset(
|
"random": lambda: RandomDataset(
|
||||||
random_seed=args.seed, dataset_path=args.dataset_path
|
random_seed=args.seed,
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
disable_shuffle=args.disable_shuffle,
|
||||||
).sample(
|
).sample(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
@ -1639,7 +1661,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
no_oversample=args.no_oversample,
|
no_oversample=args.no_oversample,
|
||||||
),
|
),
|
||||||
"random-mm": lambda: RandomMultiModalDataset(
|
"random-mm": lambda: RandomMultiModalDataset(
|
||||||
random_seed=args.seed, dataset_path=args.dataset_path
|
random_seed=args.seed,
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
disable_shuffle=args.disable_shuffle,
|
||||||
).sample(
|
).sample(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
@ -1655,7 +1679,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
no_oversample=args.no_oversample,
|
no_oversample=args.no_oversample,
|
||||||
),
|
),
|
||||||
"prefix_repetition": lambda: PrefixRepetitionRandomDataset(
|
"prefix_repetition": lambda: PrefixRepetitionRandomDataset(
|
||||||
random_seed=args.seed, dataset_path=args.dataset_path
|
random_seed=args.seed,
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
disable_shuffle=args.disable_shuffle,
|
||||||
).sample(
|
).sample(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
@ -1733,6 +1759,7 @@ class CustomDataset(BenchmarkDataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
random.seed(self.random_seed)
|
random.seed(self.random_seed)
|
||||||
|
if not getattr(self, "disable_shuffle", False):
|
||||||
random.shuffle(self.data)
|
random.shuffle(self.data)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
@ -1825,6 +1852,7 @@ class SpecBench(CustomDataset):
|
|||||||
self.data.append({"prompt": prompt})
|
self.data.append({"prompt": prompt})
|
||||||
|
|
||||||
random.seed(self.random_seed)
|
random.seed(self.random_seed)
|
||||||
|
if not getattr(self, "disable_shuffle", False):
|
||||||
random.shuffle(self.data)
|
random.shuffle(self.data)
|
||||||
|
|
||||||
def sample(self, **kwargs) -> list:
|
def sample(self, **kwargs) -> list:
|
||||||
@ -2033,6 +2061,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
|||||||
split=self.dataset_split,
|
split=self.dataset_split,
|
||||||
streaming=self.load_stream,
|
streaming=self.load_stream,
|
||||||
)
|
)
|
||||||
|
if not getattr(self, "disable_shuffle", False):
|
||||||
self.data = self.data.shuffle(seed=self.random_seed)
|
self.data = self.data.shuffle(seed=self.random_seed)
|
||||||
|
|
||||||
|
|
||||||
@ -2849,6 +2878,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
|||||||
abs(token_mismatch_total),
|
abs(token_mismatch_total),
|
||||||
sign,
|
sign,
|
||||||
)
|
)
|
||||||
|
if not getattr(self, "disable_shuffle", False):
|
||||||
random.shuffle(requests)
|
random.shuffle(requests)
|
||||||
return requests
|
return requests
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user