mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 17:25:51 +08:00
[Bugfix] Prevent benchmark_throughput.py from using duplicated random prompts (#10753)
This commit is contained in:
parent
4c05edb33a
commit
4433195ab7
@ -294,23 +294,36 @@ def main(args: argparse.Namespace):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
if args.dataset is None:
|
||||
# Synthesize a prompt with the given input length.
|
||||
# As tokenizer may add additional tokens like BOS, we need to try
|
||||
# different lengths to get the desired input length.
|
||||
for i in range(-10, 10):
|
||||
prompt = "hi " * (args.input_len + i)
|
||||
tokenized_prompt = tokenizer(prompt).input_ids
|
||||
if len(tokenized_prompt) == args.input_len:
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Failed to synthesize a prompt with {args.input_len} tokens.")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=args.input_len,
|
||||
expected_output_len=args.output_len)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
vocab_size = tokenizer.vocab_size
|
||||
requests = []
|
||||
for _ in range(args.num_prompts):
|
||||
# Synthesize a prompt with the given input length.
|
||||
candidate_ids = [
|
||||
random.randint(0, vocab_size - 1)
|
||||
for _ in range(args.input_len)
|
||||
]
|
||||
# As tokenizer may add additional tokens like BOS, we need to try
|
||||
# different lengths to get the desired input length.
|
||||
for _ in range(5): # Max attempts to correct
|
||||
candidate_prompt = tokenizer.decode(candidate_ids)
|
||||
tokenized_len = len(tokenizer.encode(candidate_prompt))
|
||||
|
||||
if tokenized_len == args.input_len:
|
||||
break
|
||||
|
||||
# Adjust length based on difference
|
||||
diff = args.input_len - tokenized_len
|
||||
if diff > 0:
|
||||
candidate_ids.extend([
|
||||
random.randint(100, vocab_size - 100)
|
||||
for _ in range(diff)
|
||||
])
|
||||
else:
|
||||
candidate_ids = candidate_ids[:diff]
|
||||
requests.append(
|
||||
SampleRequest(prompt=candidate_prompt,
|
||||
prompt_len=args.input_len,
|
||||
expected_output_len=args.output_len))
|
||||
else:
|
||||
requests = sample_requests(tokenizer, args)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user