[Bugfix] Fix bad words for Mistral models (#17753)

Signed-off-by: Qiong Zhou Huang <qiong@phonic.co>
This commit is contained in:
Qiong Zhou Huang 2025-05-07 23:32:10 -07:00 committed by GitHub
parent 597051e56f
commit 39956efb3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 18 deletions

View File

@ -4,11 +4,12 @@ from typing import Callable, Union
import torch import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor], LogitsProcessor = Union[
Callable[[list[int], list[int], torch.Tensor], Callable[[list[int], torch.Tensor], torch.Tensor],
torch.Tensor]] Callable[[list[int], list[int], torch.Tensor], torch.Tensor],
]
"""LogitsProcessor is a function that takes a list """LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a for the next token and, optionally, prompt tokens as a
@ -29,10 +30,6 @@ def get_bad_words_logits_processors(
prefix = " " if add_prefix_space else "" prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip() prompt = prefix + bad_word.lstrip()
if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(text=prompt, prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False) add_special_tokens=False)

View File

@ -13,7 +13,6 @@ from typing_extensions import deprecated
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
@ -491,13 +490,8 @@ class SamplingParams(
for add_prefix_space in [False, True]: for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else "" prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip() prompt = prefix + bad_word.lstrip()
prompt_token_ids = tokenizer.encode(text=prompt,
if isinstance(tokenizer, MistralTokenizer): add_special_tokens=False)
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(
text=prompt, add_special_tokens=False)
# If no space at the beginning # If no space at the beginning
# or if prefix space produces a new word token # or if prefix space produces a new word token