From 48d15a32aa567dfc59ede46683b01cc2321579cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=B0=E5=85=AE?= <38908462+zhyajie@users.noreply.github.com> Date: Tue, 2 Dec 2025 16:02:12 +0800 Subject: [PATCH] [CI] Fix Bad_words test for tokenizer encode/decode asymmetry (#28193) Signed-off-by: zhyajie Co-authored-by: zhyajie --- tests/v1/sample/test_sampling_params_e2e.py | 27 ++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index 1684252174d3d..a75a37befe0e1 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -106,6 +106,25 @@ def test_detokenize_false(llm): def test_bad_words(llm): """Check that we respect bad words.""" + tokenizer = llm.get_tokenizer() + + def contains_bad_word(text: str, tokens: list[int], bad_word: str) -> bool: + """Check if word appears in BOTH text and token sequence.""" + if bad_word not in text: + return False + + for add_prefix_space in [False, True]: + prefix = " " if add_prefix_space else "" + bad_words_token = tokenizer.encode( + prefix + bad_word.lstrip(), add_special_tokens=False + ) + if not bad_words_token: + continue + for i in range(len(tokens) - len(bad_words_token) + 1): + if tokens[i : i + len(bad_words_token)] == bad_words_token: + return True + return False + output = llm.generate(PROMPT, SamplingParams(temperature=0)) split_text = output[0].outputs[0].text.split() @@ -113,14 +132,16 @@ def test_bad_words(llm): params = SamplingParams(temperature=0, bad_words=[bad_words_1]) output = llm.generate(PROMPT, params) new_text = output[0].outputs[0].text - assert bad_words_1 not in new_text + new_tokens = output[0].outputs[0].token_ids + assert not contains_bad_word(new_text, new_tokens, bad_words_1) bad_words_2 = new_text.split()[-1] params = SamplingParams(temperature=0, bad_words=[bad_words_1, bad_words_2]) output = llm.generate(PROMPT, params) new_text = output[0].outputs[0].text - assert bad_words_1 not in new_text - assert bad_words_2 not in new_text + new_tokens = output[0].outputs[0].token_ids + assert not contains_bad_word(new_text, new_tokens, bad_words_1) + assert not contains_bad_word(new_text, new_tokens, bad_words_2) def test_logits_processor(llm):