From 885ca6d31db8816ee08e3fa634fbb58add289898 Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Fri, 29 Aug 2025 14:58:48 +0800 Subject: [PATCH] [Misc] Fix warnings for mistral model (#23552) Signed-off-by: zjy0516 Signed-off-by: Jiangyun Zhu Co-authored-by: Patrick von Platen --- vllm/model_executor/models/pixtral.py | 12 ++++---- vllm/model_executor/models/voxtral.py | 12 ++++---- vllm/transformers_utils/tokenizers/mistral.py | 30 +++++++++++-------- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index a74e01a59697e..e7f5799a80067 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -15,7 +15,7 @@ from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from PIL import Image -from transformers import PixtralVisionConfig, TensorType +from transformers import BatchFeature, PixtralVisionConfig, TensorType from transformers.image_utils import ImageInput from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens as _get_pixtral_hf_num_image_tokens) @@ -163,10 +163,12 @@ class PixtralProcessorAdapter: images_processed.append(image_processed) images_tokens.append(image_tokens) - return { - "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), - "images": images_processed, - } + return BatchFeature({ + "input_ids": + torch.cat(images_tokens)[None].expand(len(text), -1), + "images": + images_processed, + }) class PixtralProcessingInfo(BaseProcessingInfo): diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index eed8d89ca4f5a..6bc748407a7d1 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -17,7 +17,7 @@ from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio, from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder -from transformers import TensorType, WhisperConfig +from transformers import BatchFeature, TensorType, WhisperConfig from transformers.tokenization_utils_base import TextInput from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig @@ -156,10 +156,12 @@ class VoxtralProcessorAdapter: audios_tokens.append(torch.tensor(audio_tokens)) audios_processed.append(torch.tensor(audio)) - return { - "input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1), - "audio_arrays": audios_processed, - } + return BatchFeature({ + "input_ids": + torch.cat(audios_tokens)[None].expand(len(text), -1), + "audio_arrays": + audios_processed, + }) class VoxtralProcessingInfo(BaseProcessingInfo): diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 4dd8b2439b3f5..f545993a5a980 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -204,18 +204,16 @@ class MistralTokenizer(TokenizerBase): self.version: int = int(_mistral_version_str.split("v")[-1]) tokenizer_ = tokenizer.instruct_tokenizer.tokenizer - from mistral_common.tokens.tokenizers.tekken import ( - SpecialTokenPolicy, Tekkenizer) + from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + self.is_tekken = isinstance(tokenizer_, Tekkenizer) from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer) self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) - if self.is_tekken: - # Make sure special tokens will not raise - tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE - elif self.is_spm: - pass - else: + self._special_token_policy = (SpecialTokenPolicy.IGNORE + if self.is_tekken else None) + if not (self.is_tekken or self.is_spm): raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") self._vocab = tokenizer_.vocab() @@ -430,7 +428,8 @@ class MistralTokenizer(TokenizerBase): return self.tokenizer.unk_id ids = [_token_to_id(t) for t in tokens] - decoded = self.tokenizer.decode(ids) + decoded = self.tokenizer.decode(ids, + self._special_token_policy) else: decoded = "".join(tokens) else: @@ -444,7 +443,8 @@ class MistralTokenizer(TokenizerBase): if token in special_tokens: if regular_tokens: decoded_list.append( - self.tokenizer.decode(regular_tokens)) + self.tokenizer.decode(regular_tokens, + self._special_token_policy)) regular_tokens = [] decoded_list.append(token) else: @@ -452,7 +452,8 @@ class MistralTokenizer(TokenizerBase): if regular_tokens: decoded_list.append( - self.tokenizer.decode(regular_tokens)) # type: ignore + self.tokenizer.decode(regular_tokens, + self._special_token_policy)) decoded = ''.join(decoded_list) @@ -470,7 +471,7 @@ class MistralTokenizer(TokenizerBase): if isinstance(ids, int): ids = [ids] - return self.tokenizer.decode(ids) + return self.tokenizer.decode(ids, self._special_token_policy) def convert_ids_to_tokens( self, @@ -511,6 +512,9 @@ class MistralTokenizer(TokenizerBase): # See: https://github.com/vllm-project/vllm/pull/8640 # https://github.com/vllm-project/vllm/pull/9625 # if underlying tokenizeir is sentencepiece, we just add "�" - tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids] + tokens = [ + self.tokenizer.id_to_byte_piece(id, self._special_token_policy) + for id in ids + ] return tokens