mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:45:36 +08:00
[Bugfix] Fix erroneous randomly generated cases in bad word testing (#22170)
Signed-off-by: phantomlei <phantomlei3@gmail.com>
This commit is contained in:
parent
8d17fa633e
commit
bc8372efc3
@ -90,6 +90,27 @@ def _create_bad_words_token_ids(
|
|||||||
return bad_words_token_ids
|
return bad_words_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
# Returns all last tokens of bad word sequences that share the same prefix
|
||||||
|
# as `given_prefix` (excluding the last token).
|
||||||
|
def _collect_suffixes_with_same_prefix(
|
||||||
|
given_prefix: list[int],
|
||||||
|
bad_words_token_ids: list[list[int]]) -> list[int]:
|
||||||
|
return [bwt[-1] for bwt in bad_words_token_ids if bwt[:-1] == given_prefix]
|
||||||
|
|
||||||
|
|
||||||
|
# generate a valid token id that is not in bad_words_token_ids
|
||||||
|
def _generate_valid_token_id(bad_words_token_ids: list[list[int]],
|
||||||
|
vocab_size: int) -> int:
|
||||||
|
forbidden_start_tokens = set()
|
||||||
|
for bad_word in bad_words_token_ids:
|
||||||
|
forbidden_start_tokens.add(bad_word[0])
|
||||||
|
# Get a safe token that's not in forbidden starts
|
||||||
|
safe_token_candidates = list(
|
||||||
|
set(range(vocab_size)) - forbidden_start_tokens)
|
||||||
|
# Pick a random safe token
|
||||||
|
return np.random.choice(safe_token_candidates)
|
||||||
|
|
||||||
|
|
||||||
def _update_output_token_ids_for_bad_words(
|
def _update_output_token_ids_for_bad_words(
|
||||||
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
|
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
|
||||||
bad_words_last_tokens = {}
|
bad_words_last_tokens = {}
|
||||||
@ -104,12 +125,17 @@ def _update_output_token_ids_for_bad_words(
|
|||||||
prefix_length = len(bad_word_token_ids) - 1
|
prefix_length = len(bad_word_token_ids) - 1
|
||||||
has_bad_words = np.random.choice([True, False])
|
has_bad_words = np.random.choice([True, False])
|
||||||
if has_bad_words:
|
if has_bad_words:
|
||||||
output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
|
prefix = bad_word_token_ids[:-1]
|
||||||
bad_words_last_token.append(bad_word_token_ids[-1])
|
output_token_ids[-prefix_length:] = prefix
|
||||||
|
# Collect all last tokens from other bad words
|
||||||
|
# that share this prefix
|
||||||
|
bad_words_last_token.extend(
|
||||||
|
_collect_suffixes_with_same_prefix(
|
||||||
|
prefix, bad_words_token_ids))
|
||||||
break # Maximum one update to output_token_ids
|
break # Maximum one update to output_token_ids
|
||||||
else: # Make sure no accidental match to bad words
|
else: # Make sure no accidental match to bad words
|
||||||
output_token_ids[-1] = (bad_word_token_ids[-2] +
|
output_token_ids[-1] = _generate_valid_token_id(
|
||||||
1) % vocab_size
|
bad_words_token_ids, vocab_size)
|
||||||
bad_words_last_tokens[batch_idx] = bad_words_last_token
|
bad_words_last_tokens[batch_idx] = bad_words_last_token
|
||||||
return bad_words_last_tokens
|
return bad_words_last_tokens
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user