mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:55:50 +08:00
[Frontend][VLM] Add support for multiple multi-modal items (#8049)
This commit is contained in:
parent
8423aef4c8
commit
5231f0898e
@ -90,6 +90,7 @@ steps:
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
@ -1,7 +1,13 @@
|
||||
"""An example showing how to use vLLM to serve VLMs.
|
||||
|
||||
Launch the vLLM server with the following command:
|
||||
|
||||
(single image inference with Llava)
|
||||
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
|
||||
|
||||
(multi-image inference with Phi-3.5-vision-instruct)
|
||||
vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \
|
||||
--trust-remote-code --limit-mm-per-prompt image=2
|
||||
"""
|
||||
import base64
|
||||
|
||||
@ -84,3 +90,36 @@ chat_completion_from_base64 = client.chat.completions.create(
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print(f"Chat completion output:{result}")
|
||||
|
||||
# Multi-image input inference
|
||||
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
|
||||
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What are the animals in these images?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url_duck
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url_lion
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
model=model,
|
||||
max_tokens=64,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print(f"Chat completion output:{result}")
|
||||
|
||||
@ -3,6 +3,7 @@ from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
@ -20,6 +21,7 @@ class MockModelConfig:
|
||||
max_model_len = 100
|
||||
tokenizer_revision = None
|
||||
embedding_mode = False
|
||||
multimodal_config = MultiModalConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -6,11 +6,10 @@ import pytest_asyncio
|
||||
|
||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||
|
||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
LLAVA_CHAT_TEMPLATE = VLLM_PATH / "examples/template_llava.jinja"
|
||||
assert LLAVA_CHAT_TEMPLATE.exists()
|
||||
MODEL_NAME = "microsoft/Phi-3.5-vision-instruct"
|
||||
MAXIMUM_IMAGES = 2
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
TEST_IMAGE_URLS = [
|
||||
@ -24,13 +23,9 @@ TEST_IMAGE_URLS = [
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--enforce-eager",
|
||||
"--chat-template",
|
||||
str(LLAVA_CHAT_TEMPLATE),
|
||||
"--dtype", "bfloat16", "--max-model-len", "4096", "--max-num-seqs",
|
||||
"5", "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt",
|
||||
f"image={MAXIMUM_IMAGES}"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@ -84,7 +79,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=596, total_tokens=606)
|
||||
completion_tokens=10, prompt_tokens=772, total_tokens=782)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@ -139,7 +134,7 @@ async def test_single_chat_session_image_base64encoded(
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=596, total_tokens=606)
|
||||
completion_tokens=10, prompt_tokens=772, total_tokens=782)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@ -217,26 +212,22 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
@pytest.mark.parametrize(
|
||||
"image_urls",
|
||||
[TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))])
|
||||
async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
image_url: str):
|
||||
image_urls: List[str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
*({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
} for image_url in image_urls),
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
@ -244,20 +235,30 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
],
|
||||
}]
|
||||
|
||||
with pytest.raises(openai.BadRequestError): # test multi-image input
|
||||
await client.chat.completions.create(
|
||||
if len(image_urls) > MAXIMUM_IMAGES:
|
||||
with pytest.raises(openai.BadRequestError): # test multi-image input
|
||||
await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# the server should still work afterwards
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
completion = completion.choices[0].text
|
||||
assert completion is not None and len(completion) >= 0
|
||||
else:
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# the server should still work afterwards
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
completion = completion.choices[0].text
|
||||
assert completion is not None and len(completion) >= 0
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
305
tests/entrypoints/test_chat_utils.py
Normal file
305
tests/entrypoints/test_chat_utils.py
Normal file
@ -0,0 +1,305 @@
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import parse_chat_messages
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def phi3v_model_config():
|
||||
return ModelConfig(PHI3V_MODEL_ID,
|
||||
PHI3V_MODEL_ID,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
seed=0,
|
||||
limit_mm_per_prompt={
|
||||
"image": 2,
|
||||
})
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def phi3v_tokenizer():
|
||||
return TokenizerGroup(
|
||||
tokenizer_id=PHI3V_MODEL_ID,
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def image_url():
|
||||
image = ImageAsset('cherry_blossom')
|
||||
base64 = encode_image_base64(image.pil_image)
|
||||
return f"data:image/jpeg;base64,{base64}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_with_image_url(phi3v_model_config,
|
||||
phi3v_tokenizer, image_url):
|
||||
conversation, mm_future = parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What's in the image?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
assert conversation == [{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\nWhat's in the image?"
|
||||
}]
|
||||
mm_data = await mm_future
|
||||
assert set(mm_data.keys()) == {"image"}
|
||||
assert isinstance(mm_data["image"], Image.Image)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images(phi3v_model_config,
|
||||
phi3v_tokenizer, image_url):
|
||||
conversation, mm_future = parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What's in these images?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
assert conversation == [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
|
||||
}]
|
||||
mm_data = await mm_future
|
||||
assert set(mm_data.keys()) == {"image"}
|
||||
assert len(mm_data["image"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_placeholder_already_in_prompt(
|
||||
phi3v_model_config, phi3v_tokenizer, image_url):
|
||||
conversation, mm_future = parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type":
|
||||
"text",
|
||||
"text":
|
||||
"What's in <|image_1|> and how does it compare to <|image_2|>?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
assert conversation == [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What's in <|image_1|> and how does it compare to <|image_2|>?"
|
||||
}]
|
||||
mm_data = await mm_future
|
||||
assert set(mm_data.keys()) == {"image"}
|
||||
assert len(mm_data["image"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_placeholder_one_already_in_prompt(
|
||||
phi3v_model_config, phi3v_tokenizer, image_url):
|
||||
conversation, mm_future = parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type":
|
||||
"text",
|
||||
"text":
|
||||
"What's in <|image_1|> and how does it compare to the other one?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
assert conversation == [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"<|image_2|>\nWhat's in <|image_1|> and how does it compare to the "
|
||||
"other one?"
|
||||
}]
|
||||
mm_data = await mm_future
|
||||
assert set(mm_data.keys()) == {"image"}
|
||||
assert len(mm_data["image"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_across_messages(
|
||||
phi3v_model_config, phi3v_tokenizer, image_url):
|
||||
conversation, mm_future = parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
}]
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "Some stuff."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What about this one?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
assert conversation == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\nWhat's in this image?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Some stuff."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_2|>\nWhat about this one?"
|
||||
},
|
||||
]
|
||||
mm_data = await mm_future
|
||||
assert set(mm_data.keys()) == {"image"}
|
||||
assert len(mm_data["image"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_rejects_too_many_images_in_one_message(
|
||||
phi3v_model_config, phi3v_tokenizer, image_url):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="coroutine 'async_get_and_parse_image' was never awaited")
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="At most 2 image\\(s\\) may be provided in one request\\."
|
||||
):
|
||||
parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What's in these images?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_rejects_too_many_images_across_messages(
|
||||
phi3v_model_config, phi3v_tokenizer, image_url):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="coroutine 'async_get_and_parse_image' was never awaited")
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="At most 2 image\\(s\\) may be provided in one request\\."
|
||||
):
|
||||
parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
}]
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "Some stuff."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What about these two?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
import codecs
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
|
||||
Union)
|
||||
from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Mapping,
|
||||
Optional, Tuple, Union)
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -80,10 +81,90 @@ class ConversationMessage(TypedDict):
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChatMessageParseResult:
|
||||
messages: List[ConversationMessage]
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]]
|
||||
class MultiModalItemTracker:
|
||||
"""
|
||||
Tracks multi-modal items in a given request and ensures that the number
|
||||
of multi-modal items in a given request does not exceed the configured
|
||||
maximum per prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
|
||||
self._model_config = model_config
|
||||
self._tokenizer = tokenizer
|
||||
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
||||
if model_config.multimodal_config else {})
|
||||
self._consumed_items = {k: 0 for k in self._allowed_items}
|
||||
self._futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=None)
|
||||
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int):
|
||||
return tokenizer.decode(token_index)
|
||||
|
||||
def add(self, modality: Literal["image", "audio"],
|
||||
mm_future: Awaitable[MultiModalDataDict]) -> Optional[str]:
|
||||
"""
|
||||
Adds the multi-modal item to the current prompt and returns the
|
||||
placeholder string to use, if any.
|
||||
"""
|
||||
allowed_count = self._allowed_items.get(modality, 1)
|
||||
current_count = self._consumed_items.get(modality, 0) + 1
|
||||
if current_count > allowed_count:
|
||||
raise ValueError(
|
||||
f"At most {allowed_count} {modality}(s) may be provided in "
|
||||
"one request.")
|
||||
|
||||
self._consumed_items[modality] = current_count
|
||||
self._futures.append(mm_future)
|
||||
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
model_type = self._model_config.hf_config.model_type
|
||||
if modality == "image":
|
||||
if model_type == "phi3_v":
|
||||
# Workaround since this token is not defined in the tokenizer
|
||||
return f"<|image_{current_count}|>"
|
||||
if model_type == "minicpmv":
|
||||
return "(<image>./</image>)"
|
||||
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
|
||||
# These models do not use image tokens in the prompt
|
||||
return None
|
||||
if model_type.startswith("llava"):
|
||||
return MultiModalItemTracker._cached_token_str(
|
||||
self._tokenizer,
|
||||
self._model_config.hf_config.image_token_index)
|
||||
if model_type in ("chameleon", "internvl_chat"):
|
||||
return "<image>"
|
||||
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "audio":
|
||||
if model_type == "ultravox":
|
||||
return "<|reserved_special_token_0|>"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
else:
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
|
||||
@staticmethod
|
||||
async def _combine(futures: List[Awaitable[MultiModalDataDict]]):
|
||||
mm_lists: Mapping[str, List[object]] = defaultdict(list)
|
||||
|
||||
# Merge all the multi-modal items
|
||||
for single_mm_data in (await asyncio.gather(*futures)):
|
||||
for mm_key, mm_item in single_mm_data.items():
|
||||
if isinstance(mm_item, list):
|
||||
mm_lists[mm_key].extend(mm_item)
|
||||
else:
|
||||
mm_lists[mm_key].append(mm_item)
|
||||
|
||||
# Unpack any single item lists for models that don't expect multiple.
|
||||
return {
|
||||
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
|
||||
for mm_key, mm_list in mm_lists.items()
|
||||
}
|
||||
|
||||
def all_mm_data(self) -> Optional[Awaitable[MultiModalDataDict]]:
|
||||
return MultiModalItemTracker._combine(
|
||||
self._futures) if self._futures else None
|
||||
|
||||
|
||||
def load_chat_template(
|
||||
@ -112,44 +193,30 @@ def load_chat_template(
|
||||
return resolved_chat_template
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
|
||||
modality: Literal["image", "audio"]) -> Optional[str]:
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
model_type = model_config.hf_config.model_type
|
||||
if modality == "image":
|
||||
if model_type == "phi3_v":
|
||||
# Workaround since this token is not defined in the tokenizer
|
||||
return "<|image_1|>"
|
||||
if model_type == "minicpmv":
|
||||
return "(<image>./</image>)"
|
||||
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
|
||||
# These models do not use image tokens in the prompt
|
||||
return None
|
||||
if model_type.startswith("llava"):
|
||||
return tokenizer.decode(model_config.hf_config.image_token_index)
|
||||
if model_type in ("chameleon", "internvl_chat"):
|
||||
return "<image>"
|
||||
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "audio":
|
||||
if model_type == "ultravox":
|
||||
return "<|reserved_special_token_0|>"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
else:
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
|
||||
|
||||
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||
# (similar to chat template)
|
||||
def _get_full_multimodal_text_prompt(placeholder_token_str: str,
|
||||
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
|
||||
text_prompt: str) -> str:
|
||||
"""Combine multimodal prompts for a multimodal language model"""
|
||||
|
||||
# NOTE: For now we assume all model architectures use the same
|
||||
# placeholder + text prompt format. This may change in the future.
|
||||
return f"{placeholder_token_str}\n{text_prompt}"
|
||||
# Look through the text prompt to check for missing placeholders
|
||||
missing_placeholders = []
|
||||
for placeholder in placeholder_counts:
|
||||
|
||||
# For any existing placeholder in the text prompt, we leave it as is
|
||||
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
|
||||
|
||||
if placeholder_counts[placeholder] < 0:
|
||||
raise ValueError(
|
||||
f"Found more '{placeholder}' placeholders in input prompt than "
|
||||
"actual multimodal data items.")
|
||||
|
||||
missing_placeholders.extend([placeholder] *
|
||||
placeholder_counts[placeholder])
|
||||
|
||||
# NOTE: For now we always add missing placeholders at the front of
|
||||
# the prompt. This may change to be customizable in the future.
|
||||
return "\n".join(missing_placeholders + [text_prompt])
|
||||
|
||||
|
||||
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
|
||||
@ -160,12 +227,12 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
|
||||
def _parse_chat_message_content_parts(
|
||||
role: str,
|
||||
parts: Iterable[ChatCompletionContentPartParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> ChatMessageParseResult:
|
||||
mm_tracker: MultiModalItemTracker,
|
||||
) -> List[ConversationMessage]:
|
||||
texts: List[str] = []
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
modality: Literal["image", "audio"] = "image"
|
||||
|
||||
# multimodal placeholder_string : count
|
||||
mm_placeholder_counts: Dict[str, int] = {}
|
||||
|
||||
for part in parts:
|
||||
part_type = part["type"]
|
||||
@ -173,11 +240,6 @@ def _parse_chat_message_content_parts(
|
||||
text = _TextParser.validate_python(part)["text"]
|
||||
texts.append(text)
|
||||
elif part_type == "image_url":
|
||||
modality = "image"
|
||||
if len(mm_futures) > 0:
|
||||
raise NotImplementedError(
|
||||
"Multiple multimodal inputs is currently not supported.")
|
||||
|
||||
image_url = _ImageParser.validate_python(part)["image_url"]
|
||||
|
||||
if image_url.get("detail", "auto") != "auto":
|
||||
@ -185,60 +247,44 @@ def _parse_chat_message_content_parts(
|
||||
"'image_url.detail' is currently not supported and "
|
||||
"will be ignored.")
|
||||
|
||||
image_future = async_get_and_parse_image(image_url["url"])
|
||||
mm_futures.append(image_future)
|
||||
image_coro = async_get_and_parse_image(image_url["url"])
|
||||
placeholder = mm_tracker.add("image", image_coro)
|
||||
if placeholder:
|
||||
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
|
||||
placeholder, 0) + 1
|
||||
elif part_type == "audio_url":
|
||||
modality = "audio"
|
||||
if len(mm_futures) > 0:
|
||||
raise NotImplementedError(
|
||||
"Multiple multimodal inputs is currently not supported.")
|
||||
|
||||
audio_url = _AudioParser.validate_python(part)["audio_url"]
|
||||
audio_future = async_get_and_parse_audio(audio_url["url"])
|
||||
mm_futures.append(audio_future)
|
||||
audio_coro = async_get_and_parse_audio(audio_url["url"])
|
||||
placeholder = mm_tracker.add("audio", audio_coro)
|
||||
if placeholder:
|
||||
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
|
||||
placeholder, 0) + 1
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
text_prompt = "\n".join(texts)
|
||||
if mm_placeholder_counts:
|
||||
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
|
||||
text_prompt)
|
||||
|
||||
if mm_futures:
|
||||
placeholder_token_str = _mm_token_str(model_config, tokenizer,
|
||||
modality)
|
||||
if placeholder_token_str is not None:
|
||||
if placeholder_token_str in text_prompt:
|
||||
logger.warning(
|
||||
"Detected multi-modal token string in the text prompt. "
|
||||
"Skipping prompt formatting.")
|
||||
else:
|
||||
text_prompt = _get_full_multimodal_text_prompt(
|
||||
placeholder_token_str=placeholder_token_str,
|
||||
text_prompt=text_prompt,
|
||||
)
|
||||
|
||||
messages = [ConversationMessage(role=role, content=text_prompt)]
|
||||
|
||||
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
|
||||
return [ConversationMessage(role=role, content=text_prompt)]
|
||||
|
||||
|
||||
def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> ChatMessageParseResult:
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: MultiModalItemTracker) -> List[ConversationMessage]:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
|
||||
if content is None:
|
||||
return ChatMessageParseResult(messages=[], mm_futures=[])
|
||||
return []
|
||||
if isinstance(content, str):
|
||||
messages = [ConversationMessage(role=role, content=content)]
|
||||
return ChatMessageParseResult(messages=messages, mm_futures=[])
|
||||
return [ConversationMessage(role=role, content=content)]
|
||||
|
||||
return _parse_chat_message_content_parts(
|
||||
role,
|
||||
content, # type: ignore
|
||||
model_config,
|
||||
tokenizer,
|
||||
mm_tracker,
|
||||
)
|
||||
|
||||
|
||||
@ -246,18 +292,16 @@ def parse_chat_messages(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
|
||||
) -> Tuple[List[ConversationMessage], Optional[Awaitable[MultiModalDataDict]]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
for msg in messages:
|
||||
parse_result = _parse_chat_message_content(msg, model_config,
|
||||
tokenizer)
|
||||
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||
|
||||
conversation.extend(parse_result.messages)
|
||||
mm_futures.extend(parse_result.mm_futures)
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
return conversation, mm_futures
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
|
||||
@ -94,7 +94,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
||||
lora_request)
|
||||
|
||||
conversation, mm_futures = parse_chat_messages(
|
||||
conversation, mm_data_future = parse_chat_messages(
|
||||
request.messages, model_config, tokenizer)
|
||||
|
||||
tool_dicts = None if request.tools is None else [
|
||||
@ -116,12 +116,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
mm_data: Optional[MultiModalDataDict] = None
|
||||
try:
|
||||
if len(mm_futures):
|
||||
# since we support only single mm data currently
|
||||
assert len(
|
||||
mm_futures
|
||||
) == 1, "Multiple 'image_url' input is currently not supported."
|
||||
mm_data = await mm_futures[0]
|
||||
if mm_data_future:
|
||||
mm_data = await mm_data_future
|
||||
except Exception as e:
|
||||
logger.error("Error in loading multi-modal data: %s", e)
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@ -65,10 +65,10 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
model_config = self.model_config
|
||||
|
||||
conversation, mm_futures = parse_chat_messages(
|
||||
conversation, mm_data_future = parse_chat_messages(
|
||||
request.messages, model_config, tokenizer)
|
||||
|
||||
if mm_futures:
|
||||
if mm_data_future:
|
||||
logger.warning(
|
||||
"Multi-modal inputs are ignored during tokenization")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user