diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index f0c0d829a393b..807e543dc0cde 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -366,11 +366,67 @@ def process_video(video: Any) -> Mapping[str, Any]: f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501 ) + +def gen_prompt_decode_to_target_len( + tokenizer: PreTrainedTokenizerBase, + token_sequence: list[int], + target_token_len: int, + max_retry: int = 10, + add_special_tokens: bool = False, + rng: Optional[np.random.Generator] = None, +) -> tuple[str, list[int]]: + """ + Ensure decoded-then-encoded prompt length matches the target token length. + + This function decodes an initial token sequence to text and re-encodes it + , iteratively adjusting the token sequence length to match a target. + This is necessary because some tokenizers do not guarantee a 1:1 mapping + between consecutive tokens and the decoded-then-encoded sequence length. + For example, for GPT2Tokenizer: + [6880, 6881] -> ['Ġcalls', 'here'] -> + [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] + + Returns a tuple of the final prompt string and the adjusted token sequence. + """ + remain_num_try = max_retry + token_mismatch = 0 + while True: + prompt = tokenizer.decode(token_sequence) + token_sequence = tokenizer.encode( + prompt, add_special_tokens=add_special_tokens + ) + if remain_num_try <= 0: + if len(token_sequence) != target_token_len: + token_mismatch = len(token_sequence) - target_token_len + break + + if len(token_sequence) == target_token_len: + break + elif len(token_sequence) < target_token_len: + if rng is not None: + extra_tokens = rng.integers( + 0, + tokenizer.vocab_size, + size=target_token_len - len(token_sequence), + ).tolist() + else: + extra_tokens = np.random.randint( + 0, + tokenizer.vocab_size, + size=target_token_len - len(token_sequence), + ).tolist() + token_sequence.extend(extra_tokens) + elif len(token_sequence) > target_token_len: + token_sequence = token_sequence[:target_token_len] + + remain_num_try -= 1 + + return prompt, token_sequence, token_mismatch + # ----------------------------------------------------------------------------- # Random Dataset Implementation (Synthetic Data) # ----------------------------------------------------------------------------- - class RandomDataset(BenchmarkDataset): """ Synthetic text-only dataset for serving/throughput benchmarks. @@ -420,8 +476,9 @@ class RandomDataset(BenchmarkDataset): vocab_size = tokenizer.vocab_size requests = [] + token_mismatch_total = 0 for i in range(num_requests): - prompt, total_input_len = self.generate_token_sequence( + prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501 tokenizer=tokenizer, prefix_token_ids=prefix_token_ids, prefix_len=prefix_len, @@ -430,6 +487,7 @@ class RandomDataset(BenchmarkDataset): offset=int(offsets[i]), index=i, ) + token_mismatch_total += token_mismatch requests.append( SampleRequest( prompt=prompt, @@ -453,6 +511,18 @@ class RandomDataset(BenchmarkDataset): ) ) requests = batch_requests + + if token_mismatch_total != 0: + sign = "more" if token_mismatch_total > 0 else "fewer" + logger.warning( + "Across all generated prompts, there were %d %s tokens " + "than expected after decoding and re-encoding. This is " + "expected due to the imperfect nature of the sampling " + "procedure.", + abs(token_mismatch_total), + sign, + ) + return requests def get_prefix( @@ -530,7 +600,7 @@ class RandomDataset(BenchmarkDataset): input_len: int, offset: int, index: int, - ) -> tuple[str, int]: + ) -> tuple[str, int, int]: """ Returns (prompt, total_input_len). @@ -549,15 +619,16 @@ class RandomDataset(BenchmarkDataset): token_sequence = prefix_token_ids + inner_seq # Decode, then re-encode and truncate to preserve token count invariants - prompt = tokenizer.decode(token_sequence) total_input_len = prefix_len + int(input_len) - - re_encoded_sequence = tokenizer.encode( - prompt, add_special_tokens=False)[:total_input_len] - prompt = tokenizer.decode(re_encoded_sequence) - total_input_len = len(re_encoded_sequence) - - return prompt, total_input_len + prompt, adjusted_token_sequence, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501 + tokenizer=tokenizer, + token_sequence=token_sequence, + target_token_len=total_input_len, + add_special_tokens=False, + rng=self._rng, + ) + total_input_len = len(adjusted_token_sequence) + return prompt, total_input_len, token_mismatch # ----------------------------------------------------------------------------- @@ -873,8 +944,9 @@ class RandomMultiModalDataset(RandomDataset): vocab_size = tokenizer.vocab_size # Add synthetic multimodal items to each request mm_requests = [] + token_mismatch_total = 0 for i in range(num_requests): - prompt, total_input_len = self.generate_token_sequence( + prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501 tokenizer=tokenizer, prefix_token_ids=prefix_token_ids, prefix_len=prefix_len, @@ -883,6 +955,7 @@ class RandomMultiModalDataset(RandomDataset): offset=int(offsets[i]), index=i, ) + token_mismatch_total += token_mismatch # Get multimodal item iterator for a given request mm_item_iterator = self.get_mm_item_iterator( min_num_mm_items, @@ -918,6 +991,18 @@ class RandomMultiModalDataset(RandomDataset): request_id=request_id_prefix + str(i), ) mm_requests.append(sample_request) + + if token_mismatch_total != 0: + sign = "more" if token_mismatch_total > 0 else "fewer" + logger.warning( + "Across all generated prompts, there were %d %s tokens " + "than expected after decoding and re-encoding. This is " + "expected due to the imperfect nature of the sampling " + "procedure.", + abs(token_mismatch_total), + sign, + ) + return mm_requests # ----------------------------------------------------------------------------- @@ -2694,27 +2779,23 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset): # Generate random tokens tokens = np.random.randint( 0, vocab_size, size=target_length).tolist() - text = tokenizer.decode(tokens) - re_encoded = tokenizer.encode(text, add_special_tokens=False) - if len(re_encoded) == target_length: - return re_encoded - elif len(re_encoded) < target_length: - # Recursively generate additional consistent tokens - needed = target_length - len(re_encoded) - extra_tokens = _generate_exact_length_tokens(needed) - return re_encoded + extra_tokens - else: - # Truncate to target length - return re_encoded[:target_length] + _, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501 + tokenizer=tokenizer, + token_sequence=tokens, + target_token_len=target_length, + add_special_tokens=False, + ) + return adjusted_tokens, token_mismatch requests = [] + token_mismatch_total = 0 for _ in range(num_prefixes): prefix_tokens = _generate_exact_length_tokens(prefix_len) for _ in range(prompts_per_prefix): - suffix_tokens = _generate_exact_length_tokens(suffix_len) - + suffix_tokens, token_mistmatch = _generate_exact_length_tokens(suffix_len) # noqa: E501 + token_mismatch_total += token_mistmatch combined_tokens = prefix_tokens + suffix_tokens prompt = tokenizer.decode(combined_tokens) prompt_len = len(combined_tokens) @@ -2726,6 +2807,16 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset): ) ) + if token_mismatch_total != 0: + sign = "more" if token_mismatch_total > 0 else "fewer" + logger.warning( + "Across all generated prompts, there were %d %s tokens " + "than expected after decoding and re-encoding. This is " + "expected due to the imperfect nature of the sampling " + "procedure.", + abs(token_mismatch_total), + sign, + ) random.shuffle(requests) return requests