mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 00:35:35 +08:00
[Bugfix] Dictionary MM embeddings for online chat (#30507)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
fdc135d768
commit
b09806e28f
@ -796,9 +796,13 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
||||
"content": "<|image_1|>\nWhat's in this image?",
|
||||
}
|
||||
]
|
||||
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert mm_data["image"] is None
|
||||
assert isinstance(mm_data["image"], list)
|
||||
assert len(mm_data["image"]) == 1
|
||||
assert mm_data["image"][0] is None
|
||||
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||
|
||||
|
||||
@ -825,10 +829,11 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
||||
# Should have audio in mm_data as None (UUID provided)
|
||||
assert mm_data is not None
|
||||
assert "audio" in mm_data
|
||||
assert mm_data["audio"] is None
|
||||
assert isinstance(mm_data["audio"], list)
|
||||
assert len(mm_data["audio"]) == 1
|
||||
assert mm_data["audio"][0] is None
|
||||
|
||||
# UUID should be recorded
|
||||
assert mm_uuids is not None
|
||||
assert "audio" in mm_uuids
|
||||
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid])
|
||||
|
||||
|
||||
@ -1121,10 +1126,105 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||
mm_data = await mm_future
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert mm_data["image"] is None
|
||||
assert isinstance(mm_data["image"], list)
|
||||
assert len(mm_data["image"]) == 1
|
||||
assert mm_data["image"][0] is None
|
||||
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||
|
||||
|
||||
def test_parse_chat_messages_empty_dict_image_embeds(
|
||||
phi3v_model_config_image_embeds,
|
||||
):
|
||||
"""Test that empty dictionary for image_embeds is handled without errors."""
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_embeds", "image_embeds": {}},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
# Verify conversation structure
|
||||
assert conversation == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\nWhat's in this image?",
|
||||
}
|
||||
]
|
||||
|
||||
# Verify mm_data contains an empty dictionary of embeddings
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert isinstance(mm_data["image"], dict)
|
||||
assert len(mm_data["image"]) == 0
|
||||
|
||||
# Verify UUIDs (None since we didn't provide any)
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[None])
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_dict_image_embeds(
|
||||
phi3v_model_config_image_embeds,
|
||||
):
|
||||
"""Test that multiple dictionaries for image_embeds is handled without errors."""
|
||||
# Create two sample image embedding tensors
|
||||
batch_size = 2
|
||||
image_embedding_1 = torch.randn(batch_size, 256, 1024)
|
||||
image_embedding_2 = torch.randn(batch_size, 3)
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"image_embedding_1": tensor2base64(p),
|
||||
"image_embedding_2": tensor2base64(i),
|
||||
},
|
||||
}
|
||||
for p, i in zip(image_embedding_1, image_embedding_2)
|
||||
]
|
||||
+ [
|
||||
{"type": "text", "text": "Describe these two images."},
|
||||
],
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
# Verify conversation structure
|
||||
assert conversation == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
|
||||
}
|
||||
]
|
||||
|
||||
# Verify mm_data contains a dictionary of multi-embeddings
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert isinstance(mm_data["image"], dict)
|
||||
assert len(mm_data["image"]) == batch_size
|
||||
|
||||
# Verify each embedding has the correct shape
|
||||
assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor)
|
||||
assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape
|
||||
assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor)
|
||||
assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape
|
||||
|
||||
# Verify UUIDs (None since we didn't provide any)
|
||||
_assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_async(
|
||||
phi3v_model_config,
|
||||
|
||||
@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
@ -53,7 +53,14 @@ from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -620,6 +627,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _extract_embeds(tensors: list[torch.Tensor]):
|
||||
if len(tensors) == 0:
|
||||
return tensors
|
||||
|
||||
if len(tensors) == 1:
|
||||
tensors[0]._is_single_item = True # type: ignore
|
||||
return tensors[0] # To keep backwards compatibility for single item input
|
||||
|
||||
first_shape = tensors[0].shape
|
||||
if all(t.shape == first_shape for t in tensors):
|
||||
return torch.stack(tensors)
|
||||
|
||||
return tensors
|
||||
|
||||
|
||||
def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str):
|
||||
embeds_key = f"{modality}_embeds"
|
||||
embeds = items_by_modality[embeds_key]
|
||||
|
||||
if len(embeds) == 0:
|
||||
return embeds
|
||||
if is_list_of(embeds, torch.Tensor):
|
||||
return _extract_embeds(embeds)
|
||||
if is_list_of(embeds, dict):
|
||||
if not embeds:
|
||||
return {}
|
||||
|
||||
first_keys = set(embeds[0].keys())
|
||||
if any(set(item.keys()) != first_keys for item in embeds[1:]):
|
||||
raise ValueError(
|
||||
"All dictionaries in the list of embeddings must have the same keys."
|
||||
)
|
||||
|
||||
return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys}
|
||||
|
||||
return embeds
|
||||
|
||||
|
||||
class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
"""
|
||||
Tracks multi-modal items in a given request and ensures that the number
|
||||
@ -688,11 +733,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def all_mm_uuids(self) -> MultiModalUUIDDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_uuids = {}
|
||||
|
||||
uuids_by_modality = dict(self._uuids_by_modality)
|
||||
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_uuids = {}
|
||||
if "image_embeds" in uuids_by_modality:
|
||||
mm_uuids["image"] = uuids_by_modality["image_embeds"]
|
||||
if "image" in uuids_by_modality:
|
||||
@ -703,6 +751,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
|
||||
if "video" in uuids_by_modality:
|
||||
mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos
|
||||
|
||||
return mm_uuids
|
||||
|
||||
@abstractmethod
|
||||
@ -714,29 +763,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
def all_mm_data(self) -> MultiModalDataDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
|
||||
items_by_modality = dict(self._items_by_modality)
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_inputs = {}
|
||||
if "image_embeds" in items_by_modality:
|
||||
image_embeds_lst = items_by_modality["image_embeds"]
|
||||
mm_inputs["image"] = (
|
||||
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
|
||||
)
|
||||
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
|
||||
if "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
if "audio_embeds" in items_by_modality:
|
||||
audio_embeds_lst = items_by_modality["audio_embeds"]
|
||||
mm_inputs["audio"] = (
|
||||
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
|
||||
)
|
||||
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
|
||||
if "audio" in items_by_modality:
|
||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||
if "video" in items_by_modality:
|
||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
@ -747,38 +792,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||
async def all_mm_data(self) -> MultiModalDataDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
items_by_modality = {}
|
||||
for modality, items in self._items_by_modality.items():
|
||||
coros = []
|
||||
for item in items:
|
||||
if item is not None:
|
||||
coros.append(item)
|
||||
else:
|
||||
coros.append(asyncio.sleep(0))
|
||||
items_by_modality[modality] = await asyncio.gather(*coros)
|
||||
|
||||
coros_by_modality = {
|
||||
modality: [item or asyncio.sleep(0) for item in items]
|
||||
for modality, items in self._items_by_modality.items()
|
||||
}
|
||||
items_by_modality: dict[str, list[object | None]] = {
|
||||
modality: await asyncio.gather(*coros)
|
||||
for modality, coros in coros_by_modality.items()
|
||||
}
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_inputs = {}
|
||||
if "image_embeds" in items_by_modality:
|
||||
image_embeds_lst = items_by_modality["image_embeds"]
|
||||
mm_inputs["image"] = (
|
||||
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
|
||||
)
|
||||
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
|
||||
if "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
if "audio_embeds" in items_by_modality:
|
||||
audio_embeds_lst = items_by_modality["audio_embeds"]
|
||||
mm_inputs["audio"] = (
|
||||
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
|
||||
)
|
||||
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
|
||||
if "audio" in items_by_modality:
|
||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||
if "video" in items_by_modality:
|
||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
|
||||
@ -188,29 +188,39 @@ class InputProcessor:
|
||||
def _validate_single_prompt(single_prompt: dict | str) -> None:
|
||||
if not isinstance(single_prompt, dict):
|
||||
return
|
||||
|
||||
mm_data = single_prompt.get("multi_modal_data")
|
||||
mm_uuids = single_prompt.get("multi_modal_uuids")
|
||||
if not mm_data or not mm_uuids:
|
||||
return
|
||||
|
||||
import torch
|
||||
|
||||
def _get_len(items: object):
|
||||
if isinstance(items, dict): # Embedding inputs
|
||||
return _get_len(next(iter(items.values()))) if items else 1
|
||||
|
||||
if isinstance(items, list):
|
||||
return len(items)
|
||||
if isinstance(items, torch.Tensor):
|
||||
# To keep backwards compatibility for single item embedding input
|
||||
return 1 if getattr(items, "_is_single_item", False) else len(items)
|
||||
|
||||
return 1
|
||||
|
||||
for modality, items in mm_data.items():
|
||||
if modality in mm_uuids:
|
||||
data_len = len(items) if isinstance(items, list) else 1
|
||||
uuid_len = (
|
||||
len(mm_uuids[modality])
|
||||
if isinstance(mm_uuids[modality], list)
|
||||
else 1
|
||||
)
|
||||
data_len = _get_len(items)
|
||||
uuid_len = _get_len(mm_uuids[modality])
|
||||
if uuid_len != data_len:
|
||||
raise ValueError(
|
||||
f"multi_modal_uuids for modality '{modality}' "
|
||||
f"multi_modal_uuids for modality {modality!r} "
|
||||
"must have same length as data: got "
|
||||
f"{uuid_len} uuids vs "
|
||||
f"{data_len} items."
|
||||
f"{uuid_len} uuids vs {data_len} items."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"multi_modal_uuids for modality '{modality}' must "
|
||||
f"multi_modal_uuids for modality {modality!r} must "
|
||||
"be provided if multi_modal_data is provided."
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user