mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[Bugfix] Fix bad words for Mistral models (#17753)
Signed-off-by: Qiong Zhou Huang <qiong@phonic.co>
This commit is contained in:
parent
597051e56f
commit
39956efb3f
@ -4,11 +4,12 @@ from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor],
|
||||
Callable[[list[int], list[int], torch.Tensor],
|
||||
torch.Tensor]]
|
||||
LogitsProcessor = Union[
|
||||
Callable[[list[int], torch.Tensor], torch.Tensor],
|
||||
Callable[[list[int], list[int], torch.Tensor], torch.Tensor],
|
||||
]
|
||||
"""LogitsProcessor is a function that takes a list
|
||||
of previously generated tokens, the logits tensor
|
||||
for the next token and, optionally, prompt tokens as a
|
||||
@ -29,12 +30,8 @@ def get_bad_words_logits_processors(
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# Mistral tokenizers should not add special tokens
|
||||
prompt_token_ids = tokenizer.encode(text=prompt)
|
||||
else:
|
||||
prompt_token_ids = tokenizer.encode(text=prompt,
|
||||
add_special_tokens=False)
|
||||
prompt_token_ids = tokenizer.encode(text=prompt,
|
||||
add_special_tokens=False)
|
||||
|
||||
# If no space at the beginning
|
||||
# or if prefix space produces a new word token
|
||||
|
||||
@ -13,7 +13,6 @@ from typing_extensions import deprecated
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -491,13 +490,8 @@ class SamplingParams(
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# Mistral tokenizers should not add special tokens
|
||||
prompt_token_ids = tokenizer.encode(text=prompt)
|
||||
else:
|
||||
prompt_token_ids = tokenizer.encode(
|
||||
text=prompt, add_special_tokens=False)
|
||||
prompt_token_ids = tokenizer.encode(text=prompt,
|
||||
add_special_tokens=False)
|
||||
|
||||
# If no space at the beginning
|
||||
# or if prefix space produces a new word token
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user