mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:54:58 +08:00
[Frontend] Multimodal support in offline chat (#8098)
This commit is contained in:
parent
2be8ec6e71
commit
855c262a6b
@ -6,6 +6,7 @@ import pytest
|
|||||||
from vllm import LLM, RequestOutput, SamplingParams
|
from vllm import LLM, RequestOutput, SamplingParams
|
||||||
|
|
||||||
from ...conftest import cleanup
|
from ...conftest import cleanup
|
||||||
|
from ..openai.test_vision import TEST_IMAGE_URLS
|
||||||
|
|
||||||
MODEL_NAME = "facebook/opt-125m"
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
|
|
||||||
@ -159,3 +160,36 @@ def test_chat():
|
|||||||
]
|
]
|
||||||
outputs = llm.chat(messages)
|
outputs = llm.chat(messages)
|
||||||
assert len(outputs) == 1
|
assert len(outputs) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("image_urls",
|
||||||
|
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
|
||||||
|
def test_chat_multi_image(image_urls: List[str]):
|
||||||
|
llm = LLM(
|
||||||
|
model="microsoft/Phi-3.5-vision-instruct",
|
||||||
|
dtype="bfloat16",
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=5,
|
||||||
|
enforce_eager=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
limit_mm_per_prompt={"image": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
*({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
} for image_url in image_urls),
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
outputs = llm.chat(messages)
|
||||||
|
assert len(outputs) >= 0
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.entrypoints.chat_utils import parse_chat_messages
|
from vllm.entrypoints.chat_utils import (parse_chat_messages,
|
||||||
|
parse_chat_messages_futures)
|
||||||
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.multimodal.utils import encode_image_base64
|
from vllm.multimodal.utils import encode_image_base64
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
|
|
||||||
@ -42,10 +45,28 @@ def image_url():
|
|||||||
return f"data:image/jpeg;base64,{base64}"
|
return f"data:image/jpeg;base64,{base64}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def _assert_mm_data_is_image_input(
|
||||||
async def test_parse_chat_messages_with_image_url(phi3v_model_config,
|
mm_data: Optional[MultiModalDataDict],
|
||||||
phi3v_tokenizer, image_url):
|
image_count: int,
|
||||||
conversation, mm_future = parse_chat_messages([{
|
) -> None:
|
||||||
|
assert mm_data is not None
|
||||||
|
assert set(mm_data.keys()) == {"image"}
|
||||||
|
|
||||||
|
image_data = mm_data.get("image")
|
||||||
|
assert image_data is not None
|
||||||
|
|
||||||
|
if image_count == 1:
|
||||||
|
assert isinstance(image_data, Image.Image)
|
||||||
|
else:
|
||||||
|
assert isinstance(image_data, list) and len(image_data) == image_count
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_single_image(
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
image_url,
|
||||||
|
):
|
||||||
|
conversation, mm_data = parse_chat_messages([{
|
||||||
"role":
|
"role":
|
||||||
"user",
|
"user",
|
||||||
"content": [{
|
"content": [{
|
||||||
@ -63,15 +84,42 @@ async def test_parse_chat_messages_with_image_url(phi3v_model_config,
|
|||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "<|image_1|>\nWhat's in the image?"
|
"content": "<|image_1|>\nWhat's in the image?"
|
||||||
}]
|
}]
|
||||||
mm_data = await mm_future
|
_assert_mm_data_is_image_input(mm_data, 1)
|
||||||
assert set(mm_data.keys()) == {"image"}
|
|
||||||
assert isinstance(mm_data["image"], Image.Image)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parse_chat_messages_multiple_images(phi3v_model_config,
|
async def test_parse_chat_messages_single_image_async(
|
||||||
phi3v_tokenizer, image_url):
|
phi3v_model_config,
|
||||||
conversation, mm_future = parse_chat_messages([{
|
phi3v_tokenizer,
|
||||||
|
image_url,
|
||||||
|
):
|
||||||
|
conversation, mm_future = parse_chat_messages_futures([{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in the image?"
|
||||||
|
}]
|
||||||
|
}], phi3v_model_config, phi3v_tokenizer)
|
||||||
|
|
||||||
|
assert conversation == [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "<|image_1|>\nWhat's in the image?"
|
||||||
|
}]
|
||||||
|
_assert_mm_data_is_image_input(await mm_future, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_multiple_images(
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
image_url,
|
||||||
|
):
|
||||||
|
conversation, mm_data = parse_chat_messages([{
|
||||||
"role":
|
"role":
|
||||||
"user",
|
"user",
|
||||||
"content": [{
|
"content": [{
|
||||||
@ -96,15 +144,49 @@ async def test_parse_chat_messages_multiple_images(phi3v_model_config,
|
|||||||
"content":
|
"content":
|
||||||
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
|
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
|
||||||
}]
|
}]
|
||||||
mm_data = await mm_future
|
_assert_mm_data_is_image_input(mm_data, 2)
|
||||||
assert set(mm_data.keys()) == {"image"}
|
|
||||||
assert len(mm_data["image"]) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parse_chat_messages_placeholder_already_in_prompt(
|
async def test_parse_chat_messages_multiple_images_async(
|
||||||
phi3v_model_config, phi3v_tokenizer, image_url):
|
phi3v_model_config,
|
||||||
conversation, mm_future = parse_chat_messages([{
|
phi3v_tokenizer,
|
||||||
|
image_url,
|
||||||
|
):
|
||||||
|
conversation, mm_future = parse_chat_messages_futures([{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in these images?"
|
||||||
|
}]
|
||||||
|
}], phi3v_model_config, phi3v_tokenizer)
|
||||||
|
|
||||||
|
assert conversation == [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
|
||||||
|
}]
|
||||||
|
_assert_mm_data_is_image_input(await mm_future, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_placeholder_already_in_prompt(
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
image_url,
|
||||||
|
):
|
||||||
|
conversation, mm_data = parse_chat_messages([{
|
||||||
"role":
|
"role":
|
||||||
"user",
|
"user",
|
||||||
"content": [{
|
"content": [{
|
||||||
@ -131,15 +213,15 @@ async def test_parse_chat_messages_placeholder_already_in_prompt(
|
|||||||
"content":
|
"content":
|
||||||
"What's in <|image_1|> and how does it compare to <|image_2|>?"
|
"What's in <|image_1|> and how does it compare to <|image_2|>?"
|
||||||
}]
|
}]
|
||||||
mm_data = await mm_future
|
_assert_mm_data_is_image_input(mm_data, 2)
|
||||||
assert set(mm_data.keys()) == {"image"}
|
|
||||||
assert len(mm_data["image"]) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_parse_chat_messages_placeholder_one_already_in_prompt(
|
||||||
async def test_parse_chat_messages_placeholder_one_already_in_prompt(
|
phi3v_model_config,
|
||||||
phi3v_model_config, phi3v_tokenizer, image_url):
|
phi3v_tokenizer,
|
||||||
conversation, mm_future = parse_chat_messages([{
|
image_url,
|
||||||
|
):
|
||||||
|
conversation, mm_data = parse_chat_messages([{
|
||||||
"role":
|
"role":
|
||||||
"user",
|
"user",
|
||||||
"content": [{
|
"content": [{
|
||||||
@ -167,15 +249,15 @@ async def test_parse_chat_messages_placeholder_one_already_in_prompt(
|
|||||||
"<|image_2|>\nWhat's in <|image_1|> and how does it compare to the "
|
"<|image_2|>\nWhat's in <|image_1|> and how does it compare to the "
|
||||||
"other one?"
|
"other one?"
|
||||||
}]
|
}]
|
||||||
mm_data = await mm_future
|
_assert_mm_data_is_image_input(mm_data, 2)
|
||||||
assert set(mm_data.keys()) == {"image"}
|
|
||||||
assert len(mm_data["image"]) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_parse_chat_messages_multiple_images_across_messages(
|
||||||
async def test_parse_chat_messages_multiple_images_across_messages(
|
phi3v_model_config,
|
||||||
phi3v_model_config, phi3v_tokenizer, image_url):
|
phi3v_tokenizer,
|
||||||
conversation, mm_future = parse_chat_messages([{
|
image_url,
|
||||||
|
):
|
||||||
|
conversation, mm_data = parse_chat_messages([{
|
||||||
"role":
|
"role":
|
||||||
"user",
|
"user",
|
||||||
"content": [{
|
"content": [{
|
||||||
@ -218,14 +300,14 @@ async def test_parse_chat_messages_multiple_images_across_messages(
|
|||||||
"content": "<|image_2|>\nWhat about this one?"
|
"content": "<|image_2|>\nWhat about this one?"
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
mm_data = await mm_future
|
_assert_mm_data_is_image_input(mm_data, 2)
|
||||||
assert set(mm_data.keys()) == {"image"}
|
|
||||||
assert len(mm_data["image"]) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_parse_chat_messages_rejects_too_many_images_in_one_message(
|
||||||
async def test_parse_chat_messages_rejects_too_many_images_in_one_message(
|
phi3v_model_config,
|
||||||
phi3v_model_config, phi3v_tokenizer, image_url):
|
phi3v_tokenizer,
|
||||||
|
image_url,
|
||||||
|
):
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
"ignore",
|
"ignore",
|
||||||
@ -259,9 +341,11 @@ async def test_parse_chat_messages_rejects_too_many_images_in_one_message(
|
|||||||
}], phi3v_model_config, phi3v_tokenizer)
|
}], phi3v_model_config, phi3v_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_parse_chat_messages_rejects_too_many_images_across_messages(
|
||||||
async def test_parse_chat_messages_rejects_too_many_images_across_messages(
|
phi3v_model_config,
|
||||||
phi3v_model_config, phi3v_tokenizer, image_url):
|
phi3v_tokenizer,
|
||||||
|
image_url,
|
||||||
|
):
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
"ignore",
|
"ignore",
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import codecs
|
import codecs
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Mapping,
|
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
|
||||||
Optional, Tuple, Union)
|
Mapping, Optional, Tuple, TypeVar, Union)
|
||||||
|
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -23,7 +24,8 @@ from vllm.config import ModelConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||||
async_get_and_parse_image)
|
async_get_and_parse_image,
|
||||||
|
get_and_parse_audio, get_and_parse_image)
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -81,7 +83,11 @@ class ConversationMessage(TypedDict):
|
|||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class MultiModalItemTracker:
|
ModalityStr = Literal["image", "audio"]
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||||
"""
|
"""
|
||||||
Tracks multi-modal items in a given request and ensures that the number
|
Tracks multi-modal items in a given request and ensures that the number
|
||||||
of multi-modal items in a given request does not exceed the configured
|
of multi-modal items in a given request does not exceed the configured
|
||||||
@ -89,37 +95,28 @@ class MultiModalItemTracker:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
|
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self._model_config = model_config
|
self._model_config = model_config
|
||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
||||||
if model_config.multimodal_config else {})
|
if model_config.multimodal_config else {})
|
||||||
self._consumed_items = {k: 0 for k in self._allowed_items}
|
self._consumed_items = {k: 0 for k in self._allowed_items}
|
||||||
self._futures: List[Awaitable[MultiModalDataDict]] = []
|
|
||||||
|
self._items: List[_T] = []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int):
|
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
|
||||||
return tokenizer.decode(token_index)
|
return tokenizer.decode(token_index)
|
||||||
|
|
||||||
def add(self, modality: Literal["image", "audio"],
|
def _placeholder_str(self, modality: ModalityStr,
|
||||||
mm_future: Awaitable[MultiModalDataDict]) -> Optional[str]:
|
current_count: int) -> Optional[str]:
|
||||||
"""
|
|
||||||
Adds the multi-modal item to the current prompt and returns the
|
|
||||||
placeholder string to use, if any.
|
|
||||||
"""
|
|
||||||
allowed_count = self._allowed_items.get(modality, 1)
|
|
||||||
current_count = self._consumed_items.get(modality, 0) + 1
|
|
||||||
if current_count > allowed_count:
|
|
||||||
raise ValueError(
|
|
||||||
f"At most {allowed_count} {modality}(s) may be provided in "
|
|
||||||
"one request.")
|
|
||||||
|
|
||||||
self._consumed_items[modality] = current_count
|
|
||||||
self._futures.append(mm_future)
|
|
||||||
|
|
||||||
# TODO: Let user specify how to insert image tokens into prompt
|
# TODO: Let user specify how to insert image tokens into prompt
|
||||||
# (similar to chat template)
|
# (similar to chat template)
|
||||||
model_type = self._model_config.hf_config.model_type
|
hf_config = self._model_config.hf_config
|
||||||
|
model_type = hf_config.model_type
|
||||||
|
|
||||||
if modality == "image":
|
if modality == "image":
|
||||||
if model_type == "phi3_v":
|
if model_type == "phi3_v":
|
||||||
# Workaround since this token is not defined in the tokenizer
|
# Workaround since this token is not defined in the tokenizer
|
||||||
@ -130,9 +127,8 @@ class MultiModalItemTracker:
|
|||||||
# These models do not use image tokens in the prompt
|
# These models do not use image tokens in the prompt
|
||||||
return None
|
return None
|
||||||
if model_type.startswith("llava"):
|
if model_type.startswith("llava"):
|
||||||
return MultiModalItemTracker._cached_token_str(
|
return self._cached_token_str(self._tokenizer,
|
||||||
self._tokenizer,
|
hf_config.image_token_index)
|
||||||
self._model_config.hf_config.image_token_index)
|
|
||||||
if model_type in ("chameleon", "internvl_chat"):
|
if model_type in ("chameleon", "internvl_chat"):
|
||||||
return "<image>"
|
return "<image>"
|
||||||
|
|
||||||
@ -145,11 +141,11 @@ class MultiModalItemTracker:
|
|||||||
raise TypeError(f"Unknown modality: {modality}")
|
raise TypeError(f"Unknown modality: {modality}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _combine(futures: List[Awaitable[MultiModalDataDict]]):
|
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
|
||||||
mm_lists: Mapping[str, List[object]] = defaultdict(list)
|
mm_lists: Mapping[str, List[object]] = defaultdict(list)
|
||||||
|
|
||||||
# Merge all the multi-modal items
|
# Merge all the multi-modal items
|
||||||
for single_mm_data in (await asyncio.gather(*futures)):
|
for single_mm_data in items:
|
||||||
for mm_key, mm_item in single_mm_data.items():
|
for mm_key, mm_item in single_mm_data.items():
|
||||||
if isinstance(mm_item, list):
|
if isinstance(mm_item, list):
|
||||||
mm_lists[mm_key].extend(mm_item)
|
mm_lists[mm_key].extend(mm_item)
|
||||||
@ -162,9 +158,113 @@ class MultiModalItemTracker:
|
|||||||
for mm_key, mm_list in mm_lists.items()
|
for mm_key, mm_list in mm_lists.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def all_mm_data(self) -> Optional[Awaitable[MultiModalDataDict]]:
|
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
|
||||||
return MultiModalItemTracker._combine(
|
"""
|
||||||
self._futures) if self._futures else None
|
Add a multi-modal item to the current prompt and returns the
|
||||||
|
placeholder string to use, if any.
|
||||||
|
"""
|
||||||
|
allowed_count = self._allowed_items.get(modality, 1)
|
||||||
|
current_count = self._consumed_items.get(modality, 0) + 1
|
||||||
|
if current_count > allowed_count:
|
||||||
|
raise ValueError(
|
||||||
|
f"At most {allowed_count} {modality}(s) may be provided in "
|
||||||
|
"one request.")
|
||||||
|
|
||||||
|
self._consumed_items[modality] = current_count
|
||||||
|
self._items.append(item)
|
||||||
|
|
||||||
|
return self._placeholder_str(modality, current_count)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
|
||||||
|
|
||||||
|
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||||
|
return self._combine(self._items) if self._items else None
|
||||||
|
|
||||||
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
|
return MultiModalContentParser(self)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMultiModalItemTracker(
|
||||||
|
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
|
||||||
|
|
||||||
|
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||||
|
if self._items:
|
||||||
|
items = await asyncio.gather(*self._items)
|
||||||
|
return self._combine(items)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
|
return AsyncMultiModalContentParser(self)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMultiModalContentParser(ABC):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# multimodal placeholder_string : count
|
||||||
|
self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
|
||||||
|
|
||||||
|
def _add_placeholder(self, placeholder: Optional[str]):
|
||||||
|
if placeholder:
|
||||||
|
self._placeholder_counts[placeholder] += 1
|
||||||
|
|
||||||
|
def mm_placeholder_counts(self) -> Dict[str, int]:
|
||||||
|
return dict(self._placeholder_counts)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_image(self, image_url: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalContentParser(BaseMultiModalContentParser):
|
||||||
|
|
||||||
|
def __init__(self, tracker: MultiModalItemTracker) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._tracker = tracker
|
||||||
|
|
||||||
|
def parse_image(self, image_url: str) -> None:
|
||||||
|
image = get_and_parse_image(image_url)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("image", image)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
|
audio = get_and_parse_audio(audio_url)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("audio", audio)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||||
|
|
||||||
|
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._tracker = tracker
|
||||||
|
|
||||||
|
def parse_image(self, image_url: str) -> None:
|
||||||
|
image_coro = async_get_and_parse_image(image_url)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("image", image_coro)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
|
audio_coro = async_get_and_parse_audio(audio_url)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("audio", audio_coro)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
|
||||||
def load_chat_template(
|
def load_chat_template(
|
||||||
@ -197,10 +297,10 @@ def load_chat_template(
|
|||||||
# (similar to chat template)
|
# (similar to chat template)
|
||||||
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
|
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
|
||||||
text_prompt: str) -> str:
|
text_prompt: str) -> str:
|
||||||
"""Combine multimodal prompts for a multimodal language model"""
|
"""Combine multimodal prompts for a multimodal language model."""
|
||||||
|
|
||||||
# Look through the text prompt to check for missing placeholders
|
# Look through the text prompt to check for missing placeholders
|
||||||
missing_placeholders = []
|
missing_placeholders: List[str] = []
|
||||||
for placeholder in placeholder_counts:
|
for placeholder in placeholder_counts:
|
||||||
|
|
||||||
# For any existing placeholder in the text prompt, we leave it as is
|
# For any existing placeholder in the text prompt, we leave it as is
|
||||||
@ -227,12 +327,11 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
|
|||||||
def _parse_chat_message_content_parts(
|
def _parse_chat_message_content_parts(
|
||||||
role: str,
|
role: str,
|
||||||
parts: Iterable[ChatCompletionContentPartParam],
|
parts: Iterable[ChatCompletionContentPartParam],
|
||||||
mm_tracker: MultiModalItemTracker,
|
mm_tracker: BaseMultiModalItemTracker,
|
||||||
) -> List[ConversationMessage]:
|
) -> List[ConversationMessage]:
|
||||||
texts: List[str] = []
|
texts: List[str] = []
|
||||||
|
|
||||||
# multimodal placeholder_string : count
|
mm_parser = mm_tracker.create_parser()
|
||||||
mm_placeholder_counts: Dict[str, int] = {}
|
|
||||||
|
|
||||||
for part in parts:
|
for part in parts:
|
||||||
part_type = part["type"]
|
part_type = part["type"]
|
||||||
@ -247,22 +346,16 @@ def _parse_chat_message_content_parts(
|
|||||||
"'image_url.detail' is currently not supported and "
|
"'image_url.detail' is currently not supported and "
|
||||||
"will be ignored.")
|
"will be ignored.")
|
||||||
|
|
||||||
image_coro = async_get_and_parse_image(image_url["url"])
|
mm_parser.parse_image(image_url["url"])
|
||||||
placeholder = mm_tracker.add("image", image_coro)
|
|
||||||
if placeholder:
|
|
||||||
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
|
|
||||||
placeholder, 0) + 1
|
|
||||||
elif part_type == "audio_url":
|
elif part_type == "audio_url":
|
||||||
audio_url = _AudioParser.validate_python(part)["audio_url"]
|
audio_url = _AudioParser.validate_python(part)["audio_url"]
|
||||||
audio_coro = async_get_and_parse_audio(audio_url["url"])
|
|
||||||
placeholder = mm_tracker.add("audio", audio_coro)
|
mm_parser.parse_audio(audio_url["url"])
|
||||||
if placeholder:
|
|
||||||
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
|
|
||||||
placeholder, 0) + 1
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||||
|
|
||||||
text_prompt = "\n".join(texts)
|
text_prompt = "\n".join(texts)
|
||||||
|
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
|
||||||
if mm_placeholder_counts:
|
if mm_placeholder_counts:
|
||||||
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
|
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
|
||||||
text_prompt)
|
text_prompt)
|
||||||
@ -271,8 +364,9 @@ def _parse_chat_message_content_parts(
|
|||||||
|
|
||||||
|
|
||||||
def _parse_chat_message_content(
|
def _parse_chat_message_content(
|
||||||
message: ChatCompletionMessageParam,
|
message: ChatCompletionMessageParam,
|
||||||
mm_tracker: MultiModalItemTracker) -> List[ConversationMessage]:
|
mm_tracker: BaseMultiModalItemTracker,
|
||||||
|
) -> List[ConversationMessage]:
|
||||||
role = message["role"]
|
role = message["role"]
|
||||||
content = message.get("content")
|
content = message.get("content")
|
||||||
|
|
||||||
@ -292,7 +386,7 @@ def parse_chat_messages(
|
|||||||
messages: List[ChatCompletionMessageParam],
|
messages: List[ChatCompletionMessageParam],
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
) -> Tuple[List[ConversationMessage], Optional[Awaitable[MultiModalDataDict]]]:
|
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
|
||||||
conversation: List[ConversationMessage] = []
|
conversation: List[ConversationMessage] = []
|
||||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||||
|
|
||||||
@ -304,6 +398,22 @@ def parse_chat_messages(
|
|||||||
return conversation, mm_tracker.all_mm_data()
|
return conversation, mm_tracker.all_mm_data()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_chat_messages_futures(
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
model_config: ModelConfig,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
|
||||||
|
conversation: List[ConversationMessage] = []
|
||||||
|
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||||
|
|
||||||
|
conversation.extend(sub_messages)
|
||||||
|
|
||||||
|
return conversation, mm_tracker.all_mm_data()
|
||||||
|
|
||||||
|
|
||||||
def apply_chat_template(
|
def apply_chat_template(
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
conversation: List[ConversationMessage],
|
conversation: List[ConversationMessage],
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
|||||||
get_cached_tokenizer)
|
get_cached_tokenizer)
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Counter, deprecate_kwargs
|
from vllm.utils import Counter, deprecate_kwargs, is_list_of
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -358,15 +358,18 @@ class LLM:
|
|||||||
add_generation_prompt: bool = True,
|
add_generation_prompt: bool = True,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
"""
|
"""
|
||||||
Generates responses for chat messages.
|
Generate responses for a chat conversation.
|
||||||
|
|
||||||
Converts the messages to prompts using the tokenizer and calls
|
The chat conversation is converted into a text prompt using the
|
||||||
the :meth:`generate` method to generate the responses.
|
tokenizer and calls the :meth:`generate` method to generate the
|
||||||
|
responses.
|
||||||
|
|
||||||
|
Multi-modal inputs can be passed in the same way you would pass them
|
||||||
|
to the OpenAI API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: A list of messages to generate responses for. Each
|
messages: A single conversation represented as a list of messages.
|
||||||
message is a list of dictionaries with 'role' and 'content'
|
Each message is a dictionary with 'role' and 'content' keys.
|
||||||
keys.
|
|
||||||
sampling_params: The sampling parameters for text generation.
|
sampling_params: The sampling parameters for text generation.
|
||||||
If None, we use the default sampling parameters. When it
|
If None, we use the default sampling parameters. When it
|
||||||
is a single value, it is applied to every prompt. When it
|
is a single value, it is applied to every prompt. When it
|
||||||
@ -387,21 +390,25 @@ class LLM:
|
|||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
model_config = self.llm_engine.get_model_config()
|
model_config = self.llm_engine.get_model_config()
|
||||||
|
|
||||||
conversations, _ = parse_chat_messages(messages, model_config,
|
conversation, mm_data = parse_chat_messages(messages, model_config,
|
||||||
tokenizer)
|
tokenizer)
|
||||||
|
|
||||||
prompt = apply_chat_template(
|
prompt = apply_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
conversations,
|
conversation,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
add_generation_prompt=add_generation_prompt)
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
inputs: PromptInputs
|
inputs: PromptInputs
|
||||||
if isinstance(prompt, list) and isinstance(prompt[0], int):
|
if is_list_of(prompt, int):
|
||||||
inputs = TokensPrompt(prompt_token_ids=prompt)
|
inputs = TokensPrompt(prompt_token_ids=prompt)
|
||||||
else:
|
else:
|
||||||
inputs = TextPrompt(prompt=prompt)
|
inputs = TextPrompt(prompt=prompt)
|
||||||
|
|
||||||
|
if mm_data is not None:
|
||||||
|
inputs["multi_modal_data"] = mm_data
|
||||||
|
|
||||||
return self.generate(
|
return self.generate(
|
||||||
inputs,
|
inputs,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from vllm.engine.protocol import AsyncEngineClient
|
|||||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||||
apply_chat_template,
|
apply_chat_template,
|
||||||
load_chat_template,
|
load_chat_template,
|
||||||
parse_chat_messages)
|
parse_chat_messages_futures)
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||||
@ -26,7 +26,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
|||||||
TextTokensPrompt)
|
TextTokensPrompt)
|
||||||
from vllm.inputs import TokensPrompt
|
from vllm.inputs import TokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalDataDict
|
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
@ -94,7 +93,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
tokenizer = await self.async_engine_client.get_tokenizer(
|
||||||
lora_request)
|
lora_request)
|
||||||
|
|
||||||
conversation, mm_data_future = parse_chat_messages(
|
conversation, mm_data_future = parse_chat_messages_futures(
|
||||||
request.messages, model_config, tokenizer)
|
request.messages, model_config, tokenizer)
|
||||||
|
|
||||||
tool_dicts = None if request.tools is None else [
|
tool_dicts = None if request.tools is None else [
|
||||||
@ -114,10 +113,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
logger.error("Error in applying chat template from request: %s", e)
|
logger.error("Error in applying chat template from request: %s", e)
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
mm_data: Optional[MultiModalDataDict] = None
|
|
||||||
try:
|
try:
|
||||||
if mm_data_future:
|
mm_data = await mm_data_future
|
||||||
mm_data = await mm_data_future
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error in loading multi-modal data: %s", e)
|
logger.error("Error in loading multi-modal data: %s", e)
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from vllm.config import ModelConfig
|
|||||||
from vllm.engine.protocol import AsyncEngineClient
|
from vllm.engine.protocol import AsyncEngineClient
|
||||||
from vllm.entrypoints.chat_utils import (apply_chat_template,
|
from vllm.entrypoints.chat_utils import (apply_chat_template,
|
||||||
load_chat_template,
|
load_chat_template,
|
||||||
parse_chat_messages)
|
parse_chat_messages_futures)
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -65,10 +65,11 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
if isinstance(request, TokenizeChatRequest):
|
if isinstance(request, TokenizeChatRequest):
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
|
|
||||||
conversation, mm_data_future = parse_chat_messages(
|
conversation, mm_data_future = parse_chat_messages_futures(
|
||||||
request.messages, model_config, tokenizer)
|
request.messages, model_config, tokenizer)
|
||||||
|
|
||||||
if mm_data_future:
|
mm_data = await mm_data_future
|
||||||
|
if mm_data:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Multi-modal inputs are ignored during tokenization")
|
"Multi-modal inputs are ignored during tokenization")
|
||||||
|
|
||||||
|
|||||||
@ -120,6 +120,16 @@ async def async_fetch_audio(
|
|||||||
return librosa.load(BytesIO(audio_bytes), sr=None)
|
return librosa.load(BytesIO(audio_bytes), sr=None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
|
||||||
|
audio, sr = fetch_audio(audio_url)
|
||||||
|
return {"audio": (audio, sr)}
|
||||||
|
|
||||||
|
|
||||||
|
def get_and_parse_image(image_url: str) -> MultiModalDataDict:
|
||||||
|
image = fetch_image(image_url)
|
||||||
|
return {"image": image}
|
||||||
|
|
||||||
|
|
||||||
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
|
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
|
||||||
audio, sr = await async_fetch_audio(audio_url)
|
audio, sr = await async_fetch_audio(audio_url)
|
||||||
return {"audio": (audio, sr)}
|
return {"audio": (audio, sr)}
|
||||||
|
|||||||
@ -52,12 +52,13 @@ class MistralTokenizer:
|
|||||||
assert isinstance(self.tokenizer,
|
assert isinstance(self.tokenizer,
|
||||||
(Tekkenizer, SentencePieceTokenizer)), type(
|
(Tekkenizer, SentencePieceTokenizer)), type(
|
||||||
self.tokenizer)
|
self.tokenizer)
|
||||||
self._is_tekken = isinstance(self.tokenizer, Tekkenizer)
|
|
||||||
|
|
||||||
if self._is_tekken:
|
if (is_tekken := isinstance(self.tokenizer, Tekkenizer)):
|
||||||
# Make sure special tokens will not raise
|
# Make sure special tokens will not raise
|
||||||
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE
|
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE
|
||||||
|
|
||||||
|
self._is_tekken = is_tekken
|
||||||
|
|
||||||
# the following attributes are set to fit VLLM's design
|
# the following attributes are set to fit VLLM's design
|
||||||
self.is_fast = True
|
self.is_fast = True
|
||||||
self.chat_template = True
|
self.chat_template = True
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user