mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 02:09:08 +08:00
[misc] partial prefix & random input generation benchmark (#9929)
Signed-off-by: rickyx <rickyx@anyscale.com>
This commit is contained in:
parent
2298e69b5f
commit
90a6c759ca
@ -54,13 +54,30 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
|
|||||||
print(f"cost time {end_time - start_time}")
|
print(f"cost time {end_time - start_time}")
|
||||||
|
|
||||||
|
|
||||||
def sample_requests(
|
@dataclasses.dataclass
|
||||||
|
class Request:
|
||||||
|
prompt: str
|
||||||
|
prompt_len: int
|
||||||
|
output_len: int
|
||||||
|
|
||||||
|
|
||||||
|
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
# Remove the special tokens.
|
||||||
|
vocab = {
|
||||||
|
k: v
|
||||||
|
for k, v in vocab.items() if k not in tokenizer.all_special_ids
|
||||||
|
}
|
||||||
|
return random.choices(list(vocab.values()), k=length)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_requests_from_dataset(
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
num_requests: int,
|
num_requests: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
input_length_range: Tuple[int, int],
|
input_length_range: Tuple[int, int],
|
||||||
fixed_output_len: Optional[int],
|
fixed_output_len: Optional[int],
|
||||||
) -> List[Tuple[str, int, int]]:
|
) -> List[Request]:
|
||||||
if fixed_output_len is not None and fixed_output_len < 4:
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
raise ValueError("output_len too small")
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
@ -77,31 +94,55 @@ def sample_requests(
|
|||||||
random.shuffle(dataset)
|
random.shuffle(dataset)
|
||||||
|
|
||||||
min_len, max_len = input_length_range
|
min_len, max_len = input_length_range
|
||||||
|
assert min_len >= 0 and max_len >= min_len, "input_length_range too small"
|
||||||
|
|
||||||
# Filter out sequences that are too long or too short
|
# Filter out sequences that are too long or too short
|
||||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
filtered_requests: List[Request] = []
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
if len(filtered_dataset) == num_requests:
|
if len(filtered_requests) == num_requests:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Tokenize the prompts and completions.
|
# Tokenize the prompts and completions.
|
||||||
prompt = dataset[i][0]
|
prompt_token_ids = tokenizer(dataset[i][0]).input_ids
|
||||||
prompt_token_ids = tokenizer(prompt).input_ids
|
prompt = tokenizer.decode(prompt_token_ids)
|
||||||
completion = dataset[i][1]
|
completion = dataset[i][1]
|
||||||
completion_token_ids = tokenizer(completion).input_ids
|
completion_token_ids = tokenizer(completion).input_ids
|
||||||
prompt_len = len(prompt_token_ids)
|
prompt_len = len(prompt_token_ids)
|
||||||
output_len = len(completion_token_ids
|
output_len = (len(completion_token_ids)
|
||||||
) if fixed_output_len is None else fixed_output_len
|
if fixed_output_len is None else fixed_output_len)
|
||||||
if prompt_len < 4 or output_len < 4:
|
|
||||||
# Prune too short sequences.
|
|
||||||
continue
|
|
||||||
if min_len <= prompt_len <= max_len:
|
if min_len <= prompt_len <= max_len:
|
||||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
filtered_requests.append(Request(prompt, prompt_len, output_len))
|
||||||
|
|
||||||
return filtered_dataset
|
return filtered_requests
|
||||||
|
|
||||||
|
|
||||||
def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
|
def sample_requests_from_random(
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
input_length_range: Tuple[int, int],
|
||||||
|
fixed_output_len: Optional[int],
|
||||||
|
prefix_len: int,
|
||||||
|
) -> List[Request]:
|
||||||
|
|
||||||
|
requests = []
|
||||||
|
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
|
||||||
|
min_len, max_len = input_length_range
|
||||||
|
|
||||||
|
for i in range(num_requests):
|
||||||
|
unique_part_token_ids = sample_tokens(
|
||||||
|
tokenizer,
|
||||||
|
random.randint(min_len - prefix_len, max_len - prefix_len))
|
||||||
|
prompt_token_ids = prefix_token_ids + unique_part_token_ids
|
||||||
|
prompt = tokenizer.decode(prompt_token_ids)
|
||||||
|
prompt_len = len(prompt_token_ids)
|
||||||
|
assert (min_len <= prompt_len <= max_len
|
||||||
|
), f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
|
||||||
|
requests.append(Request(prompt, prompt_len, fixed_output_len))
|
||||||
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_and_sort_requests(requests: List[Request],
|
||||||
repeat_count: int,
|
repeat_count: int,
|
||||||
sort: bool = False) -> List[str]:
|
sort: bool = False) -> List[str]:
|
||||||
repeated_requests = requests * repeat_count
|
repeated_requests = requests * repeat_count
|
||||||
@ -109,7 +150,7 @@ def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
|
|||||||
repeated_requests.sort(key=lambda x: x[1])
|
repeated_requests.sort(key=lambda x: x[1])
|
||||||
else:
|
else:
|
||||||
random.shuffle(repeated_requests)
|
random.shuffle(repeated_requests)
|
||||||
return [req[0] for req in repeated_requests]
|
return [req.prompt for req in repeated_requests]
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@ -117,9 +158,12 @@ def main(args):
|
|||||||
input_length_range = tuple(map(int, args.input_length_range.split(':')))
|
input_length_range = tuple(map(int, args.input_length_range.split(':')))
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
if args.dataset_path is not None:
|
if args.dataset_path is not None:
|
||||||
print(f"Start to sample {args.num_prompts} prompts"
|
if args.prefix_len > 0:
|
||||||
|
raise ValueError("prefix-len is not supported when "
|
||||||
|
"dataset-path is provided.")
|
||||||
|
print(f"Start to sample {args.num_prompts} prompts "
|
||||||
f"from {args.dataset_path}")
|
f"from {args.dataset_path}")
|
||||||
filtered_datasets = sample_requests(
|
filtered_requests = sample_requests_from_dataset(
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -127,9 +171,22 @@ def main(args):
|
|||||||
fixed_output_len=args.output_len,
|
fixed_output_len=args.output_len,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt_len = len(tokenizer(PROMPT).input_ids)
|
print(f"Start to sample {args.num_prompts} prompts from random")
|
||||||
filtered_datasets = [(PROMPT, prompt_len, args.output_len)
|
filtered_requests = sample_requests_from_random(
|
||||||
] * args.num_prompts
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
input_length_range=input_length_range,
|
||||||
|
fixed_output_len=args.output_len,
|
||||||
|
prefix_len=args.prefix_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print some helpful stats of the requests.
|
||||||
|
print(f"Sampled {len(filtered_requests)} requests.")
|
||||||
|
prompt_lens = [req.prompt_len for req in filtered_requests]
|
||||||
|
print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}")
|
||||||
|
print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}")
|
||||||
|
print(f"Min Prompt Length: {min(prompt_lens)}")
|
||||||
|
print(f"Max Prompt Length: {max(prompt_lens)}")
|
||||||
|
|
||||||
engine_args = EngineArgs.from_cli_args(args)
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
|
||||||
@ -137,8 +194,8 @@ def main(args):
|
|||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||||
|
|
||||||
print("Testing filtered datasets")
|
print("Testing filtered requests")
|
||||||
prompts = repeat_and_sort_requests(filtered_datasets,
|
prompts = repeat_and_sort_requests(filtered_requests,
|
||||||
repeat_count=args.repeat_count,
|
repeat_count=args.repeat_count,
|
||||||
sort=args.sort)
|
sort=args.sort)
|
||||||
|
|
||||||
@ -161,20 +218,29 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument('--output-len', type=int, default=10)
|
parser.add_argument('--output-len', type=int, default=10)
|
||||||
parser.add_argument('--num-prompts',
|
parser.add_argument('--num-prompts',
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
required=True,
|
||||||
help="Number of the prompts sampled from dataset")
|
help="Number of the prompts sampled from dataset")
|
||||||
parser.add_argument('--repeat-count',
|
parser.add_argument('--repeat-count',
|
||||||
type=int,
|
type=int,
|
||||||
default=100,
|
default=1,
|
||||||
help='Number of times to repeat each prompt')
|
help='Number of times to repeat each prompt')
|
||||||
parser.add_argument('--sort',
|
parser.add_argument('--sort',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Sort prompts by input length')
|
help='Sort prompts by input length')
|
||||||
parser.add_argument('--input-length-range',
|
parser.add_argument('--input-length-range',
|
||||||
type=str,
|
type=str,
|
||||||
default='128:256',
|
required=True,
|
||||||
help='Range of input lengths for sampling prompts,'
|
help='Range of input lengths for sampling prompts,'
|
||||||
'specified as "min:max" (e.g., "128:256").')
|
'specified as "min:max" (e.g., "128:256").')
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefix-len",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Specifies the length of a common prefix to be "
|
||||||
|
"added to the input prompt. The input-length-range will "
|
||||||
|
"subtract this length when filtering prompts. Only used "
|
||||||
|
"when dataset-path is not provided.",
|
||||||
|
)
|
||||||
|
|
||||||
parser = EngineArgs.add_cli_args(parser)
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user