mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:26:00 +08:00
[Misc] Add fully interleaved support for multimodal 'string' content format (#14047)
Signed-off-by: drobyshev.anton <drobyshev.anton@wb.ru> Co-authored-by: drobyshev.anton <drobyshev.anton@wb.ru>
This commit is contained in:
parent
22dd9c2730
commit
e601efcb10
@ -2,11 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Literal, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
||||
parse_chat_messages,
|
||||
@ -15,7 +18,8 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
||||
resolve_hf_chat_template)
|
||||
from vllm.entrypoints.llm import apply_hf_chat_template
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
||||
encode_video_base64)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
@ -28,6 +32,7 @@ ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
|
||||
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
||||
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
@ -48,6 +53,21 @@ def phi3v_model_config():
|
||||
})
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def phi3v_model_config_mm_interleaved():
|
||||
return ModelConfig(PHI3V_MODEL_ID,
|
||||
task="generate",
|
||||
tokenizer=PHI3V_MODEL_ID,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="auto",
|
||||
seed=0,
|
||||
interleave_mm_strings=True,
|
||||
limit_mm_per_prompt={
|
||||
"image": 2,
|
||||
})
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def phi3v_tokenizer():
|
||||
return TokenizerGroup(
|
||||
@ -58,6 +78,32 @@ def phi3v_tokenizer():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def qwen25omni_model_config_mm_interleaved():
|
||||
return ModelConfig(QWEN25OMNI_MODEL_ID,
|
||||
task="generate",
|
||||
tokenizer=QWEN25OMNI_MODEL_ID,
|
||||
tokenizer_mode="auto",
|
||||
dtype="auto",
|
||||
seed=0,
|
||||
interleave_mm_strings=True,
|
||||
limit_mm_per_prompt={
|
||||
"image": 2,
|
||||
"audio": 1,
|
||||
"video": 1,
|
||||
})
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def qwen25omni_tokenizer():
|
||||
return TokenizerGroup(
|
||||
tokenizer_id=QWEN25OMNI_MODEL_ID,
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mllama_model_config():
|
||||
return ModelConfig(MLLAMA_MODEL_ID,
|
||||
@ -113,6 +159,20 @@ def image_url():
|
||||
return f"data:image/jpeg;base64,{base64}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def video_url():
|
||||
video = VideoAsset('baby_reading', 1)
|
||||
base64 = encode_video_base64(video.np_ndarrays)
|
||||
return f"data:video/jpeg;base64,{base64}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def audio_url():
|
||||
audio = AudioAsset('mary_had_lamb')
|
||||
base64 = encode_audio_base64(*audio.audio_and_sample_rate)
|
||||
return f"data:audio/ogg;base64,{base64}"
|
||||
|
||||
|
||||
def _assert_mm_data_is_image_input(
|
||||
mm_data: Optional[MultiModalDataDict],
|
||||
image_count: int,
|
||||
@ -126,6 +186,23 @@ def _assert_mm_data_is_image_input(
|
||||
assert isinstance(image_data, list) and len(image_data) == image_count
|
||||
|
||||
|
||||
ModalityType = Literal["image", "video", "audio"]
|
||||
MultiModalDataCounts = Mapping[ModalityType, int]
|
||||
|
||||
|
||||
def _assert_mm_data_inputs(
|
||||
mm_data: Optional[MultiModalDataDict],
|
||||
data_count: MultiModalDataCounts,
|
||||
) -> None:
|
||||
assert mm_data is not None
|
||||
assert set(data_count.keys()) == (set(mm_data.keys()))
|
||||
|
||||
for modality, n in data_count.items():
|
||||
modality_data = mm_data.get(modality)
|
||||
assert modality_data is not None
|
||||
assert isinstance(modality_data, list) and len(modality_data) == n
|
||||
|
||||
|
||||
def test_parse_chat_messages_single_image(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
@ -637,6 +714,277 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
|
||||
_assert_mm_data_is_image_input(mm_data, 2)
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_images_interleave(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data = parse_chat_messages(
|
||||
[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "I need you to compare this image"
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "and this one"
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "Do they have differences?"
|
||||
}]
|
||||
}],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
assert conversation == [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501
|
||||
"Do they have differences?"
|
||||
}]
|
||||
_assert_mm_data_is_image_input(mm_data, 2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_interleave_async(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data = parse_chat_messages_futures(
|
||||
[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "I need you to compare this image"
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "and this one"
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "Do they have differences?"
|
||||
}]
|
||||
}],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
assert conversation == [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501
|
||||
"Do they have differences?"
|
||||
}]
|
||||
_assert_mm_data_is_image_input(await mm_data, 2)
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data = parse_chat_messages(
|
||||
[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's on this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Be accurate."
|
||||
},
|
||||
]
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "Some stuff."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "What's on this image?"
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}]
|
||||
}],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
assert conversation == [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What's on this image?\n<|image_1|>\nBe accurate."
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "Some stuff."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "What's on this image?\n<|image_2|>"
|
||||
}]
|
||||
_assert_mm_data_is_image_input(mm_data, 2)
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
|
||||
qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer,
|
||||
image_url, video_url, audio_url):
|
||||
conversation, mm_data = parse_chat_messages(
|
||||
[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's on this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Now listen to this audio"
|
||||
},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_url
|
||||
}
|
||||
},
|
||||
]
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "Some stuff."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "What's on this image?"
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "And what's in the video?"
|
||||
}, {
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": video_url
|
||||
}
|
||||
}]
|
||||
}],
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
assert conversation == [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n"
|
||||
"Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "Some stuff."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n"
|
||||
"And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>"
|
||||
}]
|
||||
|
||||
_assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1})
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Found more '<|image_1|>' placeholders in input prompt "
|
||||
"than actual multimodal data items."):
|
||||
parse_chat_messages(
|
||||
[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type":
|
||||
"text",
|
||||
"text":
|
||||
"I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501
|
||||
"Do they have differences?"
|
||||
},
|
||||
]
|
||||
}],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
|
||||
### Mllama currently wraps images / texts as interleaved dictionaries
|
||||
def test_mllama_single_image(
|
||||
mllama_model_config,
|
||||
|
||||
@ -346,6 +346,9 @@ class ModelConfig:
|
||||
limit_mm_per_prompt: dict[str, int] = field(default_factory=dict)
|
||||
"""Maximum number of data items per modality per prompt. Only applicable
|
||||
for multimodal models."""
|
||||
interleave_mm_strings: bool = False
|
||||
"""Enable fully interleaved support for multimodal prompts, while using
|
||||
--chat-template-content-format=string. Defaults to False."""
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
"""Additional args passed to process media inputs, keyed by modalities.
|
||||
For example, to set num_frames for video, set
|
||||
@ -702,7 +705,8 @@ class ModelConfig:
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
disable_mm_preprocessor_cache=self.
|
||||
disable_mm_preprocessor_cache)
|
||||
disable_mm_preprocessor_cache,
|
||||
interleave_mm_strings=self.interleave_mm_strings)
|
||||
|
||||
if self.limit_mm_per_prompt:
|
||||
raise ValueError("`limit_mm_per_prompt` is only supported for "
|
||||
@ -713,6 +717,9 @@ class ModelConfig:
|
||||
if self.disable_mm_preprocessor_cache:
|
||||
raise ValueError("`disable_mm_preprocessor_cache` is only "
|
||||
"supported for multimodal models.")
|
||||
if self.interleave_mm_strings:
|
||||
raise ValueError("`interleave_mm_strings` is only "
|
||||
"supported for multimodal models.")
|
||||
|
||||
return None
|
||||
|
||||
@ -3126,6 +3133,11 @@ class MultiModalConfig:
|
||||
If `True`, disable caching of the processed multi-modal inputs.
|
||||
"""
|
||||
|
||||
interleave_mm_strings: bool = False
|
||||
"""
|
||||
Enable fully interleaved support for multimodal prompts.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
@ -370,6 +370,7 @@ class EngineArgs:
|
||||
get_field(TokenizerPoolConfig, "extra_config")
|
||||
limit_mm_per_prompt: dict[str, int] = \
|
||||
get_field(MultiModalConfig, "limit_per_prompt")
|
||||
interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
|
||||
media_io_kwargs: dict[str, dict[str,
|
||||
Any]] = get_field(MultiModalConfig,
|
||||
"media_io_kwargs")
|
||||
@ -763,6 +764,9 @@ class EngineArgs:
|
||||
multimodal_group.add_argument(
|
||||
"--disable-mm-preprocessor-cache",
|
||||
**multimodal_kwargs["disable_mm_preprocessor_cache"])
|
||||
multimodal_group.add_argument(
|
||||
"--interleave-mm-strings",
|
||||
**multimodal_kwargs["interleave_mm_strings"])
|
||||
|
||||
# LoRA related configs
|
||||
lora_kwargs = get_kwargs(LoRAConfig)
|
||||
@ -981,6 +985,7 @@ class EngineArgs:
|
||||
enable_prompt_embeds=self.enable_prompt_embeds,
|
||||
served_model_name=self.served_model_name,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
interleave_mm_strings=self.interleave_mm_strings,
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
use_async_output_proc=not self.disable_async_output_proc,
|
||||
config_format=self.config_format,
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
from collections import Counter, defaultdict, deque
|
||||
from collections.abc import Awaitable, Iterable
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from pathlib import Path
|
||||
@ -52,6 +52,12 @@ from vllm.utils import deprecate_kwargs, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MODALITY_PLACEHOLDERS_MAP = {
|
||||
"image": "<##IMAGE##>",
|
||||
"audio": "<##AUDIO##>",
|
||||
"video": "<##VIDEO##>",
|
||||
}
|
||||
|
||||
|
||||
class AudioURL(TypedDict, total=False):
|
||||
url: Required[str]
|
||||
@ -354,6 +360,7 @@ def resolve_mistral_chat_template(
|
||||
"so it will be ignored.")
|
||||
return None
|
||||
|
||||
|
||||
@deprecate_kwargs(
|
||||
"trust_remote_code",
|
||||
additional_message="Please use `model_config.trust_remote_code` instead.",
|
||||
@ -633,15 +640,22 @@ class BaseMultiModalContentParser(ABC):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# multimodal placeholder_string : count
|
||||
self._placeholder_counts: dict[str, int] = defaultdict(lambda: 0)
|
||||
# stores model placehodlers list with corresponding
|
||||
# general MM placeholder:
|
||||
# {
|
||||
# "<##IMAGE##>": ["<image>", "<image>", "<image>"],
|
||||
# "<##AUDIO##>": ["<audio>", "<audio>"]
|
||||
# }
|
||||
self._placeholder_storage: dict[str, list] = defaultdict(list)
|
||||
|
||||
def _add_placeholder(self, placeholder: Optional[str]):
|
||||
def _add_placeholder(self, modality: ModalityStr,
|
||||
placeholder: Optional[str]):
|
||||
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
|
||||
if placeholder:
|
||||
self._placeholder_counts[placeholder] += 1
|
||||
self._placeholder_storage[mod_placeholder].append(placeholder)
|
||||
|
||||
def mm_placeholder_counts(self) -> dict[str, int]:
|
||||
return dict(self._placeholder_counts)
|
||||
def mm_placeholder_storage(self) -> dict[str, list]:
|
||||
return dict(self._placeholder_storage)
|
||||
|
||||
@abstractmethod
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
@ -685,7 +699,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
image = self._connector.fetch_image(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_embeds(self,
|
||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||
@ -700,17 +714,17 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||
placeholder = self._tracker.add("image_embeds", embedding)
|
||||
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_pil(self, image_pil: Image.Image) -> None:
|
||||
placeholder = self._tracker.add("image", image_pil)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio = self._connector.fetch_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
audio_data = input_audio.get("data", "")
|
||||
@ -723,7 +737,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
video = self._connector.fetch_video(video_url=video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("video", placeholder)
|
||||
|
||||
|
||||
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
@ -741,7 +755,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
image_coro = self._connector.fetch_image_async(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_embeds(self,
|
||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||
@ -760,20 +774,20 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
future.set_result(embedding)
|
||||
|
||||
placeholder = self._tracker.add("image_embeds", future)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_pil(self, image_pil: Image.Image) -> None:
|
||||
future: asyncio.Future[Image.Image] = asyncio.Future()
|
||||
future.set_result(image_pil)
|
||||
|
||||
placeholder = self._tracker.add("image", future)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio_coro = self._connector.fetch_audio_async(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
audio_data = input_audio.get("data", "")
|
||||
@ -786,7 +800,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
video = self._connector.fetch_video_async(video_url=video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("video", placeholder)
|
||||
|
||||
|
||||
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||
@ -856,12 +870,40 @@ def load_chat_template(
|
||||
return _cached_load_chat_template(chat_template, is_literal=is_literal)
|
||||
|
||||
|
||||
def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
|
||||
texts: list[str]) -> str:
|
||||
for idx, elem in enumerate(texts):
|
||||
if elem in placeholder_storage:
|
||||
texts[idx] = placeholder_storage[elem].pop(0)
|
||||
|
||||
return "\n".join(texts)
|
||||
|
||||
|
||||
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||
# (similar to chat template)
|
||||
def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
|
||||
text_prompt: str) -> str:
|
||||
def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
|
||||
texts: list[str],
|
||||
interleave_strings: bool
|
||||
) -> str:
|
||||
"""Combine multimodal prompts for a multimodal language model."""
|
||||
|
||||
# flatten storage to make it looks like
|
||||
# {
|
||||
# "<|image|>": 2,
|
||||
# "<|audio|>": 1
|
||||
# }
|
||||
placeholder_counts = Counter(
|
||||
[v for elem in placeholder_storage.values() for v in elem]
|
||||
)
|
||||
|
||||
if interleave_strings:
|
||||
text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
|
||||
else:
|
||||
text_prompt = "\n".join(texts)
|
||||
|
||||
# Pass interleaved text further in case the user used image placeholders
|
||||
# himself, but forgot to disable the 'interleave_strings' flag
|
||||
|
||||
# Look through the text prompt to check for missing placeholders
|
||||
missing_placeholders: list[str] = []
|
||||
for placeholder in placeholder_counts:
|
||||
@ -870,6 +912,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
|
||||
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
|
||||
|
||||
if placeholder_counts[placeholder] < 0:
|
||||
logger.error(
|
||||
"Placeholder count is negative! "
|
||||
"Ensure that the 'interleave_strings' flag is disabled "
|
||||
"(current value: %s) "
|
||||
"when manually placing image placeholders.", interleave_strings
|
||||
)
|
||||
logger.debug("Input prompt: %s", text_prompt)
|
||||
raise ValueError(
|
||||
f"Found more '{placeholder}' placeholders in input prompt than "
|
||||
"actual multimodal data items.")
|
||||
@ -877,8 +926,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
|
||||
missing_placeholders.extend([placeholder] *
|
||||
placeholder_counts[placeholder])
|
||||
|
||||
# NOTE: For now we always add missing placeholders at the front of
|
||||
# the prompt. This may change to be customizable in the future.
|
||||
# NOTE: Default behaviour: we always add missing placeholders
|
||||
# at the front of the prompt, if interleave_strings=False
|
||||
return "\n".join(missing_placeholders + [text_prompt])
|
||||
|
||||
|
||||
@ -988,6 +1037,7 @@ def _parse_chat_message_content_parts(
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
*,
|
||||
wrap_dicts: bool,
|
||||
interleave_strings: bool,
|
||||
) -> list[ConversationMessage]:
|
||||
content = list[_ContentPart]()
|
||||
|
||||
@ -998,6 +1048,7 @@ def _parse_chat_message_content_parts(
|
||||
part,
|
||||
mm_parser,
|
||||
wrap_dicts=wrap_dicts,
|
||||
interleave_strings=interleave_strings
|
||||
)
|
||||
if parse_res:
|
||||
content.append(parse_res)
|
||||
@ -1007,11 +1058,14 @@ def _parse_chat_message_content_parts(
|
||||
return [ConversationMessage(role=role,
|
||||
content=content)] # type: ignore
|
||||
texts = cast(list[str], content)
|
||||
text_prompt = "\n".join(texts)
|
||||
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
|
||||
if mm_placeholder_counts:
|
||||
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
|
||||
text_prompt)
|
||||
mm_placeholder_storage = mm_parser.mm_placeholder_storage()
|
||||
if mm_placeholder_storage:
|
||||
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage,
|
||||
texts,
|
||||
interleave_strings)
|
||||
else:
|
||||
text_prompt = "\n".join(texts)
|
||||
|
||||
return [ConversationMessage(role=role, content=text_prompt)]
|
||||
|
||||
|
||||
@ -1020,6 +1074,7 @@ def _parse_chat_message_content_part(
|
||||
mm_parser: BaseMultiModalContentParser,
|
||||
*,
|
||||
wrap_dicts: bool,
|
||||
interleave_strings: bool,
|
||||
) -> Optional[_ContentPart]:
|
||||
"""Parses a single part of a conversation. If wrap_dicts is True,
|
||||
structured dictionary pieces for texts and images will be
|
||||
@ -1049,34 +1104,37 @@ def _parse_chat_message_content_part(
|
||||
else:
|
||||
return str_content
|
||||
|
||||
modality = None
|
||||
if part_type == "image_pil":
|
||||
image_content = cast(Image.Image, content)
|
||||
mm_parser.parse_image_pil(image_content)
|
||||
return {'type': 'image'} if wrap_dicts else None
|
||||
if part_type == "image_url":
|
||||
modality = "image"
|
||||
elif part_type == "image_url":
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_image(str_content)
|
||||
return {'type': 'image'} if wrap_dicts else None
|
||||
if part_type == "image_embeds":
|
||||
modality = "image"
|
||||
elif part_type == "image_embeds":
|
||||
content = cast(Union[str, dict[str, str]], content)
|
||||
mm_parser.parse_image_embeds(content)
|
||||
return {'type': 'image'} if wrap_dicts else None
|
||||
if part_type == "audio_url":
|
||||
modality = "image"
|
||||
elif part_type == "audio_url":
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_audio(str_content)
|
||||
return {'type': 'audio'} if wrap_dicts else None
|
||||
|
||||
if part_type == "input_audio":
|
||||
modality = "audio"
|
||||
elif part_type == "input_audio":
|
||||
dict_content = cast(InputAudio, content)
|
||||
mm_parser.parse_input_audio(dict_content)
|
||||
return {'type': 'audio'} if wrap_dicts else None
|
||||
|
||||
if part_type == "video_url":
|
||||
modality = "audio"
|
||||
elif part_type == "video_url":
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_video(str_content)
|
||||
return {'type': 'video'} if wrap_dicts else None
|
||||
modality = "video"
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
return {'type': modality} if wrap_dicts else (
|
||||
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
|
||||
)
|
||||
|
||||
|
||||
# No need to validate using Pydantic again
|
||||
@ -1088,6 +1146,7 @@ def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
interleave_strings: bool,
|
||||
) -> list[ConversationMessage]:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
@ -1103,6 +1162,7 @@ def _parse_chat_message_content(
|
||||
content, # type: ignore
|
||||
mm_tracker,
|
||||
wrap_dicts=(content_format == "openai"),
|
||||
interleave_strings=interleave_strings,
|
||||
)
|
||||
|
||||
for result_msg in result:
|
||||
@ -1155,6 +1215,11 @@ def parse_chat_messages(
|
||||
msg,
|
||||
mm_tracker,
|
||||
content_format,
|
||||
interleave_strings=(
|
||||
content_format == "string"
|
||||
and model_config.multimodal_config is not None
|
||||
and model_config.multimodal_config.interleave_mm_strings
|
||||
)
|
||||
)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
@ -1178,6 +1243,11 @@ def parse_chat_messages_futures(
|
||||
msg,
|
||||
mm_tracker,
|
||||
content_format,
|
||||
interleave_strings=(
|
||||
content_format == "string"
|
||||
and model_config.multimodal_config is not None
|
||||
and model_config.multimodal_config.interleave_mm_strings
|
||||
)
|
||||
)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user