mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-26 06:44:26 +08:00
Fix random dataset mismatched token length with config. (#24937)
Signed-off-by: Weiliang Liu <weiliangl@nvidia.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
6dee906d2c
commit
02ab3860a6
@ -366,11 +366,67 @@ def process_video(video: Any) -> Mapping[str, Any]:
|
|||||||
f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
|
f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_prompt_decode_to_target_len(
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
token_sequence: list[int],
|
||||||
|
target_token_len: int,
|
||||||
|
max_retry: int = 10,
|
||||||
|
add_special_tokens: bool = False,
|
||||||
|
rng: Optional[np.random.Generator] = None,
|
||||||
|
) -> tuple[str, list[int]]:
|
||||||
|
"""
|
||||||
|
Ensure decoded-then-encoded prompt length matches the target token length.
|
||||||
|
|
||||||
|
This function decodes an initial token sequence to text and re-encodes it
|
||||||
|
, iteratively adjusting the token sequence length to match a target.
|
||||||
|
This is necessary because some tokenizers do not guarantee a 1:1 mapping
|
||||||
|
between consecutive tokens and the decoded-then-encoded sequence length.
|
||||||
|
For example, for GPT2Tokenizer:
|
||||||
|
[6880, 6881] -> ['Ġcalls', 'here'] ->
|
||||||
|
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
||||||
|
|
||||||
|
Returns a tuple of the final prompt string and the adjusted token sequence.
|
||||||
|
"""
|
||||||
|
remain_num_try = max_retry
|
||||||
|
token_mismatch = 0
|
||||||
|
while True:
|
||||||
|
prompt = tokenizer.decode(token_sequence)
|
||||||
|
token_sequence = tokenizer.encode(
|
||||||
|
prompt, add_special_tokens=add_special_tokens
|
||||||
|
)
|
||||||
|
if remain_num_try <= 0:
|
||||||
|
if len(token_sequence) != target_token_len:
|
||||||
|
token_mismatch = len(token_sequence) - target_token_len
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(token_sequence) == target_token_len:
|
||||||
|
break
|
||||||
|
elif len(token_sequence) < target_token_len:
|
||||||
|
if rng is not None:
|
||||||
|
extra_tokens = rng.integers(
|
||||||
|
0,
|
||||||
|
tokenizer.vocab_size,
|
||||||
|
size=target_token_len - len(token_sequence),
|
||||||
|
).tolist()
|
||||||
|
else:
|
||||||
|
extra_tokens = np.random.randint(
|
||||||
|
0,
|
||||||
|
tokenizer.vocab_size,
|
||||||
|
size=target_token_len - len(token_sequence),
|
||||||
|
).tolist()
|
||||||
|
token_sequence.extend(extra_tokens)
|
||||||
|
elif len(token_sequence) > target_token_len:
|
||||||
|
token_sequence = token_sequence[:target_token_len]
|
||||||
|
|
||||||
|
remain_num_try -= 1
|
||||||
|
|
||||||
|
return prompt, token_sequence, token_mismatch
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Random Dataset Implementation (Synthetic Data)
|
# Random Dataset Implementation (Synthetic Data)
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class RandomDataset(BenchmarkDataset):
|
class RandomDataset(BenchmarkDataset):
|
||||||
"""
|
"""
|
||||||
Synthetic text-only dataset for serving/throughput benchmarks.
|
Synthetic text-only dataset for serving/throughput benchmarks.
|
||||||
@ -420,8 +476,9 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
|
token_mismatch_total = 0
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
prompt, total_input_len = self.generate_token_sequence(
|
prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
prefix_token_ids=prefix_token_ids,
|
prefix_token_ids=prefix_token_ids,
|
||||||
prefix_len=prefix_len,
|
prefix_len=prefix_len,
|
||||||
@ -430,6 +487,7 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
offset=int(offsets[i]),
|
offset=int(offsets[i]),
|
||||||
index=i,
|
index=i,
|
||||||
)
|
)
|
||||||
|
token_mismatch_total += token_mismatch
|
||||||
requests.append(
|
requests.append(
|
||||||
SampleRequest(
|
SampleRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -453,6 +511,18 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
requests = batch_requests
|
requests = batch_requests
|
||||||
|
|
||||||
|
if token_mismatch_total != 0:
|
||||||
|
sign = "more" if token_mismatch_total > 0 else "fewer"
|
||||||
|
logger.warning(
|
||||||
|
"Across all generated prompts, there were %d %s tokens "
|
||||||
|
"than expected after decoding and re-encoding. This is "
|
||||||
|
"expected due to the imperfect nature of the sampling "
|
||||||
|
"procedure.",
|
||||||
|
abs(token_mismatch_total),
|
||||||
|
sign,
|
||||||
|
)
|
||||||
|
|
||||||
return requests
|
return requests
|
||||||
|
|
||||||
def get_prefix(
|
def get_prefix(
|
||||||
@ -530,7 +600,7 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
input_len: int,
|
input_len: int,
|
||||||
offset: int,
|
offset: int,
|
||||||
index: int,
|
index: int,
|
||||||
) -> tuple[str, int]:
|
) -> tuple[str, int, int]:
|
||||||
"""
|
"""
|
||||||
Returns (prompt, total_input_len).
|
Returns (prompt, total_input_len).
|
||||||
|
|
||||||
@ -549,15 +619,16 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
token_sequence = prefix_token_ids + inner_seq
|
token_sequence = prefix_token_ids + inner_seq
|
||||||
|
|
||||||
# Decode, then re-encode and truncate to preserve token count invariants
|
# Decode, then re-encode and truncate to preserve token count invariants
|
||||||
prompt = tokenizer.decode(token_sequence)
|
|
||||||
total_input_len = prefix_len + int(input_len)
|
total_input_len = prefix_len + int(input_len)
|
||||||
|
prompt, adjusted_token_sequence, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501
|
||||||
re_encoded_sequence = tokenizer.encode(
|
tokenizer=tokenizer,
|
||||||
prompt, add_special_tokens=False)[:total_input_len]
|
token_sequence=token_sequence,
|
||||||
prompt = tokenizer.decode(re_encoded_sequence)
|
target_token_len=total_input_len,
|
||||||
total_input_len = len(re_encoded_sequence)
|
add_special_tokens=False,
|
||||||
|
rng=self._rng,
|
||||||
return prompt, total_input_len
|
)
|
||||||
|
total_input_len = len(adjusted_token_sequence)
|
||||||
|
return prompt, total_input_len, token_mismatch
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@ -873,8 +944,9 @@ class RandomMultiModalDataset(RandomDataset):
|
|||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
# Add synthetic multimodal items to each request
|
# Add synthetic multimodal items to each request
|
||||||
mm_requests = []
|
mm_requests = []
|
||||||
|
token_mismatch_total = 0
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
prompt, total_input_len = self.generate_token_sequence(
|
prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
prefix_token_ids=prefix_token_ids,
|
prefix_token_ids=prefix_token_ids,
|
||||||
prefix_len=prefix_len,
|
prefix_len=prefix_len,
|
||||||
@ -883,6 +955,7 @@ class RandomMultiModalDataset(RandomDataset):
|
|||||||
offset=int(offsets[i]),
|
offset=int(offsets[i]),
|
||||||
index=i,
|
index=i,
|
||||||
)
|
)
|
||||||
|
token_mismatch_total += token_mismatch
|
||||||
# Get multimodal item iterator for a given request
|
# Get multimodal item iterator for a given request
|
||||||
mm_item_iterator = self.get_mm_item_iterator(
|
mm_item_iterator = self.get_mm_item_iterator(
|
||||||
min_num_mm_items,
|
min_num_mm_items,
|
||||||
@ -918,6 +991,18 @@ class RandomMultiModalDataset(RandomDataset):
|
|||||||
request_id=request_id_prefix + str(i),
|
request_id=request_id_prefix + str(i),
|
||||||
)
|
)
|
||||||
mm_requests.append(sample_request)
|
mm_requests.append(sample_request)
|
||||||
|
|
||||||
|
if token_mismatch_total != 0:
|
||||||
|
sign = "more" if token_mismatch_total > 0 else "fewer"
|
||||||
|
logger.warning(
|
||||||
|
"Across all generated prompts, there were %d %s tokens "
|
||||||
|
"than expected after decoding and re-encoding. This is "
|
||||||
|
"expected due to the imperfect nature of the sampling "
|
||||||
|
"procedure.",
|
||||||
|
abs(token_mismatch_total),
|
||||||
|
sign,
|
||||||
|
)
|
||||||
|
|
||||||
return mm_requests
|
return mm_requests
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@ -2694,27 +2779,23 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
|||||||
# Generate random tokens
|
# Generate random tokens
|
||||||
tokens = np.random.randint(
|
tokens = np.random.randint(
|
||||||
0, vocab_size, size=target_length).tolist()
|
0, vocab_size, size=target_length).tolist()
|
||||||
text = tokenizer.decode(tokens)
|
|
||||||
re_encoded = tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
|
|
||||||
if len(re_encoded) == target_length:
|
_, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501
|
||||||
return re_encoded
|
tokenizer=tokenizer,
|
||||||
elif len(re_encoded) < target_length:
|
token_sequence=tokens,
|
||||||
# Recursively generate additional consistent tokens
|
target_token_len=target_length,
|
||||||
needed = target_length - len(re_encoded)
|
add_special_tokens=False,
|
||||||
extra_tokens = _generate_exact_length_tokens(needed)
|
)
|
||||||
return re_encoded + extra_tokens
|
return adjusted_tokens, token_mismatch
|
||||||
else:
|
|
||||||
# Truncate to target length
|
|
||||||
return re_encoded[:target_length]
|
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
|
token_mismatch_total = 0
|
||||||
for _ in range(num_prefixes):
|
for _ in range(num_prefixes):
|
||||||
prefix_tokens = _generate_exact_length_tokens(prefix_len)
|
prefix_tokens = _generate_exact_length_tokens(prefix_len)
|
||||||
|
|
||||||
for _ in range(prompts_per_prefix):
|
for _ in range(prompts_per_prefix):
|
||||||
suffix_tokens = _generate_exact_length_tokens(suffix_len)
|
suffix_tokens, token_mistmatch = _generate_exact_length_tokens(suffix_len) # noqa: E501
|
||||||
|
token_mismatch_total += token_mistmatch
|
||||||
combined_tokens = prefix_tokens + suffix_tokens
|
combined_tokens = prefix_tokens + suffix_tokens
|
||||||
prompt = tokenizer.decode(combined_tokens)
|
prompt = tokenizer.decode(combined_tokens)
|
||||||
prompt_len = len(combined_tokens)
|
prompt_len = len(combined_tokens)
|
||||||
@ -2726,6 +2807,16 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if token_mismatch_total != 0:
|
||||||
|
sign = "more" if token_mismatch_total > 0 else "fewer"
|
||||||
|
logger.warning(
|
||||||
|
"Across all generated prompts, there were %d %s tokens "
|
||||||
|
"than expected after decoding and re-encoding. This is "
|
||||||
|
"expected due to the imperfect nature of the sampling "
|
||||||
|
"procedure.",
|
||||||
|
abs(token_mismatch_total),
|
||||||
|
sign,
|
||||||
|
)
|
||||||
random.shuffle(requests)
|
random.shuffle(requests)
|
||||||
return requests
|
return requests
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user