mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 09:44:27 +08:00
[CI] Fix Bad_words test for tokenizer encode/decode asymmetry (#28193)
Signed-off-by: zhyajie <yajizhan@amd.com> Co-authored-by: zhyajie <yajizhan@amd.com>
This commit is contained in:
parent
3b221cb661
commit
48d15a32aa
@ -106,6 +106,25 @@ def test_detokenize_false(llm):
|
||||
def test_bad_words(llm):
|
||||
"""Check that we respect bad words."""
|
||||
|
||||
tokenizer = llm.get_tokenizer()
|
||||
|
||||
def contains_bad_word(text: str, tokens: list[int], bad_word: str) -> bool:
|
||||
"""Check if word appears in BOTH text and token sequence."""
|
||||
if bad_word not in text:
|
||||
return False
|
||||
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
bad_words_token = tokenizer.encode(
|
||||
prefix + bad_word.lstrip(), add_special_tokens=False
|
||||
)
|
||||
if not bad_words_token:
|
||||
continue
|
||||
for i in range(len(tokens) - len(bad_words_token) + 1):
|
||||
if tokens[i : i + len(bad_words_token)] == bad_words_token:
|
||||
return True
|
||||
return False
|
||||
|
||||
output = llm.generate(PROMPT, SamplingParams(temperature=0))
|
||||
split_text = output[0].outputs[0].text.split()
|
||||
|
||||
@ -113,14 +132,16 @@ def test_bad_words(llm):
|
||||
params = SamplingParams(temperature=0, bad_words=[bad_words_1])
|
||||
output = llm.generate(PROMPT, params)
|
||||
new_text = output[0].outputs[0].text
|
||||
assert bad_words_1 not in new_text
|
||||
new_tokens = output[0].outputs[0].token_ids
|
||||
assert not contains_bad_word(new_text, new_tokens, bad_words_1)
|
||||
|
||||
bad_words_2 = new_text.split()[-1]
|
||||
params = SamplingParams(temperature=0, bad_words=[bad_words_1, bad_words_2])
|
||||
output = llm.generate(PROMPT, params)
|
||||
new_text = output[0].outputs[0].text
|
||||
assert bad_words_1 not in new_text
|
||||
assert bad_words_2 not in new_text
|
||||
new_tokens = output[0].outputs[0].token_ids
|
||||
assert not contains_bad_word(new_text, new_tokens, bad_words_1)
|
||||
assert not contains_bad_word(new_text, new_tokens, bad_words_2)
|
||||
|
||||
|
||||
def test_logits_processor(llm):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user