mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 04:35:01 +08:00
[Bugfix] Dictionary MM embeddings for online chat (#30507)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
fdc135d768
commit
b09806e28f
@ -796,9 +796,13 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
|||||||
"content": "<|image_1|>\nWhat's in this image?",
|
"content": "<|image_1|>\nWhat's in this image?",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
assert mm_data is not None
|
assert mm_data is not None
|
||||||
assert "image" in mm_data
|
assert "image" in mm_data
|
||||||
assert mm_data["image"] is None
|
assert isinstance(mm_data["image"], list)
|
||||||
|
assert len(mm_data["image"]) == 1
|
||||||
|
assert mm_data["image"][0] is None
|
||||||
|
|
||||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||||
|
|
||||||
|
|
||||||
@ -825,10 +829,11 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
|||||||
# Should have audio in mm_data as None (UUID provided)
|
# Should have audio in mm_data as None (UUID provided)
|
||||||
assert mm_data is not None
|
assert mm_data is not None
|
||||||
assert "audio" in mm_data
|
assert "audio" in mm_data
|
||||||
assert mm_data["audio"] is None
|
assert isinstance(mm_data["audio"], list)
|
||||||
|
assert len(mm_data["audio"]) == 1
|
||||||
|
assert mm_data["audio"][0] is None
|
||||||
|
|
||||||
# UUID should be recorded
|
# UUID should be recorded
|
||||||
assert mm_uuids is not None
|
|
||||||
assert "audio" in mm_uuids
|
|
||||||
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid])
|
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid])
|
||||||
|
|
||||||
|
|
||||||
@ -1121,10 +1126,105 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
|||||||
mm_data = await mm_future
|
mm_data = await mm_future
|
||||||
assert mm_data is not None
|
assert mm_data is not None
|
||||||
assert "image" in mm_data
|
assert "image" in mm_data
|
||||||
assert mm_data["image"] is None
|
assert isinstance(mm_data["image"], list)
|
||||||
|
assert len(mm_data["image"]) == 1
|
||||||
|
assert mm_data["image"][0] is None
|
||||||
|
|
||||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_empty_dict_image_embeds(
|
||||||
|
phi3v_model_config_image_embeds,
|
||||||
|
):
|
||||||
|
"""Test that empty dictionary for image_embeds is handled without errors."""
|
||||||
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_embeds", "image_embeds": {}},
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
phi3v_model_config_image_embeds,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify conversation structure
|
||||||
|
assert conversation == [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "<|image_1|>\nWhat's in this image?",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Verify mm_data contains an empty dictionary of embeddings
|
||||||
|
assert mm_data is not None
|
||||||
|
assert "image" in mm_data
|
||||||
|
assert isinstance(mm_data["image"], dict)
|
||||||
|
assert len(mm_data["image"]) == 0
|
||||||
|
|
||||||
|
# Verify UUIDs (None since we didn't provide any)
|
||||||
|
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[None])
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_multiple_dict_image_embeds(
|
||||||
|
phi3v_model_config_image_embeds,
|
||||||
|
):
|
||||||
|
"""Test that multiple dictionaries for image_embeds is handled without errors."""
|
||||||
|
# Create two sample image embedding tensors
|
||||||
|
batch_size = 2
|
||||||
|
image_embedding_1 = torch.randn(batch_size, 256, 1024)
|
||||||
|
image_embedding_2 = torch.randn(batch_size, 3)
|
||||||
|
|
||||||
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_embeds",
|
||||||
|
"image_embeds": {
|
||||||
|
"image_embedding_1": tensor2base64(p),
|
||||||
|
"image_embedding_2": tensor2base64(i),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for p, i in zip(image_embedding_1, image_embedding_2)
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
{"type": "text", "text": "Describe these two images."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
phi3v_model_config_image_embeds,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify conversation structure
|
||||||
|
assert conversation == [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Verify mm_data contains a dictionary of multi-embeddings
|
||||||
|
assert mm_data is not None
|
||||||
|
assert "image" in mm_data
|
||||||
|
assert isinstance(mm_data["image"], dict)
|
||||||
|
assert len(mm_data["image"]) == batch_size
|
||||||
|
|
||||||
|
# Verify each embedding has the correct shape
|
||||||
|
assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor)
|
||||||
|
assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape
|
||||||
|
assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor)
|
||||||
|
assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape
|
||||||
|
|
||||||
|
# Verify UUIDs (None since we didn't provide any)
|
||||||
|
_assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parse_chat_messages_multiple_images_async(
|
async def test_parse_chat_messages_multiple_images_async(
|
||||||
phi3v_model_config,
|
phi3v_model_config,
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque
|
|||||||
from collections.abc import Awaitable, Callable, Iterable
|
from collections.abc import Awaitable, Callable, Iterable
|
||||||
from functools import cached_property, lru_cache, partial
|
from functools import cached_property, lru_cache, partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
import jinja2.ext
|
import jinja2.ext
|
||||||
@ -53,7 +53,14 @@ from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
|||||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
from vllm.utils.collection_utils import is_list_of
|
||||||
from vllm.utils.func_utils import supports_kw
|
from vllm.utils.func_utils import supports_kw
|
||||||
|
from vllm.utils.import_utils import LazyLoader
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
else:
|
||||||
|
torch = LazyLoader("torch", globals(), "torch")
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -620,6 +627,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
|
|||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_embeds(tensors: list[torch.Tensor]):
|
||||||
|
if len(tensors) == 0:
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
if len(tensors) == 1:
|
||||||
|
tensors[0]._is_single_item = True # type: ignore
|
||||||
|
return tensors[0] # To keep backwards compatibility for single item input
|
||||||
|
|
||||||
|
first_shape = tensors[0].shape
|
||||||
|
if all(t.shape == first_shape for t in tensors):
|
||||||
|
return torch.stack(tensors)
|
||||||
|
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
|
def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str):
|
||||||
|
embeds_key = f"{modality}_embeds"
|
||||||
|
embeds = items_by_modality[embeds_key]
|
||||||
|
|
||||||
|
if len(embeds) == 0:
|
||||||
|
return embeds
|
||||||
|
if is_list_of(embeds, torch.Tensor):
|
||||||
|
return _extract_embeds(embeds)
|
||||||
|
if is_list_of(embeds, dict):
|
||||||
|
if not embeds:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
first_keys = set(embeds[0].keys())
|
||||||
|
if any(set(item.keys()) != first_keys for item in embeds[1:]):
|
||||||
|
raise ValueError(
|
||||||
|
"All dictionaries in the list of embeddings must have the same keys."
|
||||||
|
)
|
||||||
|
|
||||||
|
return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys}
|
||||||
|
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
|
||||||
class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||||
"""
|
"""
|
||||||
Tracks multi-modal items in a given request and ensures that the number
|
Tracks multi-modal items in a given request and ensures that the number
|
||||||
@ -688,11 +733,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
def all_mm_uuids(self) -> MultiModalUUIDDict | None:
|
def all_mm_uuids(self) -> MultiModalUUIDDict | None:
|
||||||
if not self._items_by_modality:
|
if not self._items_by_modality:
|
||||||
return None
|
return None
|
||||||
mm_uuids = {}
|
|
||||||
uuids_by_modality = dict(self._uuids_by_modality)
|
uuids_by_modality = dict(self._uuids_by_modality)
|
||||||
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
|
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
|
||||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||||
|
if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality:
|
||||||
|
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||||
|
|
||||||
|
mm_uuids = {}
|
||||||
if "image_embeds" in uuids_by_modality:
|
if "image_embeds" in uuids_by_modality:
|
||||||
mm_uuids["image"] = uuids_by_modality["image_embeds"]
|
mm_uuids["image"] = uuids_by_modality["image_embeds"]
|
||||||
if "image" in uuids_by_modality:
|
if "image" in uuids_by_modality:
|
||||||
@ -703,6 +751,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
|
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
|
||||||
if "video" in uuids_by_modality:
|
if "video" in uuids_by_modality:
|
||||||
mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos
|
mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos
|
||||||
|
|
||||||
return mm_uuids
|
return mm_uuids
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -714,29 +763,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
|||||||
def all_mm_data(self) -> MultiModalDataDict | None:
|
def all_mm_data(self) -> MultiModalDataDict | None:
|
||||||
if not self._items_by_modality:
|
if not self._items_by_modality:
|
||||||
return None
|
return None
|
||||||
mm_inputs = {}
|
|
||||||
items_by_modality = dict(self._items_by_modality)
|
items_by_modality = dict(self._items_by_modality)
|
||||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||||
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||||
|
|
||||||
|
mm_inputs = {}
|
||||||
if "image_embeds" in items_by_modality:
|
if "image_embeds" in items_by_modality:
|
||||||
image_embeds_lst = items_by_modality["image_embeds"]
|
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
|
||||||
mm_inputs["image"] = (
|
|
||||||
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
|
|
||||||
)
|
|
||||||
if "image" in items_by_modality:
|
if "image" in items_by_modality:
|
||||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||||
if "audio_embeds" in items_by_modality:
|
if "audio_embeds" in items_by_modality:
|
||||||
audio_embeds_lst = items_by_modality["audio_embeds"]
|
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
|
||||||
mm_inputs["audio"] = (
|
|
||||||
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
|
|
||||||
)
|
|
||||||
if "audio" in items_by_modality:
|
if "audio" in items_by_modality:
|
||||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||||
if "video" in items_by_modality:
|
if "video" in items_by_modality:
|
||||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
@ -747,38 +792,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
|||||||
async def all_mm_data(self) -> MultiModalDataDict | None:
|
async def all_mm_data(self) -> MultiModalDataDict | None:
|
||||||
if not self._items_by_modality:
|
if not self._items_by_modality:
|
||||||
return None
|
return None
|
||||||
mm_inputs = {}
|
|
||||||
items_by_modality = {}
|
|
||||||
for modality, items in self._items_by_modality.items():
|
|
||||||
coros = []
|
|
||||||
for item in items:
|
|
||||||
if item is not None:
|
|
||||||
coros.append(item)
|
|
||||||
else:
|
|
||||||
coros.append(asyncio.sleep(0))
|
|
||||||
items_by_modality[modality] = await asyncio.gather(*coros)
|
|
||||||
|
|
||||||
|
coros_by_modality = {
|
||||||
|
modality: [item or asyncio.sleep(0) for item in items]
|
||||||
|
for modality, items in self._items_by_modality.items()
|
||||||
|
}
|
||||||
|
items_by_modality: dict[str, list[object | None]] = {
|
||||||
|
modality: await asyncio.gather(*coros)
|
||||||
|
for modality, coros in coros_by_modality.items()
|
||||||
|
}
|
||||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||||
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||||
|
|
||||||
|
mm_inputs = {}
|
||||||
if "image_embeds" in items_by_modality:
|
if "image_embeds" in items_by_modality:
|
||||||
image_embeds_lst = items_by_modality["image_embeds"]
|
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
|
||||||
mm_inputs["image"] = (
|
|
||||||
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
|
|
||||||
)
|
|
||||||
if "image" in items_by_modality:
|
if "image" in items_by_modality:
|
||||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||||
if "audio_embeds" in items_by_modality:
|
if "audio_embeds" in items_by_modality:
|
||||||
audio_embeds_lst = items_by_modality["audio_embeds"]
|
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
|
||||||
mm_inputs["audio"] = (
|
|
||||||
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
|
|
||||||
)
|
|
||||||
if "audio" in items_by_modality:
|
if "audio" in items_by_modality:
|
||||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||||
if "video" in items_by_modality:
|
if "video" in items_by_modality:
|
||||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
|
|||||||
@ -188,29 +188,39 @@ class InputProcessor:
|
|||||||
def _validate_single_prompt(single_prompt: dict | str) -> None:
|
def _validate_single_prompt(single_prompt: dict | str) -> None:
|
||||||
if not isinstance(single_prompt, dict):
|
if not isinstance(single_prompt, dict):
|
||||||
return
|
return
|
||||||
|
|
||||||
mm_data = single_prompt.get("multi_modal_data")
|
mm_data = single_prompt.get("multi_modal_data")
|
||||||
mm_uuids = single_prompt.get("multi_modal_uuids")
|
mm_uuids = single_prompt.get("multi_modal_uuids")
|
||||||
if not mm_data or not mm_uuids:
|
if not mm_data or not mm_uuids:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def _get_len(items: object):
|
||||||
|
if isinstance(items, dict): # Embedding inputs
|
||||||
|
return _get_len(next(iter(items.values()))) if items else 1
|
||||||
|
|
||||||
|
if isinstance(items, list):
|
||||||
|
return len(items)
|
||||||
|
if isinstance(items, torch.Tensor):
|
||||||
|
# To keep backwards compatibility for single item embedding input
|
||||||
|
return 1 if getattr(items, "_is_single_item", False) else len(items)
|
||||||
|
|
||||||
|
return 1
|
||||||
|
|
||||||
for modality, items in mm_data.items():
|
for modality, items in mm_data.items():
|
||||||
if modality in mm_uuids:
|
if modality in mm_uuids:
|
||||||
data_len = len(items) if isinstance(items, list) else 1
|
data_len = _get_len(items)
|
||||||
uuid_len = (
|
uuid_len = _get_len(mm_uuids[modality])
|
||||||
len(mm_uuids[modality])
|
|
||||||
if isinstance(mm_uuids[modality], list)
|
|
||||||
else 1
|
|
||||||
)
|
|
||||||
if uuid_len != data_len:
|
if uuid_len != data_len:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"multi_modal_uuids for modality '{modality}' "
|
f"multi_modal_uuids for modality {modality!r} "
|
||||||
"must have same length as data: got "
|
"must have same length as data: got "
|
||||||
f"{uuid_len} uuids vs "
|
f"{uuid_len} uuids vs {data_len} items."
|
||||||
f"{data_len} items."
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"multi_modal_uuids for modality '{modality}' must "
|
f"multi_modal_uuids for modality {modality!r} must "
|
||||||
"be provided if multi_modal_data is provided."
|
"be provided if multi_modal_data is provided."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user