[Bugfix] Prevent benchmark_throughput.py from using duplicated random prompts (#10753)

This commit is contained in:
Michael Goin 2024-12-02 21:26:15 -05:00 committed by GitHub
parent 4c05edb33a
commit 4433195ab7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -294,23 +294,36 @@ def main(args: argparse.Namespace):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code) args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None: if args.dataset is None:
# Synthesize a prompt with the given input length. vocab_size = tokenizer.vocab_size
# As tokenizer may add additional tokens like BOS, we need to try requests = []
# different lengths to get the desired input length. for _ in range(args.num_prompts):
for i in range(-10, 10): # Synthesize a prompt with the given input length.
prompt = "hi " * (args.input_len + i) candidate_ids = [
tokenized_prompt = tokenizer(prompt).input_ids random.randint(0, vocab_size - 1)
if len(tokenized_prompt) == args.input_len: for _ in range(args.input_len)
break ]
else: # As tokenizer may add additional tokens like BOS, we need to try
raise ValueError( # different lengths to get the desired input length.
f"Failed to synthesize a prompt with {args.input_len} tokens.") for _ in range(5): # Max attempts to correct
requests = [ candidate_prompt = tokenizer.decode(candidate_ids)
SampleRequest(prompt=prompt, tokenized_len = len(tokenizer.encode(candidate_prompt))
prompt_len=args.input_len,
expected_output_len=args.output_len) if tokenized_len == args.input_len:
for _ in range(args.num_prompts) break
]
# Adjust length based on difference
diff = args.input_len - tokenized_len
if diff > 0:
candidate_ids.extend([
random.randint(100, vocab_size - 100)
for _ in range(diff)
])
else:
candidate_ids = candidate_ids[:diff]
requests.append(
SampleRequest(prompt=candidate_prompt,
prompt_len=args.input_len,
expected_output_len=args.output_len))
else: else:
requests = sample_requests(tokenizer, args) requests = sample_requests(tokenizer, args)