[Bugfix] Fix bad words for Mistral models (#17753)

Signed-off-by: Qiong Zhou Huang <qiong@phonic.co>
This commit is contained in:
Qiong Zhou Huang 2025-05-07 23:32:10 -07:00 committed by GitHub
parent 597051e56f
commit 39956efb3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 18 deletions

View File

@ -4,11 +4,12 @@ from typing import Callable, Union
import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor],
Callable[[list[int], list[int], torch.Tensor],
torch.Tensor]]
LogitsProcessor = Union[
Callable[[list[int], torch.Tensor], torch.Tensor],
Callable[[list[int], list[int], torch.Tensor], torch.Tensor],
]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
@ -29,12 +30,8 @@ def get_bad_words_logits_processors(
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)
# If no space at the beginning
# or if prefix space produces a new word token

View File

@ -13,7 +13,6 @@ from typing_extensions import deprecated
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
logger = init_logger(__name__)
@ -491,13 +490,8 @@ class SamplingParams(
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(
text=prompt, add_special_tokens=False)
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)
# If no space at the beginning
# or if prefix space produces a new word token