[Benchmark] Add option to skip oversampling in benchmark (#24457)

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
Ekagra Ranjan 2025-09-09 18:00:17 -04:00 committed by GitHub
parent 0dc9cbb527
commit fb1a8f932a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -198,8 +198,9 @@ class BenchmarkDataset(ABC):
@abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "") -> list[SampleRequest]:
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False) -> list[SampleRequest]:
"""
Abstract method to generate sample requests from the dataset.
@ -224,6 +225,7 @@ class BenchmarkDataset(ABC):
requests: list[SampleRequest],
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
) -> None:
"""
Oversamples the list of requests if its size is less than the desired
@ -236,6 +238,11 @@ class BenchmarkDataset(ABC):
request_id_prefix (str) The prefix of the request ids.
"""
if no_oversample:
logger.info("Skipping oversampling. " \
"Total samples: %d.", len(requests))
return
if len(requests) < num_requests:
random.seed(self.random_seed)
additional = deepcopy(
@ -405,6 +412,7 @@ class RandomDataset(BenchmarkDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
prefix_len: int = DEFAULT_PREFIX_LEN,
range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN,
@ -832,6 +840,7 @@ class RandomMultiModalDataset(RandomDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN,
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
@ -959,6 +968,7 @@ class ShareGPTDataset(BenchmarkDataset):
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list:
samples: list = []
@ -1002,7 +1012,10 @@ class ShareGPTDataset(BenchmarkDataset):
request_id=request_id_prefix + str(ind),
))
ind += 1
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
self.maybe_oversample_requests(samples,
num_requests,
request_id_prefix,
no_oversample)
return samples
@ -1036,6 +1049,12 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
help="Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset.",
)
parser.add_argument(
"--no-oversample",
action="store_true",
help="Do not oversample if the dataset has " \
"fewer samples than num-prompts.",
)
# group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options")
@ -1322,6 +1341,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
)
elif args.dataset_name == "sonnet":
@ -1336,6 +1356,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer,
return_prompt_formatted=False,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
)
else:
assert tokenizer.chat_template or tokenizer.default_chat_template, (
@ -1348,6 +1369,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer,
return_prompt_formatted=True,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
)
elif args.dataset_name == "hf":
@ -1443,6 +1465,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer,
output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
**hf_kwargs
)
@ -1456,6 +1479,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer,
output_len=args.spec_bench_output_len,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
),
"sharegpt": lambda: ShareGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
@ -1464,6 +1488,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
),
"burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
@ -1471,6 +1496,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer,
num_requests=args.num_prompts,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
),
"random": lambda: RandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
@ -1483,6 +1509,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix,
batchsize=args.random_batch_size,
no_oversample=args.no_oversample,
),
"random-mm":
lambda: RandomMultiModalDataset(
@ -1499,6 +1526,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio,
bucket_config=args.random_mm_bucket_config,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
),
"prefix_repetition":
lambda: PrefixRepetitionRandomDataset(
@ -1511,6 +1539,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
num_prefixes=args.prefix_repetition_num_prefixes,
output_len=args.prefix_repetition_output_len,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
),
}
@ -1592,6 +1621,7 @@ class CustomDataset(BenchmarkDataset):
enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list:
# load all data if needed
@ -1628,7 +1658,7 @@ class CustomDataset(BenchmarkDataset):
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
request_id_prefix, no_oversample)
return sampled_requests
@ -1719,6 +1749,7 @@ class SonnetDataset(BenchmarkDataset):
output_len: int = DEFAULT_OUTPUT_LEN,
return_prompt_formatted: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list:
# Calculate average token length for a poem line.
@ -1814,6 +1845,7 @@ class BurstGPTDataset(BenchmarkDataset):
max_loras: Optional[int] = None,
lora_path: Optional[str] = None,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list[SampleRequest]:
samples = []
@ -1893,6 +1925,7 @@ class ConversationDataset(HuggingFaceDataset):
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs) -> list:
# Filter examples with at least 2 conversations
filtered_data = self.data.filter(
@ -1934,7 +1967,7 @@ class ConversationDataset(HuggingFaceDataset):
))
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
request_id_prefix, no_oversample)
return sampled_requests
@ -1964,6 +1997,7 @@ class VisionArenaDataset(HuggingFaceDataset):
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list:
output_len = (output_len
@ -1993,7 +2027,7 @@ class VisionArenaDataset(HuggingFaceDataset):
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
request_id_prefix, no_oversample)
return sampled_requests
@ -2023,6 +2057,7 @@ class InstructCoderDataset(HuggingFaceDataset):
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
@ -2054,7 +2089,7 @@ class InstructCoderDataset(HuggingFaceDataset):
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
request_id_prefix, no_oversample)
return sampled_requests
@ -2085,6 +2120,7 @@ class MTBenchDataset(HuggingFaceDataset):
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list:
output_len = (output_len
@ -2115,7 +2151,7 @@ class MTBenchDataset(HuggingFaceDataset):
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
request_id_prefix, no_oversample)
return sampled_requests
@ -2149,6 +2185,7 @@ class BlazeditDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
request_id_prefix: str = "",
no_oversample: bool = False,
min_distance: float = 0.0,
max_distance: float = 1.0,
**kwargs,
@ -2202,7 +2239,7 @@ Please generate the new code file in the "New file" section below.""" # noqa: E5
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
request_id_prefix, no_oversample)
return sampled_requests
@ -2226,6 +2263,7 @@ class AIMODataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs) -> list:
sampled_requests = []
ind = 0
@ -2258,7 +2296,7 @@ class AIMODataset(HuggingFaceDataset):
))
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
request_id_prefix, no_oversample)
return sampled_requests
@ -2330,6 +2368,7 @@ class NextEditPredictionDataset(HuggingFaceDataset):
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.hf_name)
if formatting_prompt_func is None:
@ -2347,7 +2386,10 @@ class NextEditPredictionDataset(HuggingFaceDataset):
))
if len(samples) >= num_requests:
break
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
self.maybe_oversample_requests(samples,
num_requests,
request_id_prefix,
no_oversample)
return samples
@ -2398,6 +2440,7 @@ class ASRDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list:
output_len = (output_len
@ -2436,7 +2479,7 @@ class ASRDataset(HuggingFaceDataset):
skipped,
)
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
request_id_prefix, no_oversample)
return sampled_requests
@ -2474,6 +2517,7 @@ class MLPerfDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list[SampleRequest]:
# Force dynamic output length based on reference completion.
@ -2520,7 +2564,7 @@ class MLPerfDataset(HuggingFaceDataset):
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
request_id_prefix, no_oversample)
return sampled_requests
@ -2554,6 +2598,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
num_prefixes: int = DEFAULT_NUM_PREFIXES,
output_len: int = DEFAULT_OUTPUT_LEN,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list[SampleRequest]:
vocab_size = tokenizer.vocab_size