mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 06:35:01 +08:00
[Bugfix] Fix bad words for Mistral models (#17753)
Signed-off-by: Qiong Zhou Huang <qiong@phonic.co>
This commit is contained in:
parent
597051e56f
commit
39956efb3f
@ -4,11 +4,12 @@ from typing import Callable, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor],
|
LogitsProcessor = Union[
|
||||||
Callable[[list[int], list[int], torch.Tensor],
|
Callable[[list[int], torch.Tensor], torch.Tensor],
|
||||||
torch.Tensor]]
|
Callable[[list[int], list[int], torch.Tensor], torch.Tensor],
|
||||||
|
]
|
||||||
"""LogitsProcessor is a function that takes a list
|
"""LogitsProcessor is a function that takes a list
|
||||||
of previously generated tokens, the logits tensor
|
of previously generated tokens, the logits tensor
|
||||||
for the next token and, optionally, prompt tokens as a
|
for the next token and, optionally, prompt tokens as a
|
||||||
@ -29,10 +30,6 @@ def get_bad_words_logits_processors(
|
|||||||
prefix = " " if add_prefix_space else ""
|
prefix = " " if add_prefix_space else ""
|
||||||
prompt = prefix + bad_word.lstrip()
|
prompt = prefix + bad_word.lstrip()
|
||||||
|
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
|
||||||
# Mistral tokenizers should not add special tokens
|
|
||||||
prompt_token_ids = tokenizer.encode(text=prompt)
|
|
||||||
else:
|
|
||||||
prompt_token_ids = tokenizer.encode(text=prompt,
|
prompt_token_ids = tokenizer.encode(text=prompt,
|
||||||
add_special_tokens=False)
|
add_special_tokens=False)
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,6 @@ from typing_extensions import deprecated
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logits_process import LogitsProcessor
|
from vllm.logits_process import LogitsProcessor
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -491,13 +490,8 @@ class SamplingParams(
|
|||||||
for add_prefix_space in [False, True]:
|
for add_prefix_space in [False, True]:
|
||||||
prefix = " " if add_prefix_space else ""
|
prefix = " " if add_prefix_space else ""
|
||||||
prompt = prefix + bad_word.lstrip()
|
prompt = prefix + bad_word.lstrip()
|
||||||
|
prompt_token_ids = tokenizer.encode(text=prompt,
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
add_special_tokens=False)
|
||||||
# Mistral tokenizers should not add special tokens
|
|
||||||
prompt_token_ids = tokenizer.encode(text=prompt)
|
|
||||||
else:
|
|
||||||
prompt_token_ids = tokenizer.encode(
|
|
||||||
text=prompt, add_special_tokens=False)
|
|
||||||
|
|
||||||
# If no space at the beginning
|
# If no space at the beginning
|
||||||
# or if prefix space produces a new word token
|
# or if prefix space produces a new word token
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user