Support multiple image/audio embeddings per requests (#29988)

Signed-off-by: Jeremy Teboul <jeremyteboul@fb.com>
Co-authored-by: Jeremy Teboul <jeremyteboul@fb.com>
This commit is contained in:
jeremyteboul 2025-12-06 20:34:24 -08:00 committed by GitHub
parent cbedb703cc
commit dce6d229f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 198 additions and 20 deletions

View File

@ -445,7 +445,7 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
For Qwen3-VL, the `image_embeds` should contain both the base image embedding and deepstack features.
#### Audio Embeddings
#### Audio Embedding Inputs
You can pass pre-computed audio embeddings similar to image embeddings:
@ -892,5 +892,11 @@ For Online Serving, you can also skip sending media if you expect cache hits wit
```
!!! note
Only one message can contain `{"type": "image_embeds"}`.
Multiple messages can now contain `{"type": "image_embeds"}`, enabling you to pass multiple image embeddings in a single request (similar to regular images). The number of embeddings is limited by `--limit-mm-per-prompt`.
**Important**: The embedding shape format differs based on the number of embeddings:
- **Single embedding**: 3D tensor of shape `(1, feature_size, hidden_size)`
- **Multiple embeddings**: List of 2D tensors, each of shape `(feature_size, hidden_size)`
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.

View File

@ -6,6 +6,7 @@ from collections.abc import Mapping
from typing import Literal
import pytest
import torch
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.assets.audio import AudioAsset
@ -915,6 +916,183 @@ async def test_parse_chat_messages_audio_embeds_async(
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
def test_parse_chat_messages_multiple_image_embeds(
phi3v_model_config_image_embeds,
):
"""Test that multiple image_embeds in a single message are now supported.
This test validates the fix for the limitation that previously only allowed
one message with {'type': 'image_embeds'}. Now multiple image embeddings
can be provided in a single request, similar to regular images.
"""
# Create two sample image embedding tensors
image_embedding_1 = torch.randn(256, 1024)
image_embedding_2 = torch.randn(128, 1024)
# Encode them as base64 using the convenience function
base64_image_embedding_1 = tensor2base64(image_embedding_1)
base64_image_embedding_2 = tensor2base64(image_embedding_2)
conversation, mm_data, mm_uuids = parse_chat_messages(
[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": base64_image_embedding_1,
},
{
"type": "image_embeds",
"image_embeds": base64_image_embedding_2,
},
{"type": "text", "text": "Describe these two images."},
],
}
],
phi3v_model_config_image_embeds,
content_format="string",
)
# Verify conversation structure
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
}
]
# Verify mm_data contains a list of embeddings (not a single embedding)
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], list)
assert len(mm_data["image"]) == 2
# Verify each embedding has the correct shape
assert isinstance(mm_data["image"][0], torch.Tensor)
assert mm_data["image"][0].shape == image_embedding_1.shape
assert isinstance(mm_data["image"][1], torch.Tensor)
assert mm_data["image"][1].shape == image_embedding_2.shape
# Verify UUIDs (None since we didn't provide any)
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
def test_parse_chat_messages_multiple_image_embeds_with_uuids(
phi3v_model_config_image_embeds,
):
"""Test multiple image_embeds with UUIDs.
This validates that UUIDs are properly tracked for multiple embeddings.
"""
uuid1 = "image-uuid-1"
uuid2 = "image-uuid-2"
conversation, mm_data, mm_uuids = parse_chat_messages(
[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": None,
"uuid": uuid1,
},
{
"type": "image_embeds",
"image_embeds": None,
"uuid": uuid2,
},
{"type": "text", "text": "Compare these images."},
],
}
],
phi3v_model_config_image_embeds,
content_format="string",
)
# Verify conversation structure
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\n<|image_2|>\nCompare these images.",
}
]
# Verify mm_data contains a list with None values (UUID references)
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], list)
assert len(mm_data["image"]) == 2
assert mm_data["image"][0] is None
assert mm_data["image"][1] is None
# Verify UUIDs are correctly tracked
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[uuid1, uuid2])
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_image_embeds_async(
phi3v_model_config_image_embeds,
):
"""Test multiple image_embeds with async parsing.
This validates the AsyncMultiModalItemTracker also supports multiple embeddings.
"""
# Create two sample image embedding tensors
image_embedding_1 = torch.randn(200, 768)
image_embedding_2 = torch.randn(150, 768)
# Encode them as base64 using the convenience function
base64_image_embedding_1 = tensor2base64(image_embedding_1)
base64_image_embedding_2 = tensor2base64(image_embedding_2)
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": base64_image_embedding_1,
},
{
"type": "image_embeds",
"image_embeds": base64_image_embedding_2,
},
{"type": "text", "text": "What do these images show?"},
],
}
],
phi3v_model_config_image_embeds,
content_format="string",
)
# Verify conversation structure
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\n<|image_2|>\nWhat do these images show?",
}
]
# Await the future and verify mm_data
mm_data = await mm_future
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], list)
assert len(mm_data["image"]) == 2
# Verify each embedding has the correct shape
assert isinstance(mm_data["image"][0], torch.Tensor)
assert mm_data["image"][0].shape == image_embedding_1.shape
assert isinstance(mm_data["image"][1], torch.Tensor)
assert mm_data["image"][1].shape == image_embedding_2.shape
# Verify UUIDs
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
@pytest.mark.asyncio
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
phi3v_model_config_image_embeds,

View File

@ -694,16 +694,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "image_embeds" in uuids_by_modality:
image_embeds_uuids = uuids_by_modality["image_embeds"]
if len(image_embeds_uuids) > 1:
raise ValueError("Only one message can have {'type': 'image_embeds'}")
mm_uuids["image"] = uuids_by_modality["image_embeds"]
if "image" in uuids_by_modality:
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
if "audio_embeds" in uuids_by_modality:
audio_embeds_uuids = uuids_by_modality["audio_embeds"]
if len(audio_embeds_uuids) > 1:
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
if "audio" in uuids_by_modality:
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
@ -729,16 +723,16 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError("Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
mm_inputs["image"] = (
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
)
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality:
audio_embeds_lst = items_by_modality["audio_embeds"]
if len(audio_embeds_lst) > 1:
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
mm_inputs["audio"] = audio_embeds_lst[0]
mm_inputs["audio"] = (
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
)
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
@ -771,16 +765,16 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError("Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
mm_inputs["image"] = (
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
)
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality:
audio_embeds_lst = items_by_modality["audio_embeds"]
if len(audio_embeds_lst) > 1:
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
mm_inputs["audio"] = audio_embeds_lst[0]
mm_inputs["audio"] = (
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
)
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality: