[CI] Fix Bad_words test for tokenizer encode/decode asymmetry (#28193)

Signed-off-by: zhyajie <yajizhan@amd.com>
Co-authored-by: zhyajie <yajizhan@amd.com>
This commit is contained in:
杰兮 2025-12-02 16:02:12 +08:00 committed by GitHub
parent 3b221cb661
commit 48d15a32aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -106,6 +106,25 @@ def test_detokenize_false(llm):
def test_bad_words(llm):
"""Check that we respect bad words."""
tokenizer = llm.get_tokenizer()
def contains_bad_word(text: str, tokens: list[int], bad_word: str) -> bool:
"""Check if word appears in BOTH text and token sequence."""
if bad_word not in text:
return False
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
bad_words_token = tokenizer.encode(
prefix + bad_word.lstrip(), add_special_tokens=False
)
if not bad_words_token:
continue
for i in range(len(tokens) - len(bad_words_token) + 1):
if tokens[i : i + len(bad_words_token)] == bad_words_token:
return True
return False
output = llm.generate(PROMPT, SamplingParams(temperature=0))
split_text = output[0].outputs[0].text.split()
@ -113,14 +132,16 @@ def test_bad_words(llm):
params = SamplingParams(temperature=0, bad_words=[bad_words_1])
output = llm.generate(PROMPT, params)
new_text = output[0].outputs[0].text
assert bad_words_1 not in new_text
new_tokens = output[0].outputs[0].token_ids
assert not contains_bad_word(new_text, new_tokens, bad_words_1)
bad_words_2 = new_text.split()[-1]
params = SamplingParams(temperature=0, bad_words=[bad_words_1, bad_words_2])
output = llm.generate(PROMPT, params)
new_text = output[0].outputs[0].text
assert bad_words_1 not in new_text
assert bad_words_2 not in new_text
new_tokens = output[0].outputs[0].token_ids
assert not contains_bad_word(new_text, new_tokens, bad_words_1)
assert not contains_bad_word(new_text, new_tokens, bad_words_2)
def test_logits_processor(llm):