diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index ea10661ea113..31c6c881d7b8 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -90,6 +90,27 @@ def _create_bad_words_token_ids( return bad_words_token_ids +# Returns all last tokens of bad word sequences that share the same prefix +# as `given_prefix` (excluding the last token). +def _collect_suffixes_with_same_prefix( + given_prefix: list[int], + bad_words_token_ids: list[list[int]]) -> list[int]: + return [bwt[-1] for bwt in bad_words_token_ids if bwt[:-1] == given_prefix] + + +# generate a valid token id that is not in bad_words_token_ids +def _generate_valid_token_id(bad_words_token_ids: list[list[int]], + vocab_size: int) -> int: + forbidden_start_tokens = set() + for bad_word in bad_words_token_ids: + forbidden_start_tokens.add(bad_word[0]) + # Get a safe token that's not in forbidden starts + safe_token_candidates = list( + set(range(vocab_size)) - forbidden_start_tokens) + # Pick a random safe token + return np.random.choice(safe_token_candidates) + + def _update_output_token_ids_for_bad_words( metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]: bad_words_last_tokens = {} @@ -104,12 +125,17 @@ def _update_output_token_ids_for_bad_words( prefix_length = len(bad_word_token_ids) - 1 has_bad_words = np.random.choice([True, False]) if has_bad_words: - output_token_ids[-prefix_length:] = bad_word_token_ids[:-1] - bad_words_last_token.append(bad_word_token_ids[-1]) + prefix = bad_word_token_ids[:-1] + output_token_ids[-prefix_length:] = prefix + # Collect all last tokens from other bad words + # that share this prefix + bad_words_last_token.extend( + _collect_suffixes_with_same_prefix( + prefix, bad_words_token_ids)) break # Maximum one update to output_token_ids else: # Make sure no accidental match to bad words - output_token_ids[-1] = (bad_word_token_ids[-2] + - 1) % vocab_size + output_token_ids[-1] = _generate_valid_token_id( + bad_words_token_ids, vocab_size) bad_words_last_tokens[batch_idx] = bad_words_last_token return bad_words_last_tokens