mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 08:21:50 +08:00
[Benchmark] Add option to skip oversampling in benchmark (#24457)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
parent
0dc9cbb527
commit
fb1a8f932a
@ -198,8 +198,9 @@ class BenchmarkDataset(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "") -> list[SampleRequest]:
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False) -> list[SampleRequest]:
|
||||
"""
|
||||
Abstract method to generate sample requests from the dataset.
|
||||
|
||||
@ -224,6 +225,7 @@ class BenchmarkDataset(ABC):
|
||||
requests: list[SampleRequest],
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Oversamples the list of requests if its size is less than the desired
|
||||
@ -236,6 +238,11 @@ class BenchmarkDataset(ABC):
|
||||
request_id_prefix (str) The prefix of the request ids.
|
||||
|
||||
"""
|
||||
if no_oversample:
|
||||
logger.info("Skipping oversampling. " \
|
||||
"Total samples: %d.", len(requests))
|
||||
return
|
||||
|
||||
if len(requests) < num_requests:
|
||||
random.seed(self.random_seed)
|
||||
additional = deepcopy(
|
||||
@ -405,6 +412,7 @@ class RandomDataset(BenchmarkDataset):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||
range_ratio: float = DEFAULT_RANGE_RATIO,
|
||||
input_len: int = DEFAULT_INPUT_LEN,
|
||||
@ -832,6 +840,7 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN,
|
||||
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
|
||||
input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
|
||||
@ -959,6 +968,7 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
@ -1002,7 +1012,10 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
request_id=request_id_prefix + str(ind),
|
||||
))
|
||||
ind += 1
|
||||
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
|
||||
self.maybe_oversample_requests(samples,
|
||||
num_requests,
|
||||
request_id_prefix,
|
||||
no_oversample)
|
||||
return samples
|
||||
|
||||
|
||||
@ -1036,6 +1049,12 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
help="Path to the sharegpt/sonnet dataset. "
|
||||
"Or the huggingface dataset ID if using HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-oversample",
|
||||
action="store_true",
|
||||
help="Do not oversample if the dataset has " \
|
||||
"fewer samples than num-prompts.",
|
||||
)
|
||||
|
||||
# group for dataset specific arguments
|
||||
custom_group = parser.add_argument_group("custom dataset options")
|
||||
@ -1322,6 +1341,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
output_len=args.custom_output_len,
|
||||
skip_chat_template=args.custom_skip_chat_template,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
no_oversample=args.no_oversample,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sonnet":
|
||||
@ -1336,6 +1356,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=False,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
no_oversample=args.no_oversample,
|
||||
)
|
||||
else:
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
@ -1348,6 +1369,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=True,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
no_oversample=args.no_oversample,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "hf":
|
||||
@ -1443,6 +1465,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.hf_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
no_oversample=args.no_oversample,
|
||||
**hf_kwargs
|
||||
)
|
||||
|
||||
@ -1456,6 +1479,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.spec_bench_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
no_oversample=args.no_oversample,
|
||||
),
|
||||
"sharegpt": lambda: ShareGPTDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
@ -1464,6 +1488,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
num_requests=args.num_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
no_oversample=args.no_oversample,
|
||||
),
|
||||
"burstgpt": lambda: BurstGPTDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
@ -1471,6 +1496,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
no_oversample=args.no_oversample,
|
||||
),
|
||||
"random": lambda: RandomDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
@ -1483,6 +1509,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
range_ratio=args.random_range_ratio,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
batchsize=args.random_batch_size,
|
||||
no_oversample=args.no_oversample,
|
||||
),
|
||||
"random-mm":
|
||||
lambda: RandomMultiModalDataset(
|
||||
@ -1499,6 +1526,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio,
|
||||
bucket_config=args.random_mm_bucket_config,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
no_oversample=args.no_oversample,
|
||||
),
|
||||
"prefix_repetition":
|
||||
lambda: PrefixRepetitionRandomDataset(
|
||||
@ -1511,6 +1539,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
num_prefixes=args.prefix_repetition_num_prefixes,
|
||||
output_len=args.prefix_repetition_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
no_oversample=args.no_oversample,
|
||||
),
|
||||
}
|
||||
|
||||
@ -1592,6 +1621,7 @@ class CustomDataset(BenchmarkDataset):
|
||||
enable_multimodal_chat: bool = False,
|
||||
skip_chat_template: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
# load all data if needed
|
||||
@ -1628,7 +1658,7 @@ class CustomDataset(BenchmarkDataset):
|
||||
request_id=request_id_prefix + str(i),
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix)
|
||||
request_id_prefix, no_oversample)
|
||||
|
||||
return sampled_requests
|
||||
|
||||
@ -1719,6 +1749,7 @@ class SonnetDataset(BenchmarkDataset):
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
return_prompt_formatted: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
# Calculate average token length for a poem line.
|
||||
@ -1814,6 +1845,7 @@ class BurstGPTDataset(BenchmarkDataset):
|
||||
max_loras: Optional[int] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
samples = []
|
||||
@ -1893,6 +1925,7 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs) -> list:
|
||||
# Filter examples with at least 2 conversations
|
||||
filtered_data = self.data.filter(
|
||||
@ -1934,7 +1967,7 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
))
|
||||
ind += 1
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix)
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -1964,6 +1997,7 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = (output_len
|
||||
@ -1993,7 +2027,7 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
request_id=request_id_prefix + str(i),
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix)
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -2023,6 +2057,7 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
@ -2054,7 +2089,7 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
request_id=request_id_prefix + str(i),
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix)
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -2085,6 +2120,7 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = (output_len
|
||||
@ -2115,7 +2151,7 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
request_id=request_id_prefix + str(i),
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix)
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -2149,6 +2185,7 @@ class BlazeditDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
min_distance: float = 0.0,
|
||||
max_distance: float = 1.0,
|
||||
**kwargs,
|
||||
@ -2202,7 +2239,7 @@ Please generate the new code file in the "New file" section below.""" # noqa: E5
|
||||
request_id=request_id_prefix + str(i),
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix)
|
||||
request_id_prefix, no_oversample)
|
||||
|
||||
return sampled_requests
|
||||
|
||||
@ -2226,6 +2263,7 @@ class AIMODataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs) -> list:
|
||||
sampled_requests = []
|
||||
ind = 0
|
||||
@ -2258,7 +2296,7 @@ class AIMODataset(HuggingFaceDataset):
|
||||
))
|
||||
ind += 1
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix)
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -2330,6 +2368,7 @@ class NextEditPredictionDataset(HuggingFaceDataset):
|
||||
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs):
|
||||
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.hf_name)
|
||||
if formatting_prompt_func is None:
|
||||
@ -2347,7 +2386,10 @@ class NextEditPredictionDataset(HuggingFaceDataset):
|
||||
))
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
|
||||
self.maybe_oversample_requests(samples,
|
||||
num_requests,
|
||||
request_id_prefix,
|
||||
no_oversample)
|
||||
return samples
|
||||
|
||||
|
||||
@ -2398,6 +2440,7 @@ class ASRDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = (output_len
|
||||
@ -2436,7 +2479,7 @@ class ASRDataset(HuggingFaceDataset):
|
||||
skipped,
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix)
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -2474,6 +2517,7 @@ class MLPerfDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
# Force dynamic output length based on reference completion.
|
||||
@ -2520,7 +2564,7 @@ class MLPerfDataset(HuggingFaceDataset):
|
||||
ind += 1
|
||||
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix)
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -2554,6 +2598,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
||||
num_prefixes: int = DEFAULT_NUM_PREFIXES,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
vocab_size = tokenizer.vocab_size
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user