From dce6d229f7d405a4757aa8f0a76ba62f0e39eaa4 Mon Sep 17 00:00:00 2001 From: jeremyteboul <80506730+jeremyteboul@users.noreply.github.com> Date: Sat, 6 Dec 2025 20:34:24 -0800 Subject: [PATCH] Support multiple image/audio embeddings per requests (#29988) Signed-off-by: Jeremy Teboul Co-authored-by: Jeremy Teboul --- docs/features/multimodal_inputs.md | 10 +- tests/entrypoints/test_chat_utils.py | 178 +++++++++++++++++++++++++++ vllm/entrypoints/chat_utils.py | 30 ++--- 3 files changed, 198 insertions(+), 20 deletions(-) diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index 0adb32a7ac33c..c3fd726e9938c 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -445,7 +445,7 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd For Qwen3-VL, the `image_embeds` should contain both the base image embedding and deepstack features. -#### Audio Embeddings +#### Audio Embedding Inputs You can pass pre-computed audio embeddings similar to image embeddings: @@ -892,5 +892,11 @@ For Online Serving, you can also skip sending media if you expect cache hits wit ``` !!! note - Only one message can contain `{"type": "image_embeds"}`. + Multiple messages can now contain `{"type": "image_embeds"}`, enabling you to pass multiple image embeddings in a single request (similar to regular images). The number of embeddings is limited by `--limit-mm-per-prompt`. + + **Important**: The embedding shape format differs based on the number of embeddings: + + - **Single embedding**: 3D tensor of shape `(1, feature_size, hidden_size)` + - **Multiple embeddings**: List of 2D tensors, each of shape `(feature_size, hidden_size)` + If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc. diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 75be34820bcd7..527322c71ae4b 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -6,6 +6,7 @@ from collections.abc import Mapping from typing import Literal import pytest +import torch from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.assets.audio import AudioAsset @@ -915,6 +916,183 @@ async def test_parse_chat_messages_audio_embeds_async( _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) +def test_parse_chat_messages_multiple_image_embeds( + phi3v_model_config_image_embeds, +): + """Test that multiple image_embeds in a single message are now supported. + + This test validates the fix for the limitation that previously only allowed + one message with {'type': 'image_embeds'}. Now multiple image embeddings + can be provided in a single request, similar to regular images. + """ + # Create two sample image embedding tensors + image_embedding_1 = torch.randn(256, 1024) + image_embedding_2 = torch.randn(128, 1024) + + # Encode them as base64 using the convenience function + base64_image_embedding_1 = tensor2base64(image_embedding_1) + base64_image_embedding_2 = tensor2base64(image_embedding_2) + + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": base64_image_embedding_1, + }, + { + "type": "image_embeds", + "image_embeds": base64_image_embedding_2, + }, + {"type": "text", "text": "Describe these two images."}, + ], + } + ], + phi3v_model_config_image_embeds, + content_format="string", + ) + + # Verify conversation structure + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nDescribe these two images.", + } + ] + + # Verify mm_data contains a list of embeddings (not a single embedding) + assert mm_data is not None + assert "image" in mm_data + assert isinstance(mm_data["image"], list) + assert len(mm_data["image"]) == 2 + + # Verify each embedding has the correct shape + assert isinstance(mm_data["image"][0], torch.Tensor) + assert mm_data["image"][0].shape == image_embedding_1.shape + assert isinstance(mm_data["image"][1], torch.Tensor) + assert mm_data["image"][1].shape == image_embedding_2.shape + + # Verify UUIDs (None since we didn't provide any) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_multiple_image_embeds_with_uuids( + phi3v_model_config_image_embeds, +): + """Test multiple image_embeds with UUIDs. + + This validates that UUIDs are properly tracked for multiple embeddings. + """ + uuid1 = "image-uuid-1" + uuid2 = "image-uuid-2" + + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": None, + "uuid": uuid1, + }, + { + "type": "image_embeds", + "image_embeds": None, + "uuid": uuid2, + }, + {"type": "text", "text": "Compare these images."}, + ], + } + ], + phi3v_model_config_image_embeds, + content_format="string", + ) + + # Verify conversation structure + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nCompare these images.", + } + ] + + # Verify mm_data contains a list with None values (UUID references) + assert mm_data is not None + assert "image" in mm_data + assert isinstance(mm_data["image"], list) + assert len(mm_data["image"]) == 2 + assert mm_data["image"][0] is None + assert mm_data["image"][1] is None + + # Verify UUIDs are correctly tracked + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[uuid1, uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_image_embeds_async( + phi3v_model_config_image_embeds, +): + """Test multiple image_embeds with async parsing. + + This validates the AsyncMultiModalItemTracker also supports multiple embeddings. + """ + # Create two sample image embedding tensors + image_embedding_1 = torch.randn(200, 768) + image_embedding_2 = torch.randn(150, 768) + + # Encode them as base64 using the convenience function + base64_image_embedding_1 = tensor2base64(image_embedding_1) + base64_image_embedding_2 = tensor2base64(image_embedding_2) + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": base64_image_embedding_1, + }, + { + "type": "image_embeds", + "image_embeds": base64_image_embedding_2, + }, + {"type": "text", "text": "What do these images show?"}, + ], + } + ], + phi3v_model_config_image_embeds, + content_format="string", + ) + + # Verify conversation structure + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat do these images show?", + } + ] + + # Await the future and verify mm_data + mm_data = await mm_future + assert mm_data is not None + assert "image" in mm_data + assert isinstance(mm_data["image"], list) + assert len(mm_data["image"]) == 2 + + # Verify each embedding has the correct shape + assert isinstance(mm_data["image"][0], torch.Tensor) + assert mm_data["image"][0].shape == image_embedding_1.shape + assert isinstance(mm_data["image"][1], torch.Tensor) + assert mm_data["image"][1].shape == image_embedding_2.shape + + # Verify UUIDs + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + @pytest.mark.asyncio async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( phi3v_model_config_image_embeds, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 077fe681bc5b8..aceaa8bd45b81 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -694,16 +694,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): raise ValueError("Mixing raw image and embedding inputs is not allowed") if "image_embeds" in uuids_by_modality: - image_embeds_uuids = uuids_by_modality["image_embeds"] - if len(image_embeds_uuids) > 1: - raise ValueError("Only one message can have {'type': 'image_embeds'}") mm_uuids["image"] = uuids_by_modality["image_embeds"] if "image" in uuids_by_modality: mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images if "audio_embeds" in uuids_by_modality: - audio_embeds_uuids = uuids_by_modality["audio_embeds"] - if len(audio_embeds_uuids) > 1: - raise ValueError("Only one message can have {'type': 'audio_embeds'}") mm_uuids["audio"] = uuids_by_modality["audio_embeds"] if "audio" in uuids_by_modality: mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios @@ -729,16 +723,16 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]): if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] - if len(image_embeds_lst) > 1: - raise ValueError("Only one message can have {'type': 'image_embeds'}") - mm_inputs["image"] = image_embeds_lst[0] + mm_inputs["image"] = ( + image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0] + ) if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio_embeds" in items_by_modality: audio_embeds_lst = items_by_modality["audio_embeds"] - if len(audio_embeds_lst) > 1: - raise ValueError("Only one message can have {'type': 'audio_embeds'}") - mm_inputs["audio"] = audio_embeds_lst[0] + mm_inputs["audio"] = ( + audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0] + ) if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: @@ -771,16 +765,16 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] - if len(image_embeds_lst) > 1: - raise ValueError("Only one message can have {'type': 'image_embeds'}") - mm_inputs["image"] = image_embeds_lst[0] + mm_inputs["image"] = ( + image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0] + ) if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio_embeds" in items_by_modality: audio_embeds_lst = items_by_modality["audio_embeds"] - if len(audio_embeds_lst) > 1: - raise ValueError("Only one message can have {'type': 'audio_embeds'}") - mm_inputs["audio"] = audio_embeds_lst[0] + mm_inputs["audio"] = ( + audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0] + ) if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: