From 39956efb3f2a3457d29725916e0ba9cbb69841a3 Mon Sep 17 00:00:00 2001 From: Qiong Zhou Huang Date: Wed, 7 May 2025 23:32:10 -0700 Subject: [PATCH] [Bugfix] Fix bad words for Mistral models (#17753) Signed-off-by: Qiong Zhou Huang --- vllm/logits_process.py | 17 +++++++---------- vllm/sampling_params.py | 10 ++-------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/vllm/logits_process.py b/vllm/logits_process.py index e3faf20029ec9..29a73656bf65e 100644 --- a/vllm/logits_process.py +++ b/vllm/logits_process.py @@ -4,11 +4,12 @@ from typing import Callable, Union import torch -from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer -LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor], - Callable[[list[int], list[int], torch.Tensor], - torch.Tensor]] +LogitsProcessor = Union[ + Callable[[list[int], torch.Tensor], torch.Tensor], + Callable[[list[int], list[int], torch.Tensor], torch.Tensor], +] """LogitsProcessor is a function that takes a list of previously generated tokens, the logits tensor for the next token and, optionally, prompt tokens as a @@ -29,12 +30,8 @@ def get_bad_words_logits_processors( prefix = " " if add_prefix_space else "" prompt = prefix + bad_word.lstrip() - if isinstance(tokenizer, MistralTokenizer): - # Mistral tokenizers should not add special tokens - prompt_token_ids = tokenizer.encode(text=prompt) - else: - prompt_token_ids = tokenizer.encode(text=prompt, - add_special_tokens=False) + prompt_token_ids = tokenizer.encode(text=prompt, + add_special_tokens=False) # If no space at the beginning # or if prefix space produces a new word token diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 66a77681be9a3..affc5c64b9416 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -13,7 +13,6 @@ from typing_extensions import deprecated from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer logger = init_logger(__name__) @@ -491,13 +490,8 @@ class SamplingParams( for add_prefix_space in [False, True]: prefix = " " if add_prefix_space else "" prompt = prefix + bad_word.lstrip() - - if isinstance(tokenizer, MistralTokenizer): - # Mistral tokenizers should not add special tokens - prompt_token_ids = tokenizer.encode(text=prompt) - else: - prompt_token_ids = tokenizer.encode( - text=prompt, add_special_tokens=False) + prompt_token_ids = tokenizer.encode(text=prompt, + add_special_tokens=False) # If no space at the beginning # or if prefix space produces a new word token