[Frontend][VLM] Add support for multiple multi-modal items (#8049)

This commit is contained in:
Roger Wang 2024-08-31 16:35:53 -07:00 committed by GitHub
parent 8423aef4c8
commit 5231f0898e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 524 additions and 136 deletions

View File

@ -90,6 +90,7 @@ steps:
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/openai - pytest -v -s entrypoints/openai
- pytest -v -s entrypoints/test_chat_utils.py
- label: Distributed Tests (4 GPUs) # 10min - label: Distributed Tests (4 GPUs) # 10min
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"

View File

@ -1,7 +1,13 @@
"""An example showing how to use vLLM to serve VLMs. """An example showing how to use vLLM to serve VLMs.
Launch the vLLM server with the following command: Launch the vLLM server with the following command:
(single image inference with Llava)
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
(multi-image inference with Phi-3.5-vision-instruct)
vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \
--trust-remote-code --limit-mm-per-prompt image=2
""" """
import base64 import base64
@ -84,3 +90,36 @@ chat_completion_from_base64 = client.chat.completions.create(
result = chat_completion_from_base64.choices[0].message.content result = chat_completion_from_base64.choices[0].message.content
print(f"Chat completion output:{result}") print(f"Chat completion output:{result}")
# Multi-image input inference
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
chat_completion_from_url = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What are the animals in these images?"
},
{
"type": "image_url",
"image_url": {
"url": image_url_duck
},
},
{
"type": "image_url",
"image_url": {
"url": image_url_lion
},
},
],
}],
model=model,
max_tokens=64,
)
result = chat_completion_from_url.choices[0].message.content
print(f"Chat completion output:{result}")

View File

@ -3,6 +3,7 @@ from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from unittest.mock import MagicMock from unittest.mock import MagicMock
from vllm.config import MultiModalConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
@ -20,6 +21,7 @@ class MockModelConfig:
max_model_len = 100 max_model_len = 100
tokenizer_revision = None tokenizer_revision = None
embedding_mode = False embedding_mode = False
multimodal_config = MultiModalConfig()
@dataclass @dataclass

View File

@ -6,11 +6,10 @@ import pytest_asyncio
from vllm.multimodal.utils import encode_image_base64, fetch_image from vllm.multimodal.utils import encode_image_base64, fetch_image
from ...utils import VLLM_PATH, RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "llava-hf/llava-1.5-7b-hf" MODEL_NAME = "microsoft/Phi-3.5-vision-instruct"
LLAVA_CHAT_TEMPLATE = VLLM_PATH / "examples/template_llava.jinja" MAXIMUM_IMAGES = 2
assert LLAVA_CHAT_TEMPLATE.exists()
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [ TEST_IMAGE_URLS = [
@ -24,13 +23,9 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = [ args = [
"--dtype", "--dtype", "bfloat16", "--max-model-len", "4096", "--max-num-seqs",
"bfloat16", "5", "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt",
"--max-model-len", f"image={MAXIMUM_IMAGES}"
"4096",
"--enforce-eager",
"--chat-template",
str(LLAVA_CHAT_TEMPLATE),
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@ -84,7 +79,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=596, total_tokens=606) completion_tokens=10, prompt_tokens=772, total_tokens=782)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
@ -139,7 +134,7 @@ async def test_single_chat_session_image_base64encoded(
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=596, total_tokens=606) completion_tokens=10, prompt_tokens=772, total_tokens=782)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
@ -217,26 +212,22 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize(
"image_urls",
[TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))])
async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
image_url: str): image_urls: List[str]):
messages = [{ messages = [{
"role": "role":
"user", "user",
"content": [ "content": [
{ *({
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": image_url "url": image_url
} }
}, } for image_url in image_urls),
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{ {
"type": "text", "type": "text",
"text": "What's in this image?" "text": "What's in this image?"
@ -244,20 +235,30 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
], ],
}] }]
with pytest.raises(openai.BadRequestError): # test multi-image input if len(image_urls) > MAXIMUM_IMAGES:
await client.chat.completions.create( with pytest.raises(openai.BadRequestError): # test multi-image input
await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
)
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0
else:
chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
temperature=0.0, temperature=0.0,
) )
message = chat_completion.choices[0].message
# the server should still work afterwards assert message.content is not None and len(message.content) >= 0
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0

View File

@ -0,0 +1,305 @@
import warnings
import pytest
from PIL import Image
from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import parse_chat_messages
from vllm.multimodal.utils import encode_image_base64
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
@pytest.fixture(scope="module")
def phi3v_model_config():
return ModelConfig(PHI3V_MODEL_ID,
PHI3V_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
seed=0,
limit_mm_per_prompt={
"image": 2,
})
@pytest.fixture(scope="module")
def phi3v_tokenizer():
return TokenizerGroup(
tokenizer_id=PHI3V_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
@pytest.fixture(scope="module")
def image_url():
image = ImageAsset('cherry_blossom')
base64 = encode_image_base64(image.pil_image)
return f"data:image/jpeg;base64,{base64}"
@pytest.mark.asyncio
async def test_parse_chat_messages_with_image_url(phi3v_model_config,
phi3v_tokenizer, image_url):
conversation, mm_future = parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in the image?"
}]
}], phi3v_model_config, phi3v_tokenizer)
assert conversation == [{
"role": "user",
"content": "<|image_1|>\nWhat's in the image?"
}]
mm_data = await mm_future
assert set(mm_data.keys()) == {"image"}
assert isinstance(mm_data["image"], Image.Image)
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images(phi3v_model_config,
phi3v_tokenizer, image_url):
conversation, mm_future = parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in these images?"
}]
}], phi3v_model_config, phi3v_tokenizer)
assert conversation == [{
"role":
"user",
"content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
mm_data = await mm_future
assert set(mm_data.keys()) == {"image"}
assert len(mm_data["image"]) == 2
@pytest.mark.asyncio
async def test_parse_chat_messages_placeholder_already_in_prompt(
phi3v_model_config, phi3v_tokenizer, image_url):
conversation, mm_future = parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type":
"text",
"text":
"What's in <|image_1|> and how does it compare to <|image_2|>?"
}]
}], phi3v_model_config, phi3v_tokenizer)
assert conversation == [{
"role":
"user",
"content":
"What's in <|image_1|> and how does it compare to <|image_2|>?"
}]
mm_data = await mm_future
assert set(mm_data.keys()) == {"image"}
assert len(mm_data["image"]) == 2
@pytest.mark.asyncio
async def test_parse_chat_messages_placeholder_one_already_in_prompt(
phi3v_model_config, phi3v_tokenizer, image_url):
conversation, mm_future = parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type":
"text",
"text":
"What's in <|image_1|> and how does it compare to the other one?"
}]
}], phi3v_model_config, phi3v_tokenizer)
assert conversation == [{
"role":
"user",
"content":
"<|image_2|>\nWhat's in <|image_1|> and how does it compare to the "
"other one?"
}]
mm_data = await mm_future
assert set(mm_data.keys()) == {"image"}
assert len(mm_data["image"]) == 2
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_across_messages(
phi3v_model_config, phi3v_tokenizer, image_url):
conversation, mm_future = parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in this image?"
}]
}, {
"role": "assistant",
"content": "Some stuff."
}, {
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What about this one?"
}]
}], phi3v_model_config, phi3v_tokenizer)
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\nWhat's in this image?"
},
{
"role": "assistant",
"content": "Some stuff."
},
{
"role": "user",
"content": "<|image_2|>\nWhat about this one?"
},
]
mm_data = await mm_future
assert set(mm_data.keys()) == {"image"}
assert len(mm_data["image"]) == 2
@pytest.mark.asyncio
async def test_parse_chat_messages_rejects_too_many_images_in_one_message(
phi3v_model_config, phi3v_tokenizer, image_url):
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="coroutine 'async_get_and_parse_image' was never awaited")
with pytest.raises(
ValueError,
match="At most 2 image\\(s\\) may be provided in one request\\."
):
parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in these images?"
}]
}], phi3v_model_config, phi3v_tokenizer)
@pytest.mark.asyncio
async def test_parse_chat_messages_rejects_too_many_images_across_messages(
phi3v_model_config, phi3v_tokenizer, image_url):
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="coroutine 'async_get_and_parse_image' was never awaited")
with pytest.raises(
ValueError,
match="At most 2 image\\(s\\) may be provided in one request\\."
):
parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in this image?"
}]
}, {
"role": "assistant",
"content": "Some stuff."
}, {
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What about these two?"
}]
}], phi3v_model_config, phi3v_tokenizer)

View File

@ -1,9 +1,10 @@
import asyncio
import codecs import codecs
from dataclasses import dataclass from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple, from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Mapping,
Union) Optional, Tuple, Union)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -80,10 +81,90 @@ class ConversationMessage(TypedDict):
content: str content: str
@dataclass(frozen=True) class MultiModalItemTracker:
class ChatMessageParseResult: """
messages: List[ConversationMessage] Tracks multi-modal items in a given request and ensures that the number
mm_futures: List[Awaitable[MultiModalDataDict]] of multi-modal items in a given request does not exceed the configured
maximum per prompt.
"""
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
self._model_config = model_config
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}
self._futures: List[Awaitable[MultiModalDataDict]] = []
@staticmethod
@lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int):
return tokenizer.decode(token_index)
def add(self, modality: Literal["image", "audio"],
mm_future: Awaitable[MultiModalDataDict]) -> Optional[str]:
"""
Adds the multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")
self._consumed_items[modality] = current_count
self._futures.append(mm_future)
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = self._model_config.hf_config.model_type
if modality == "image":
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return MultiModalItemTracker._cached_token_str(
self._tokenizer,
self._model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")
@staticmethod
async def _combine(futures: List[Awaitable[MultiModalDataDict]]):
mm_lists: Mapping[str, List[object]] = defaultdict(list)
# Merge all the multi-modal items
for single_mm_data in (await asyncio.gather(*futures)):
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
else:
mm_lists[mm_key].append(mm_item)
# Unpack any single item lists for models that don't expect multiple.
return {
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
for mm_key, mm_list in mm_lists.items()
}
def all_mm_data(self) -> Optional[Awaitable[MultiModalDataDict]]:
return MultiModalItemTracker._combine(
self._futures) if self._futures else None
def load_chat_template( def load_chat_template(
@ -112,44 +193,30 @@ def load_chat_template(
return resolved_chat_template return resolved_chat_template
@lru_cache(maxsize=None)
def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = model_config.hf_config.model_type
if modality == "image":
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")
# TODO: Let user specify how to insert multimodal tokens into prompt # TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template) # (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_token_str: str, def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
text_prompt: str) -> str: text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model""" """Combine multimodal prompts for a multimodal language model"""
# NOTE: For now we assume all model architectures use the same # Look through the text prompt to check for missing placeholders
# placeholder + text prompt format. This may change in the future. missing_placeholders = []
return f"{placeholder_token_str}\n{text_prompt}" for placeholder in placeholder_counts:
# For any existing placeholder in the text prompt, we leave it as is
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
if placeholder_counts[placeholder] < 0:
raise ValueError(
f"Found more '{placeholder}' placeholders in input prompt than "
"actual multimodal data items.")
missing_placeholders.extend([placeholder] *
placeholder_counts[placeholder])
# NOTE: For now we always add missing placeholders at the front of
# the prompt. This may change to be customizable in the future.
return "\n".join(missing_placeholders + [text_prompt])
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam) _TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
@ -160,12 +227,12 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
def _parse_chat_message_content_parts( def _parse_chat_message_content_parts(
role: str, role: str,
parts: Iterable[ChatCompletionContentPartParam], parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig, mm_tracker: MultiModalItemTracker,
tokenizer: AnyTokenizer, ) -> List[ConversationMessage]:
) -> ChatMessageParseResult:
texts: List[str] = [] texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
modality: Literal["image", "audio"] = "image" # multimodal placeholder_string : count
mm_placeholder_counts: Dict[str, int] = {}
for part in parts: for part in parts:
part_type = part["type"] part_type = part["type"]
@ -173,11 +240,6 @@ def _parse_chat_message_content_parts(
text = _TextParser.validate_python(part)["text"] text = _TextParser.validate_python(part)["text"]
texts.append(text) texts.append(text)
elif part_type == "image_url": elif part_type == "image_url":
modality = "image"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
image_url = _ImageParser.validate_python(part)["image_url"] image_url = _ImageParser.validate_python(part)["image_url"]
if image_url.get("detail", "auto") != "auto": if image_url.get("detail", "auto") != "auto":
@ -185,60 +247,44 @@ def _parse_chat_message_content_parts(
"'image_url.detail' is currently not supported and " "'image_url.detail' is currently not supported and "
"will be ignored.") "will be ignored.")
image_future = async_get_and_parse_image(image_url["url"]) image_coro = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future) placeholder = mm_tracker.add("image", image_coro)
if placeholder:
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
placeholder, 0) + 1
elif part_type == "audio_url": elif part_type == "audio_url":
modality = "audio"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
audio_url = _AudioParser.validate_python(part)["audio_url"] audio_url = _AudioParser.validate_python(part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"]) audio_coro = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future) placeholder = mm_tracker.add("audio", audio_coro)
if placeholder:
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
placeholder, 0) + 1
else: else:
raise NotImplementedError(f"Unknown part type: {part_type}") raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts) text_prompt = "\n".join(texts)
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
if mm_futures: return [ConversationMessage(role=role, content=text_prompt)]
placeholder_token_str = _mm_token_str(model_config, tokenizer,
modality)
if placeholder_token_str is not None:
if placeholder_token_str in text_prompt:
logger.warning(
"Detected multi-modal token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = _get_full_multimodal_text_prompt(
placeholder_token_str=placeholder_token_str,
text_prompt=text_prompt,
)
messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
def _parse_chat_message_content( def _parse_chat_message_content(
message: ChatCompletionMessageParam, message: ChatCompletionMessageParam,
model_config: ModelConfig, mm_tracker: MultiModalItemTracker) -> List[ConversationMessage]:
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
role = message["role"] role = message["role"]
content = message.get("content") content = message.get("content")
if content is None: if content is None:
return ChatMessageParseResult(messages=[], mm_futures=[]) return []
if isinstance(content, str): if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)] return [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])
return _parse_chat_message_content_parts( return _parse_chat_message_content_parts(
role, role,
content, # type: ignore content, # type: ignore
model_config, mm_tracker,
tokenizer,
) )
@ -246,18 +292,16 @@ def parse_chat_messages(
messages: List[ChatCompletionMessageParam], messages: List[ChatCompletionMessageParam],
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]: ) -> Tuple[List[ConversationMessage], Optional[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = [] conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = [] mm_tracker = MultiModalItemTracker(model_config, tokenizer)
for msg in messages: for msg in messages:
parse_result = _parse_chat_message_content(msg, model_config, sub_messages = _parse_chat_message_content(msg, mm_tracker)
tokenizer)
conversation.extend(parse_result.messages) conversation.extend(sub_messages)
mm_futures.extend(parse_result.mm_futures)
return conversation, mm_futures return conversation, mm_tracker.all_mm_data()
def apply_chat_template( def apply_chat_template(

View File

@ -94,7 +94,7 @@ class OpenAIServingChat(OpenAIServing):
tokenizer = await self.async_engine_client.get_tokenizer( tokenizer = await self.async_engine_client.get_tokenizer(
lora_request) lora_request)
conversation, mm_futures = parse_chat_messages( conversation, mm_data_future = parse_chat_messages(
request.messages, model_config, tokenizer) request.messages, model_config, tokenizer)
tool_dicts = None if request.tools is None else [ tool_dicts = None if request.tools is None else [
@ -116,12 +116,8 @@ class OpenAIServingChat(OpenAIServing):
mm_data: Optional[MultiModalDataDict] = None mm_data: Optional[MultiModalDataDict] = None
try: try:
if len(mm_futures): if mm_data_future:
# since we support only single mm data currently mm_data = await mm_data_future
assert len(
mm_futures
) == 1, "Multiple 'image_url' input is currently not supported."
mm_data = await mm_futures[0]
except Exception as e: except Exception as e:
logger.error("Error in loading multi-modal data: %s", e) logger.error("Error in loading multi-modal data: %s", e)
return self.create_error_response(str(e)) return self.create_error_response(str(e))

View File

@ -65,10 +65,10 @@ class OpenAIServingTokenization(OpenAIServing):
if isinstance(request, TokenizeChatRequest): if isinstance(request, TokenizeChatRequest):
model_config = self.model_config model_config = self.model_config
conversation, mm_futures = parse_chat_messages( conversation, mm_data_future = parse_chat_messages(
request.messages, model_config, tokenizer) request.messages, model_config, tokenizer)
if mm_futures: if mm_data_future:
logger.warning( logger.warning(
"Multi-modal inputs are ignored during tokenization") "Multi-modal inputs are ignored during tokenization")