mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:55:40 +08:00
[Misc] Fix warnings for mistral model (#23552)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
2d0afcc9dc
commit
885ca6d31d
@ -15,7 +15,7 @@ from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
|||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
|
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import PixtralVisionConfig, TensorType
|
from transformers import BatchFeature, PixtralVisionConfig, TensorType
|
||||||
from transformers.image_utils import ImageInput
|
from transformers.image_utils import ImageInput
|
||||||
from transformers.models.pixtral.image_processing_pixtral import (
|
from transformers.models.pixtral.image_processing_pixtral import (
|
||||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
|
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
|
||||||
@ -163,10 +163,12 @@ class PixtralProcessorAdapter:
|
|||||||
images_processed.append(image_processed)
|
images_processed.append(image_processed)
|
||||||
images_tokens.append(image_tokens)
|
images_tokens.append(image_tokens)
|
||||||
|
|
||||||
return {
|
return BatchFeature({
|
||||||
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
"input_ids":
|
||||||
"images": images_processed,
|
torch.cat(images_tokens)[None].expand(len(text), -1),
|
||||||
}
|
"images":
|
||||||
|
images_processed,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
class PixtralProcessingInfo(BaseProcessingInfo):
|
class PixtralProcessingInfo(BaseProcessingInfo):
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio,
|
|||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
from mistral_common.protocol.transcription.request import TranscriptionRequest
|
from mistral_common.protocol.transcription.request import TranscriptionRequest
|
||||||
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
|
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
|
||||||
from transformers import TensorType, WhisperConfig
|
from transformers import BatchFeature, TensorType, WhisperConfig
|
||||||
from transformers.tokenization_utils_base import TextInput
|
from transformers.tokenization_utils_base import TextInput
|
||||||
|
|
||||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||||
@ -156,10 +156,12 @@ class VoxtralProcessorAdapter:
|
|||||||
audios_tokens.append(torch.tensor(audio_tokens))
|
audios_tokens.append(torch.tensor(audio_tokens))
|
||||||
audios_processed.append(torch.tensor(audio))
|
audios_processed.append(torch.tensor(audio))
|
||||||
|
|
||||||
return {
|
return BatchFeature({
|
||||||
"input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
|
"input_ids":
|
||||||
"audio_arrays": audios_processed,
|
torch.cat(audios_tokens)[None].expand(len(text), -1),
|
||||||
}
|
"audio_arrays":
|
||||||
|
audios_processed,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
class VoxtralProcessingInfo(BaseProcessingInfo):
|
class VoxtralProcessingInfo(BaseProcessingInfo):
|
||||||
|
|||||||
@ -204,18 +204,16 @@ class MistralTokenizer(TokenizerBase):
|
|||||||
self.version: int = int(_mistral_version_str.split("v")[-1])
|
self.version: int = int(_mistral_version_str.split("v")[-1])
|
||||||
|
|
||||||
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
||||||
from mistral_common.tokens.tokenizers.tekken import (
|
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||||
SpecialTokenPolicy, Tekkenizer)
|
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||||
|
|
||||||
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
|
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
|
||||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||||
SentencePieceTokenizer)
|
SentencePieceTokenizer)
|
||||||
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
|
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
|
||||||
if self.is_tekken:
|
self._special_token_policy = (SpecialTokenPolicy.IGNORE
|
||||||
# Make sure special tokens will not raise
|
if self.is_tekken else None)
|
||||||
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
|
if not (self.is_tekken or self.is_spm):
|
||||||
elif self.is_spm:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
||||||
|
|
||||||
self._vocab = tokenizer_.vocab()
|
self._vocab = tokenizer_.vocab()
|
||||||
@ -430,7 +428,8 @@ class MistralTokenizer(TokenizerBase):
|
|||||||
return self.tokenizer.unk_id
|
return self.tokenizer.unk_id
|
||||||
|
|
||||||
ids = [_token_to_id(t) for t in tokens]
|
ids = [_token_to_id(t) for t in tokens]
|
||||||
decoded = self.tokenizer.decode(ids)
|
decoded = self.tokenizer.decode(ids,
|
||||||
|
self._special_token_policy)
|
||||||
else:
|
else:
|
||||||
decoded = "".join(tokens)
|
decoded = "".join(tokens)
|
||||||
else:
|
else:
|
||||||
@ -444,7 +443,8 @@ class MistralTokenizer(TokenizerBase):
|
|||||||
if token in special_tokens:
|
if token in special_tokens:
|
||||||
if regular_tokens:
|
if regular_tokens:
|
||||||
decoded_list.append(
|
decoded_list.append(
|
||||||
self.tokenizer.decode(regular_tokens))
|
self.tokenizer.decode(regular_tokens,
|
||||||
|
self._special_token_policy))
|
||||||
regular_tokens = []
|
regular_tokens = []
|
||||||
decoded_list.append(token)
|
decoded_list.append(token)
|
||||||
else:
|
else:
|
||||||
@ -452,7 +452,8 @@ class MistralTokenizer(TokenizerBase):
|
|||||||
|
|
||||||
if regular_tokens:
|
if regular_tokens:
|
||||||
decoded_list.append(
|
decoded_list.append(
|
||||||
self.tokenizer.decode(regular_tokens)) # type: ignore
|
self.tokenizer.decode(regular_tokens,
|
||||||
|
self._special_token_policy))
|
||||||
|
|
||||||
decoded = ''.join(decoded_list)
|
decoded = ''.join(decoded_list)
|
||||||
|
|
||||||
@ -470,7 +471,7 @@ class MistralTokenizer(TokenizerBase):
|
|||||||
|
|
||||||
if isinstance(ids, int):
|
if isinstance(ids, int):
|
||||||
ids = [ids]
|
ids = [ids]
|
||||||
return self.tokenizer.decode(ids)
|
return self.tokenizer.decode(ids, self._special_token_policy)
|
||||||
|
|
||||||
def convert_ids_to_tokens(
|
def convert_ids_to_tokens(
|
||||||
self,
|
self,
|
||||||
@ -511,6 +512,9 @@ class MistralTokenizer(TokenizerBase):
|
|||||||
# See: https://github.com/vllm-project/vllm/pull/8640
|
# See: https://github.com/vllm-project/vllm/pull/8640
|
||||||
# https://github.com/vllm-project/vllm/pull/9625
|
# https://github.com/vllm-project/vllm/pull/9625
|
||||||
# if underlying tokenizeir is sentencepiece, we just add "<22>"
|
# if underlying tokenizeir is sentencepiece, we just add "<22>"
|
||||||
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
|
tokens = [
|
||||||
|
self.tokenizer.id_to_byte_piece(id, self._special_token_policy)
|
||||||
|
for id in ids
|
||||||
|
]
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user