[Bugfix] Dictionary MM embeddings for online chat (#30507)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-12-13 15:48:56 +08:00 committed by GitHub
parent fdc135d768
commit b09806e28f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 193 additions and 44 deletions

View File

@ -796,9 +796,13 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
"content": "<|image_1|>\nWhat's in this image?", "content": "<|image_1|>\nWhat's in this image?",
} }
] ]
assert mm_data is not None assert mm_data is not None
assert "image" in mm_data assert "image" in mm_data
assert mm_data["image"] is None assert isinstance(mm_data["image"], list)
assert len(mm_data["image"]) == 1
assert mm_data["image"][0] is None
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
@ -825,10 +829,11 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
# Should have audio in mm_data as None (UUID provided) # Should have audio in mm_data as None (UUID provided)
assert mm_data is not None assert mm_data is not None
assert "audio" in mm_data assert "audio" in mm_data
assert mm_data["audio"] is None assert isinstance(mm_data["audio"], list)
assert len(mm_data["audio"]) == 1
assert mm_data["audio"][0] is None
# UUID should be recorded # UUID should be recorded
assert mm_uuids is not None
assert "audio" in mm_uuids
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid]) _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid])
@ -1121,10 +1126,105 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
mm_data = await mm_future mm_data = await mm_future
assert mm_data is not None assert mm_data is not None
assert "image" in mm_data assert "image" in mm_data
assert mm_data["image"] is None assert isinstance(mm_data["image"], list)
assert len(mm_data["image"]) == 1
assert mm_data["image"][0] is None
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
def test_parse_chat_messages_empty_dict_image_embeds(
phi3v_model_config_image_embeds,
):
"""Test that empty dictionary for image_embeds is handled without errors."""
conversation, mm_data, mm_uuids = parse_chat_messages(
[
{
"role": "user",
"content": [
{"type": "image_embeds", "image_embeds": {}},
{"type": "text", "text": "What's in this image?"},
],
}
],
phi3v_model_config_image_embeds,
content_format="string",
)
# Verify conversation structure
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\nWhat's in this image?",
}
]
# Verify mm_data contains an empty dictionary of embeddings
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], dict)
assert len(mm_data["image"]) == 0
# Verify UUIDs (None since we didn't provide any)
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[None])
def test_parse_chat_messages_multiple_dict_image_embeds(
phi3v_model_config_image_embeds,
):
"""Test that multiple dictionaries for image_embeds is handled without errors."""
# Create two sample image embedding tensors
batch_size = 2
image_embedding_1 = torch.randn(batch_size, 256, 1024)
image_embedding_2 = torch.randn(batch_size, 3)
conversation, mm_data, mm_uuids = parse_chat_messages(
[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"image_embedding_1": tensor2base64(p),
"image_embedding_2": tensor2base64(i),
},
}
for p, i in zip(image_embedding_1, image_embedding_2)
]
+ [
{"type": "text", "text": "Describe these two images."},
],
}
],
phi3v_model_config_image_embeds,
content_format="string",
)
# Verify conversation structure
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
}
]
# Verify mm_data contains a dictionary of multi-embeddings
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], dict)
assert len(mm_data["image"]) == batch_size
# Verify each embedding has the correct shape
assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor)
assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape
assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor)
assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape
# Verify UUIDs (None since we didn't provide any)
_assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_async( async def test_parse_chat_messages_multiple_images_async(
phi3v_model_config, phi3v_model_config,

View File

@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque
from collections.abc import Awaitable, Callable, Iterable from collections.abc import Awaitable, Callable, Iterable
from functools import cached_property, lru_cache, partial from functools import cached_property, lru_cache, partial
from pathlib import Path from pathlib import Path
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
import jinja2 import jinja2
import jinja2.ext import jinja2.ext
@ -53,7 +53,14 @@ from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.collection_utils import is_list_of
from vllm.utils.func_utils import supports_kw from vllm.utils.func_utils import supports_kw
from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
import torch
else:
torch = LazyLoader("torch", globals(), "torch")
logger = init_logger(__name__) logger = init_logger(__name__)
@ -620,6 +627,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
_T = TypeVar("_T") _T = TypeVar("_T")
def _extract_embeds(tensors: list[torch.Tensor]):
if len(tensors) == 0:
return tensors
if len(tensors) == 1:
tensors[0]._is_single_item = True # type: ignore
return tensors[0] # To keep backwards compatibility for single item input
first_shape = tensors[0].shape
if all(t.shape == first_shape for t in tensors):
return torch.stack(tensors)
return tensors
def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str):
embeds_key = f"{modality}_embeds"
embeds = items_by_modality[embeds_key]
if len(embeds) == 0:
return embeds
if is_list_of(embeds, torch.Tensor):
return _extract_embeds(embeds)
if is_list_of(embeds, dict):
if not embeds:
return {}
first_keys = set(embeds[0].keys())
if any(set(item.keys()) != first_keys for item in embeds[1:]):
raise ValueError(
"All dictionaries in the list of embeddings must have the same keys."
)
return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys}
return embeds
class BaseMultiModalItemTracker(ABC, Generic[_T]): class BaseMultiModalItemTracker(ABC, Generic[_T]):
""" """
Tracks multi-modal items in a given request and ensures that the number Tracks multi-modal items in a given request and ensures that the number
@ -688,11 +733,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def all_mm_uuids(self) -> MultiModalUUIDDict | None: def all_mm_uuids(self) -> MultiModalUUIDDict | None:
if not self._items_by_modality: if not self._items_by_modality:
return None return None
mm_uuids = {}
uuids_by_modality = dict(self._uuids_by_modality) uuids_by_modality = dict(self._uuids_by_modality)
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality: if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed") raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
mm_uuids = {}
if "image_embeds" in uuids_by_modality: if "image_embeds" in uuids_by_modality:
mm_uuids["image"] = uuids_by_modality["image_embeds"] mm_uuids["image"] = uuids_by_modality["image_embeds"]
if "image" in uuids_by_modality: if "image" in uuids_by_modality:
@ -703,6 +751,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
if "video" in uuids_by_modality: if "video" in uuids_by_modality:
mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos
return mm_uuids return mm_uuids
@abstractmethod @abstractmethod
@ -714,29 +763,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> MultiModalDataDict | None: def all_mm_data(self) -> MultiModalDataDict | None:
if not self._items_by_modality: if not self._items_by_modality:
return None return None
mm_inputs = {}
items_by_modality = dict(self._items_by_modality) items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality: if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed") raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in items_by_modality and "audio_embeds" in items_by_modality: if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed") raise ValueError("Mixing raw audio and embedding inputs is not allowed")
mm_inputs = {}
if "image_embeds" in items_by_modality: if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"] mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
mm_inputs["image"] = (
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
)
if "image" in items_by_modality: if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality: if "audio_embeds" in items_by_modality:
audio_embeds_lst = items_by_modality["audio_embeds"] mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
mm_inputs["audio"] = (
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
)
if "audio" in items_by_modality: if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality: if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
@ -747,38 +792,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> MultiModalDataDict | None: async def all_mm_data(self) -> MultiModalDataDict | None:
if not self._items_by_modality: if not self._items_by_modality:
return None return None
mm_inputs = {}
items_by_modality = {}
for modality, items in self._items_by_modality.items():
coros = []
for item in items:
if item is not None:
coros.append(item)
else:
coros.append(asyncio.sleep(0))
items_by_modality[modality] = await asyncio.gather(*coros)
coros_by_modality = {
modality: [item or asyncio.sleep(0) for item in items]
for modality, items in self._items_by_modality.items()
}
items_by_modality: dict[str, list[object | None]] = {
modality: await asyncio.gather(*coros)
for modality, coros in coros_by_modality.items()
}
if "image" in items_by_modality and "image_embeds" in items_by_modality: if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed") raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in items_by_modality and "audio_embeds" in items_by_modality: if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed") raise ValueError("Mixing raw audio and embedding inputs is not allowed")
mm_inputs = {}
if "image_embeds" in items_by_modality: if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"] mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
mm_inputs["image"] = (
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
)
if "image" in items_by_modality: if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality: if "audio_embeds" in items_by_modality:
audio_embeds_lst = items_by_modality["audio_embeds"] mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
mm_inputs["audio"] = (
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
)
if "audio" in items_by_modality: if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality: if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":

View File

@ -188,29 +188,39 @@ class InputProcessor:
def _validate_single_prompt(single_prompt: dict | str) -> None: def _validate_single_prompt(single_prompt: dict | str) -> None:
if not isinstance(single_prompt, dict): if not isinstance(single_prompt, dict):
return return
mm_data = single_prompt.get("multi_modal_data") mm_data = single_prompt.get("multi_modal_data")
mm_uuids = single_prompt.get("multi_modal_uuids") mm_uuids = single_prompt.get("multi_modal_uuids")
if not mm_data or not mm_uuids: if not mm_data or not mm_uuids:
return return
import torch
def _get_len(items: object):
if isinstance(items, dict): # Embedding inputs
return _get_len(next(iter(items.values()))) if items else 1
if isinstance(items, list):
return len(items)
if isinstance(items, torch.Tensor):
# To keep backwards compatibility for single item embedding input
return 1 if getattr(items, "_is_single_item", False) else len(items)
return 1
for modality, items in mm_data.items(): for modality, items in mm_data.items():
if modality in mm_uuids: if modality in mm_uuids:
data_len = len(items) if isinstance(items, list) else 1 data_len = _get_len(items)
uuid_len = ( uuid_len = _get_len(mm_uuids[modality])
len(mm_uuids[modality])
if isinstance(mm_uuids[modality], list)
else 1
)
if uuid_len != data_len: if uuid_len != data_len:
raise ValueError( raise ValueError(
f"multi_modal_uuids for modality '{modality}' " f"multi_modal_uuids for modality {modality!r} "
"must have same length as data: got " "must have same length as data: got "
f"{uuid_len} uuids vs " f"{uuid_len} uuids vs {data_len} items."
f"{data_len} items."
) )
else: else:
raise ValueError( raise ValueError(
f"multi_modal_uuids for modality '{modality}' must " f"multi_modal_uuids for modality {modality!r} must "
"be provided if multi_modal_data is provided." "be provided if multi_modal_data is provided."
) )