diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 527322c71ae4b..40059c9041541 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -796,9 +796,13 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( "content": "<|image_1|>\nWhat's in this image?", } ] + assert mm_data is not None assert "image" in mm_data - assert mm_data["image"] is None + assert isinstance(mm_data["image"], list) + assert len(mm_data["image"]) == 1 + assert mm_data["image"][0] is None + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) @@ -825,10 +829,11 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid( # Should have audio in mm_data as None (UUID provided) assert mm_data is not None assert "audio" in mm_data - assert mm_data["audio"] is None + assert isinstance(mm_data["audio"], list) + assert len(mm_data["audio"]) == 1 + assert mm_data["audio"][0] is None + # UUID should be recorded - assert mm_uuids is not None - assert "audio" in mm_uuids _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid]) @@ -1121,10 +1126,105 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( mm_data = await mm_future assert mm_data is not None assert "image" in mm_data - assert mm_data["image"] is None + assert isinstance(mm_data["image"], list) + assert len(mm_data["image"]) == 1 + assert mm_data["image"][0] is None + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) +def test_parse_chat_messages_empty_dict_image_embeds( + phi3v_model_config_image_embeds, +): + """Test that empty dictionary for image_embeds is handled without errors.""" + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "image_embeds", "image_embeds": {}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], + phi3v_model_config_image_embeds, + content_format="string", + ) + + # Verify conversation structure + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + } + ] + + # Verify mm_data contains an empty dictionary of embeddings + assert mm_data is not None + assert "image" in mm_data + assert isinstance(mm_data["image"], dict) + assert len(mm_data["image"]) == 0 + + # Verify UUIDs (None since we didn't provide any) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_dict_image_embeds( + phi3v_model_config_image_embeds, +): + """Test that multiple dictionaries for image_embeds is handled without errors.""" + # Create two sample image embedding tensors + batch_size = 2 + image_embedding_1 = torch.randn(batch_size, 256, 1024) + image_embedding_2 = torch.randn(batch_size, 3) + + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": { + "image_embedding_1": tensor2base64(p), + "image_embedding_2": tensor2base64(i), + }, + } + for p, i in zip(image_embedding_1, image_embedding_2) + ] + + [ + {"type": "text", "text": "Describe these two images."}, + ], + } + ], + phi3v_model_config_image_embeds, + content_format="string", + ) + + # Verify conversation structure + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nDescribe these two images.", + } + ] + + # Verify mm_data contains a dictionary of multi-embeddings + assert mm_data is not None + assert "image" in mm_data + assert isinstance(mm_data["image"], dict) + assert len(mm_data["image"]) == batch_size + + # Verify each embedding has the correct shape + assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor) + assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape + assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor) + assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape + + # Verify UUIDs (None since we didn't provide any) + _assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None]) + + @pytest.mark.asyncio async def test_parse_chat_messages_multiple_images_async( phi3v_model_config, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index aceaa8bd45b81..5a15dec6f84c1 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque from collections.abc import Awaitable, Callable, Iterable from functools import cached_property, lru_cache, partial from pathlib import Path -from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast import jinja2 import jinja2.ext @@ -53,7 +53,14 @@ from vllm.tokenizers import MistralTokenizer, TokenizerLike from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import random_uuid +from vllm.utils.collection_utils import is_list_of from vllm.utils.func_utils import supports_kw +from vllm.utils.import_utils import LazyLoader + +if TYPE_CHECKING: + import torch +else: + torch = LazyLoader("torch", globals(), "torch") logger = init_logger(__name__) @@ -620,6 +627,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"] _T = TypeVar("_T") +def _extract_embeds(tensors: list[torch.Tensor]): + if len(tensors) == 0: + return tensors + + if len(tensors) == 1: + tensors[0]._is_single_item = True # type: ignore + return tensors[0] # To keep backwards compatibility for single item input + + first_shape = tensors[0].shape + if all(t.shape == first_shape for t in tensors): + return torch.stack(tensors) + + return tensors + + +def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str): + embeds_key = f"{modality}_embeds" + embeds = items_by_modality[embeds_key] + + if len(embeds) == 0: + return embeds + if is_list_of(embeds, torch.Tensor): + return _extract_embeds(embeds) + if is_list_of(embeds, dict): + if not embeds: + return {} + + first_keys = set(embeds[0].keys()) + if any(set(item.keys()) != first_keys for item in embeds[1:]): + raise ValueError( + "All dictionaries in the list of embeddings must have the same keys." + ) + + return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys} + + return embeds + + class BaseMultiModalItemTracker(ABC, Generic[_T]): """ Tracks multi-modal items in a given request and ensures that the number @@ -688,11 +733,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): def all_mm_uuids(self) -> MultiModalUUIDDict | None: if not self._items_by_modality: return None - mm_uuids = {} + uuids_by_modality = dict(self._uuids_by_modality) if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") + if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality: + raise ValueError("Mixing raw audio and embedding inputs is not allowed") + mm_uuids = {} if "image_embeds" in uuids_by_modality: mm_uuids["image"] = uuids_by_modality["image_embeds"] if "image" in uuids_by_modality: @@ -703,6 +751,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios if "video" in uuids_by_modality: mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos + return mm_uuids @abstractmethod @@ -714,29 +763,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]): def all_mm_data(self) -> MultiModalDataDict | None: if not self._items_by_modality: return None - mm_inputs = {} + items_by_modality = dict(self._items_by_modality) if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") if "audio" in items_by_modality and "audio_embeds" in items_by_modality: raise ValueError("Mixing raw audio and embedding inputs is not allowed") + mm_inputs = {} if "image_embeds" in items_by_modality: - image_embeds_lst = items_by_modality["image_embeds"] - mm_inputs["image"] = ( - image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0] - ) + mm_inputs["image"] = _get_embeds_data(items_by_modality, "image") if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio_embeds" in items_by_modality: - audio_embeds_lst = items_by_modality["audio_embeds"] - mm_inputs["audio"] = ( - audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0] - ) + mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio") if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: mm_inputs["video"] = items_by_modality["video"] # A list of videos + return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -747,38 +792,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): async def all_mm_data(self) -> MultiModalDataDict | None: if not self._items_by_modality: return None - mm_inputs = {} - items_by_modality = {} - for modality, items in self._items_by_modality.items(): - coros = [] - for item in items: - if item is not None: - coros.append(item) - else: - coros.append(asyncio.sleep(0)) - items_by_modality[modality] = await asyncio.gather(*coros) + coros_by_modality = { + modality: [item or asyncio.sleep(0) for item in items] + for modality, items in self._items_by_modality.items() + } + items_by_modality: dict[str, list[object | None]] = { + modality: await asyncio.gather(*coros) + for modality, coros in coros_by_modality.items() + } if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") if "audio" in items_by_modality and "audio_embeds" in items_by_modality: raise ValueError("Mixing raw audio and embedding inputs is not allowed") + mm_inputs = {} if "image_embeds" in items_by_modality: - image_embeds_lst = items_by_modality["image_embeds"] - mm_inputs["image"] = ( - image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0] - ) + mm_inputs["image"] = _get_embeds_data(items_by_modality, "image") if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio_embeds" in items_by_modality: - audio_embeds_lst = items_by_modality["audio_embeds"] - mm_inputs["audio"] = ( - audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0] - ) + mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio") if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: mm_inputs["video"] = items_by_modality["video"] # A list of videos + return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index a3c18464d3f52..5bd18cc064cb5 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -188,29 +188,39 @@ class InputProcessor: def _validate_single_prompt(single_prompt: dict | str) -> None: if not isinstance(single_prompt, dict): return + mm_data = single_prompt.get("multi_modal_data") mm_uuids = single_prompt.get("multi_modal_uuids") if not mm_data or not mm_uuids: return + import torch + + def _get_len(items: object): + if isinstance(items, dict): # Embedding inputs + return _get_len(next(iter(items.values()))) if items else 1 + + if isinstance(items, list): + return len(items) + if isinstance(items, torch.Tensor): + # To keep backwards compatibility for single item embedding input + return 1 if getattr(items, "_is_single_item", False) else len(items) + + return 1 + for modality, items in mm_data.items(): if modality in mm_uuids: - data_len = len(items) if isinstance(items, list) else 1 - uuid_len = ( - len(mm_uuids[modality]) - if isinstance(mm_uuids[modality], list) - else 1 - ) + data_len = _get_len(items) + uuid_len = _get_len(mm_uuids[modality]) if uuid_len != data_len: raise ValueError( - f"multi_modal_uuids for modality '{modality}' " + f"multi_modal_uuids for modality {modality!r} " "must have same length as data: got " - f"{uuid_len} uuids vs " - f"{data_len} items." + f"{uuid_len} uuids vs {data_len} items." ) else: raise ValueError( - f"multi_modal_uuids for modality '{modality}' must " + f"multi_modal_uuids for modality {modality!r} must " "be provided if multi_modal_data is provided." )