[V1] Extend beyond image modality and support mixed-modality inference with Llava-OneVision (#11685)

Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Roger Wang 2025-01-06 11:58:16 -08:00 committed by GitHub
parent e20c92bb61
commit 91b361ae89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 633 additions and 279 deletions

View File

@ -647,7 +647,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc.
-
- ✅︎
-
- ✅︎
* - `MiniCPMV`
- MiniCPM-V
- T + I<sup>E+</sup>

View File

@ -2,16 +2,22 @@ import base64
import mimetypes
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Dict, Tuple
from typing import TYPE_CHECKING, Dict, NamedTuple, Optional, Tuple
import numpy as np
import pytest
from PIL import Image, ImageChops
from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector,
merge_and_sort_multimodal_metadata,
repeat_and_pad_placeholder_tokens)
if TYPE_CHECKING:
from vllm.multimodal.hasher import MultiModalHashDict
from vllm.multimodal.inputs import MultiModalPlaceholderDict
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
@ -191,3 +197,204 @@ def test_repeat_and_pad_placeholder_tokens(model):
assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids
assert ranges == expected_ranges
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
class TestCase(NamedTuple):
mm_positions: "MultiModalPlaceholderDict"
mm_hashes: Optional["MultiModalHashDict"]
expected_modalities: list[str]
expected_ranges: list[PlaceholderRange]
expected_hashes: Optional[list[str]]
def test_merge_and_sort_multimodal_metadata():
test_cases = [
# Single modality should return result as is but flattened
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2),
]
},
mm_hashes={"image": ["hash1", "hash2"]},
expected_modalities=["image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2),
],
expected_hashes=["hash1", "hash2"],
),
# Single modality without hashes return None for mm hash.
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
]
},
mm_hashes=None,
expected_modalities=["image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
],
expected_hashes=None,
),
# Multiple modalities with hashes should return sorted modalities
# and flattened ranges and hashes.
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
"audio": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=["audio", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
expected_hashes=[
"audio_hash1", "audio_hash2", "image_hash1", "image_hash2"
],
),
# Multiple modalities without hashes should return sorted modalities
# and flattened ranges and None.
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
"audio": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
]
},
mm_hashes=None,
expected_modalities=["audio", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
expected_hashes=None,
),
# Three modalities
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=15, length=7),
PlaceholderRange(offset=22, length=8),
],
"audio": [
PlaceholderRange(offset=0, length=2),
],
"video": [
PlaceholderRange(offset=3, length=4),
PlaceholderRange(offset=7, length=5),
PlaceholderRange(offset=12, length=6),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1"],
"video": ["video_hash1", "video_hash2", "video_hash3"]
},
expected_modalities=["audio", "video", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=4),
PlaceholderRange(offset=7, length=5),
PlaceholderRange(offset=12, length=6),
PlaceholderRange(offset=15, length=7),
PlaceholderRange(offset=22, length=8),
],
expected_hashes=[
"audio_hash1", "video_hash1", "video_hash2", "video_hash3",
"image_hash1", "image_hash2"
],
),
]
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
expected_hashes) in test_cases:
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
mm_positions, mm_hashes)
assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes
def test_merge_and_sort_multimodal_metadata_with_interleaving():
test_cases = [
# <image> <audio> <image> <audio>
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=8, length=2),
],
"audio": [
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=11, length=4),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=[],
expected_ranges=[],
expected_hashes=None,
),
# <image> <image> <video> <audio> <image>
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=20, length=4),
],
"audio": [
PlaceholderRange(offset=5, length=2),
],
"video": [
PlaceholderRange(offset=8, length=5),
]
},
mm_hashes=None,
expected_modalities=[],
expected_ranges=[],
expected_hashes=None,
),
]
for case in test_cases:
with pytest.raises(ValueError) as ex_info:
merge_and_sort_multimodal_metadata(case.mm_positions,
case.mm_hashes)
assert "Interleaved mixed-modality" in str(ex_info.value)

View File

@ -1,6 +1,6 @@
import pytest
from vllm.inputs import token_inputs
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock,
@ -14,14 +14,18 @@ def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
if mm_positions is None:
multi_modal_inputs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
return Request(
request_id=request_id,
inputs=token_inputs(
prompt_token_ids=prompt_token_ids,
multi_modal_placeholders={"image": mm_positions}
if mm_positions else None,
multi_modal_hashes=mm_hashes,
),
prompt=None,
prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,

View File

@ -1,8 +1,7 @@
"""Compare the with and without prefix caching."""
import pytest
from vllm.inputs import token_inputs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
@ -13,12 +12,18 @@ def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
if mm_positions is None:
multi_modal_inputs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
return Request(
request_id=request_id,
inputs=token_inputs(prompt_token_ids=prompt_token_ids,
multi_modal_placeholders={"image": mm_positions}
if mm_positions else None,
multi_modal_hashes=mm_hashes),
prompt=None,
prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,

View File

@ -39,8 +39,12 @@ class SupportsMultiModal(Protocol):
The output embeddings must be one of the following formats:
- A list or tuple of 2D tensors, where each tensor corresponds to
each input image.
each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
NOTE: The returned multimodal embeddings must be in the same order as
the appearances of their corresponding multimodal data item in the
input prompt.
"""
...

View File

@ -35,6 +35,9 @@ from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
@ -223,8 +226,10 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
max_image_tokens = self._get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO)
return max(max_total_frames // max(max_videos, 1), 1)
return max(max_frames_per_video, 1)
def _get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
@ -558,13 +563,15 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
if "pixel_values" in kwargs:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if "pixel_values_videos" in kwargs:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key == "pixel_values" and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
@ -824,21 +831,21 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
if not modalities:
return None
# We make a tuple of each embedding with its modality string. This is a
# temporary workaround for models to handle mixed modalities when
# get_multimodal_embeddings and get_input_embeddings are called
# separately.
# TODO(ywang96): Add support for mixed-modality inference for v1.
multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings.append((vision_embeddings, "image"))
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
multimodal_embeddings.append((video_embeddings, "video"))
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += tuple(vision_embeddings)
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings
@ -850,15 +857,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
for embeddings, modality in multimodal_embeddings:
if modality == "image":
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, embeddings,
self.config.image_token_index)
if modality == "video":
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, embeddings,
self.config.video_token_index)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_index, self.config.video_token_index])
return inputs_embeds
def forward(

View File

@ -972,8 +972,6 @@ def image_input_mapper_for_molmo(
assert len(data) == 1, "Molmo supports only one image per prompt."
data = data[0]
# Remove unused dummy PIL image
data.pop('raw_mm_data', None)
return MultiModalKwargs(data)
@ -1019,7 +1017,6 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
dummy_imgdata = {
"images": out["images"],
"image_input_idx": out["image_input_idx"],
"raw_mm_data": dummy_image,
}
if "image_masks" in out:
dummy_imgdata["image_masks"] = out["image_masks"]

View File

@ -1,4 +1,5 @@
from .base import MultiModalPlaceholderMap, MultiModalPlugin
from .hasher import MultiModalHashDict, MultiModalHasher
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalKwargs,
MultiModalPlaceholderDict, NestedTensors)
@ -18,6 +19,8 @@ __all__ = [
"ModalityData",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalHashDict",
"MultiModalHasher",
"MultiModalKwargs",
"MultiModalPlaceholderDict",
"MultiModalPlaceholderMap",

100
vllm/multimodal/hasher.py Normal file
View File

@ -0,0 +1,100 @@
import pickle
from typing import TYPE_CHECKING, Iterable, Mapping, Optional
import numpy as np
import torch
from blake3 import blake3
from PIL import Image
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.inputs import TokensPrompt
logger = init_logger(__name__)
MultiModalHashDict = Mapping[str, list[str]]
"""
A dictionary containing hashes for items in each modality.
"""
class MultiModalHasher:
@classmethod
def serialize_item(cls, obj: object) -> bytes:
# Simple cases
if isinstance(obj, str):
return obj.encode("utf-8")
if isinstance(obj, bytes):
return obj
if isinstance(obj, Image.Image):
return obj.tobytes()
# Convertible to NumPy arrays
if isinstance(obj, torch.Tensor):
obj = obj.numpy()
if isinstance(obj, (int, float)):
obj = np.array(obj)
if isinstance(obj, np.ndarray):
return obj.tobytes()
logger.warning(
"No serialization method found for %s. "
"Falling back to pickle.", type(obj))
return pickle.dumps(obj)
@classmethod
def item_to_bytes(
cls,
key: str,
obj: object,
) -> Iterable[tuple[bytes, bytes]]:
# Recursive cases
if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj):
yield from cls.item_to_bytes(f"{key}.{i}", elem)
elif isinstance(obj, dict):
for k, v in obj.items():
yield from cls.item_to_bytes(f"{key}.{k}", v)
else:
key_bytes = cls.serialize_item(key)
value_bytes = cls.serialize_item(obj)
yield key_bytes, value_bytes
@classmethod
def hash_kwargs(cls, **kwargs: object) -> str:
hasher = blake3()
for k, v in kwargs.items():
for k_bytes, v_bytes in cls.item_to_bytes(k, v):
hasher.update(k_bytes)
hasher.update(v_bytes)
return hasher.hexdigest()
@classmethod
def hash_prompt_mm_data(
cls, prompt: "TokensPrompt") -> Optional["MultiModalHashDict"]:
"""Hash multimodal data in the user input prompt if they exist."""
if "multi_modal_data" not in prompt:
return None
mm_data = prompt["multi_modal_data"]
if not mm_data:
# mm_data can be None or an empty dict.
return None
mm_items = {
modality: items if isinstance(items, list) else [items]
for modality, items in mm_data.items()
}
mm_hashes = {
modality: [cls.hash_kwargs(**{modality: item}) for item in items]
for modality, items in mm_items.items()
}
return mm_hashes

View File

@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast,
final)
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final)
import numpy as np
import torch
@ -14,6 +14,9 @@ from typing_extensions import NotRequired, TypeAlias
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
if TYPE_CHECKING:
from .hasher import MultiModalHashDict
_T = TypeVar("_T")
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
@ -513,7 +516,7 @@ class MultiModalInputsV2(TypedDict):
mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes: NotRequired[list[str]]
mm_hashes: NotRequired[Optional["MultiModalHashDict"]]
"""The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict

View File

@ -1,4 +1,3 @@
import pickle
import re
from abc import ABC, abstractmethod
from collections import defaultdict
@ -7,18 +6,16 @@ from dataclasses import dataclass, field
from functools import lru_cache
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
import numpy as np
import torch
from blake3 import blake3
from PIL import Image
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from vllm import envs
from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
@ -486,56 +483,6 @@ class ProcessingCache:
logger.debug("ProcessingCache: hit_ratio = %.2f",
cache_stats.hit_ratio)
def _serialize_item(self, obj: object) -> bytes:
# Simple cases
if isinstance(obj, str):
return obj.encode("utf-8")
if isinstance(obj, bytes):
return obj
if isinstance(obj, Image.Image):
return obj.tobytes()
# Convertible to NumPy arrays
if isinstance(obj, torch.Tensor):
obj = obj.numpy()
if isinstance(obj, (int, float)):
obj = np.array(obj)
if isinstance(obj, np.ndarray):
return obj.tobytes()
logger.warning(
"No serialization method found for %s. "
"Falling back to pickle.", type(obj))
return pickle.dumps(obj)
def _item_to_bytes(
self,
key: str,
obj: object,
) -> Iterable[tuple[bytes, bytes]]:
# Recursive cases
if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj):
yield from self._item_to_bytes(f"{key}.{i}", elem)
elif isinstance(obj, dict):
for k, v in obj.items():
yield from self._item_to_bytes(f"{key}.{k}", v)
else:
key_bytes = self._serialize_item(key)
value_bytes = self._serialize_item(obj)
yield key_bytes, value_bytes
def _hash_kwargs(self, **kwargs: object) -> str:
hasher = blake3()
for k, v in kwargs.items():
for k_bytes, v_bytes in self._item_to_bytes(k, v):
hasher.update(k_bytes)
hasher.update(v_bytes)
return hasher.hexdigest()
def get(
self,
model_id: str,
@ -554,9 +501,9 @@ class ProcessingCache:
"""
self._maybe_log_cache_stats()
cache_key = self._hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
return self._cache.get(cache_key)
def put(
@ -571,9 +518,9 @@ class ProcessingCache:
Put a processed multi-modal item into the cache
according to its dependencies (see :meth:`get`).
"""
cache_key = self._hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
self._cache.put(cache_key, output_kwargs)
@ -1049,6 +996,24 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
mm_items = self._to_mm_items(mm_data)
# Create MM hashes (only used in V1)
# TODO: Use these hash keys for caching operations in apply_hf_processor
# instead of rehashing.
if envs.VLLM_USE_V1:
model_id = self.ctx.model_config.model
mm_hashes = {
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs)
for item in items
]
for modality, items in mm_items.items()
}
else:
mm_hashes = None
prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
prompt_text,
mm_items,
@ -1122,6 +1087,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
prompt=prompt_text,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholder_ranges,
)
@ -1174,7 +1140,9 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"tokens.")
total_len = len(prompt_token_ids)
if total_len > seq_len:
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "

View File

@ -1,6 +1,6 @@
from functools import lru_cache
from pathlib import Path
from typing import Optional, TypeVar, Union
from typing import TYPE_CHECKING, Optional, TypeVar, Union
from urllib.parse import ParseResult, urlparse
import numpy as np
@ -25,6 +25,10 @@ cached_get_tokenizer = lru_cache(get_tokenizer)
_M = TypeVar("_M")
if TYPE_CHECKING:
from .hasher import MultiModalHashDict
from .inputs import MultiModalPlaceholderDict
class MediaConnector:
@ -437,3 +441,83 @@ def consecutive_placeholder_ranges(
PlaceholderRange(offset=initial_offset + i * item_size,
length=item_size) for i in range(num_items)
]
def merge_and_sort_multimodal_metadata(
mm_positions: "MultiModalPlaceholderDict",
mm_hashes: Optional["MultiModalHashDict"],
) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]:
"""Given a MultiModalPlaceholderDict, merge all PlaceholderRange
objects from all available modalities into a single list of
PlaceholderRange, sorted by their offset (starting index in the input
sequence) in the ascending order.
Optionally if a MultiModalHashDict is given, same operation will be
applied to the object and the sorted list of hashes will be returned.
Raises:
ValueError: If the input prompt has interleaved placeholders from
different modalities (e.g, "<image><audio><image> Describe the
content.")
Returns:
list[str]: Sorted list of involved modalities.
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
mm_positions.
Optional[list[str]]: Sorted list of all hashes from mm_hashes if
given, None otherwise.
"""
modalities = list(mm_positions.keys())
assert len(modalities) > 0, "No modalities found in the mm_positions."
# For single modality, placeholder ranges and hashes are already sorted
# so we can return the list directly.
if len(modalities) == 1:
if mm_hashes is None:
return modalities, list(mm_positions[modalities[0]]), None
else:
return modalities, list(mm_positions[modalities[0]]), list(
mm_hashes[modalities[0]])
placeholder_lists_with_modality = [(modality, mm_positions[modality])
for modality in modalities]
if mm_hashes is None:
sorted_placeholder_lists = sorted(placeholder_lists_with_modality,
key=lambda x: x[1][0]['offset'])
sorted_hash_lists = None
else:
hashes_lists = [
mm_hashes[modality] for modality in modalities
if modality in mm_hashes
]
sorted_pairs = sorted(zip(placeholder_lists_with_modality,
hashes_lists),
key=lambda x: x[0][1][0]['offset'])
sorted_placeholder_tuple, sorted_hash_tuple = zip(*sorted_pairs)
sorted_placeholder_lists = list(sorted_placeholder_tuple)
sorted_hash_lists = list(sorted_hash_tuple)
sorted_modalities = [modality for modality, _ in sorted_placeholder_lists]
# Flatten sorted list of lists to a single list and verify there is no
# interleaving of placeholders from different modalities.
merged_placeholders: list[PlaceholderRange] = []
for modality, placeholder_list in sorted_placeholder_lists:
if merged_placeholders and placeholder_list[0][
'offset'] < merged_placeholders[-1]['offset']:
raise ValueError(
"Interleaved mixed-modality inference is currently not "
"supported.")
merged_placeholders.extend(placeholder_list)
if sorted_hash_lists is not None:
merged_hashes = []
for hash_list in sorted_hash_lists:
merged_hashes.extend(hash_list)
else:
merged_hashes = None
return sorted_modalities, merged_placeholders, merged_hashes

View File

@ -1,12 +1,14 @@
import enum
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union
import msgspec
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
from vllm.sampling_params import SamplingParams
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
@dataclass
@ -21,13 +23,13 @@ class EngineCoreRequest:
# always be tokenized?
prompt: Optional[str]
prompt_token_ids: List[int]
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_inputs: Optional[List[Optional["MultiModalKwargs"]]]
mm_hashes: Optional[List[str]]
mm_placeholders: Optional[MultiModalPlaceholderDict]
sampling_params: SamplingParams
mm_placeholders: Optional[List["PlaceholderRange"]]
sampling_params: "SamplingParams"
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]
lora_request: Optional["LoRARequest"]
class EngineCoreOutput(

View File

@ -1,10 +1,6 @@
from typing import Any, Dict, List, Optional
import PIL
from blake3 import blake3
from vllm.config import ModelConfig
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
@ -144,66 +140,3 @@ class MMInputMapperServer:
full_mm_inputs.append(mm_input)
return full_mm_inputs
class MMHasher:
def __init__(self):
pass
def hash_dummy_mm_data(
self,
mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
"""Hash user-defined dummy multimodal data used for profiling."""
if mm_data is None:
return None
image_inputs = mm_data['image']
# This is a temporary workaround for models (e.g, Molmo) that
# process multimodal data in the input processor (therefore
# image_inputs is MultiModalKwargs instead of raw input format).
# `raw_mm_data` with the original input format is expected
# in this case.
if isinstance(image_inputs, dict):
assert "raw_mm_data" in image_inputs and isinstance(
image_inputs["raw_mm_data"], PIL.Image.Image)
image_inputs = image_inputs.pop("raw_mm_data")
return self.hash_images(image_inputs)
def hash_prompt_mm_data(self, prompt: PromptType) -> Optional[List[str]]:
"""Hash multimodal data in the user input prompt if they exist."""
if "multi_modal_data" not in prompt:
return None
mm_data = prompt["multi_modal_data"]
if not mm_data:
# mm_data can be None or an empty dict.
return None
image_inputs = mm_data["image"]
return self.hash_images(image_inputs)
def hash_images(self, image_inputs) -> Optional[List[str]]:
"""Hash PIL image objects to strings."""
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
assert len(image_inputs) > 0
ret = []
for image in image_inputs:
assert isinstance(image, PIL.Image.Image)
# Convert image to bytes
bytes = image.tobytes()
# Hash image bytes
hasher = blake3()
hasher.update(bytes)
ret.append(hasher.hexdigest())
return ret

View File

@ -7,14 +7,15 @@ from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalHasher,
MultiModalKwargs, MultiModalRegistry)
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
class Processor:
@ -47,7 +48,6 @@ class Processor:
# Multi-modal hasher (for images)
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
self.mm_hasher = MMHasher()
def process_inputs(
self,
@ -73,11 +73,6 @@ class Processor:
assert priority == 0, "vLLM V1 does not support priority at the moment."
assert trace_headers is None, "vLLM V1 does not support tracing yet."
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash_prompt_mm_data(prompt)
# Process inputs.
preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
@ -108,8 +103,20 @@ class Processor:
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
# Multimodal related.
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
# Use mm_hashes from processed inputs if the model has merged
# input processor.
if decoder_inputs.multi_modal_hashes:
mm_hashes = decoder_inputs.multi_modal_hashes
# Fallback to using MultiModalHasher directly.
else:
mm_hashes = MultiModalHasher.hash_prompt_mm_data(prompt)
# For merged preprocessor, mm_data is already mm_inputs
precomputed_mm_inputs = None
precomputed_mm_inputs: Optional[list[MultiModalKwargs]] = None
decoder_mm_data = decoder_inputs.multi_modal_data
if isinstance(decoder_mm_data, MultiModalKwargs):
# The output of merged multi-modal processor (`decoder_mm_data`)
@ -122,27 +129,67 @@ class Processor:
for item in decoder_mm_data.get_items(modality)
]
# Apply MM mapper
mm_inputs = None
if len(decoder_mm_data) > 0:
mm_inputs = self.mm_input_mapper_client.process_inputs(
decoder_mm_data,
mm_positions = decoder_inputs.multi_modal_placeholders
# Last-mile processing of multimodal metadata and inputs.
if mm_positions:
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
# NOTE: interleaved modalities are not supported.
(
sorted_modalities,
sorted_mm_positions,
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
mm_positions,
mm_hashes,
decoder_inputs.mm_processor_kwargs,
precomputed_mm_inputs,
)
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
# modalities involved AND the model supports merged input processor.
if len(sorted_modalities) > 1 and precomputed_mm_inputs:
modality_order_dict = {
modality: order
for order, modality in enumerate(sorted_modalities)
}
# Sanity check to make sure each multimodal input has only one
# modality key.
for mm_input in precomputed_mm_inputs:
assert len(mm_input.modalities) == 1
# Sort MultiModalKwags to match sorted_mm_positions
precomputed_mm_inputs = sorted(
precomputed_mm_inputs,
key=lambda mm_input: modality_order_dict[list(
mm_input.modalities)[0]])
# Apply mm input cache update (and input mapper if necessary).
sorted_mm_inputs = self.mm_input_mapper_client.process_inputs(
mm_data=decoder_mm_data,
mm_hashes=sorted_mm_hashes,
mm_processor_kwargs=decoder_inputs.mm_processor_kwargs,
precomputed_mm_inputs=precomputed_mm_inputs,
)
else:
sorted_mm_inputs = None
sorted_mm_hashes = None
sorted_mm_positions = None
return EngineCoreRequest(
request_id,
decoder_inputs.prompt,
decoder_inputs.prompt_token_ids,
mm_inputs,
mm_hashes,
decoder_inputs.multi_modal_placeholders,
sampling_params,
eos_token_id,
arrival_time,
lora_request,
request_id=request_id,
prompt=decoder_inputs.prompt,
prompt_token_ids=decoder_inputs.prompt_token_ids,
mm_inputs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
sampling_params=sampling_params,
eos_token_id=eos_token_id,
arrival_time=arrival_time,
lora_request=lora_request,
)
def _validate_model_inputs(self, inputs: ProcessorInputs):

View File

@ -1,15 +1,15 @@
import enum
from typing import TYPE_CHECKING, List, Optional, Union
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.kv_cache_utils import BlockHashType
@ -18,14 +18,17 @@ class Request:
def __init__(
self,
request_id: str,
inputs: DecoderOnlyInputs,
prompt: Optional[str],
prompt_token_ids: List[int],
multi_modal_inputs: Optional[List["MultiModalKwargs"]],
multi_modal_hashes: Optional[List[str]],
multi_modal_placeholders: Optional[List["PlaceholderRange"]],
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.request_id = request_id
self.inputs = SingletonInputsAdapter(inputs)
self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
@ -41,26 +44,21 @@ class Request:
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
self.prompt = self.inputs.prompt
self.prompt_token_ids = self.inputs.prompt_token_ids
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: List[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0
# Multi-modal input metadata.
mm_positions = self.inputs.multi_modal_placeholders
if mm_positions:
# FIXME(woosuk): Support other modalities.
self.mm_positions = mm_positions.get("image", [])
else:
self.mm_positions = []
# Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []
if self.inputs.multi_modal_inputs:
self.mm_inputs = self.inputs.multi_modal_inputs
# Multi-modal related
self.mm_positions = multi_modal_placeholders or []
self.mm_inputs = multi_modal_inputs or []
self.mm_hashes: List[str] = multi_modal_hashes or []
self.mm_hashes: List[str] = self.inputs.multi_modal_hashes
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
assert len(self.mm_inputs) == len(self.mm_hashes)
# Cache the computed kv block hashes of the request to avoid
# recomputing.
@ -70,15 +68,11 @@ class Request:
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls(
request_id=request.request_id,
inputs=token_inputs(
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
multi_modal_data=None,
multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
mm_processor_kwargs=None,
),
prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids,
multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params,
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,

View File

@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -82,12 +82,10 @@ class GPUModelRunner:
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: mm_input_mapper_client and mm_hasher are only used for memory
# profiling.
self.mm_input_mapper_client = MMInputMapperClient(self.model_config)
self.mm_hasher = MMHasher()
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
# NOTE: Initialized input mapper is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
@ -722,8 +720,6 @@ class GPUModelRunner:
]
# Profile with multimodal encoder & encoder cache.
# TODO (ywang96): generalize this beyond image modality since
# mm_input_mapper only supports image inputs.
if self.is_multimodal_model:
# Create dummy batch of multimodal inputs.
@ -735,15 +731,30 @@ class GPUModelRunner:
dummy_mm_data = dummy_request_data.multi_modal_data
# NOTE: Currently model is profiled with a single non-text
# modality even when it supports multiple.
max_tokens_per_mm_item = max(
self.mm_registry.get_max_tokens_per_item_by_modality(
self.model_config).values())
# modality with the max possible input tokens even when
# it supports multiple.
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501
self.model_config)
max_num_mm_items_encoder_budget = min(
self.max_num_encoder_input_tokens,
self.encoder_cache_size) // max_tokens_per_mm_item
dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
# Check how many items of this modality can be supported by
# the encoder cache budget.
encoder_cache_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)
max_num_mm_items_encoder_budget = encoder_cache_budget // \
max_tokens_per_mm_item
# TODO: Allow users to set encoder_cache_budget in case this
# happens.
assert max_num_mm_items_encoder_budget > 0, (
f"Encoder cache budget={encoder_cache_budget} is too small to "
f"support the maximum possible size of multimodal embeddings"
f"={max_tokens_per_mm_item}.")
# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req = max(
self.mm_registry.get_mm_limits_per_prompt(
self.model_config).values())
@ -763,33 +774,24 @@ class GPUModelRunner:
# they are scheduled to be processed separately.
# Case when models have a merged processor, their dummy data is
# already batched `MultiModalKwargs`, therefore we need to "unbatch"
# and take the first item in each batched tensor.
# TODO (ywang96): This is somewhat hacky. Refactor this to be
# consistent with the other case.
# already batched `MultiModalKwargs`, therefore we take the first
# `MultiModalKwargsItem` from the desired modality to profile on.
if isinstance(dummy_mm_data, MultiModalKwargs):
dummy_mm_kwargs = {
k: v[0].unsqueeze(0)
for k, v in dummy_mm_data.items()
}
dummy_mm_item = dummy_mm_data.get_item(
modality=dummy_data_modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
# Case when models have dummy data explicitly defined as
# `MultiModalDataDict`, so they need to be processed through input
# mapper.
# TODO (ywang96): deprecate this path once merged processor is
# supported on all models.
else:
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash_dummy_mm_data(
dummy_mm_data)
mm_kwargs_list = self.mm_input_mapper_client.process_inputs(
mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs(
mm_data=dummy_mm_data,
mm_hashes=mm_hashes,
mm_hashes=None,
mm_processor_kwargs=None,
precomputed_mm_inputs=None)
# Take the first `MultiModalKwargs`
dummy_mm_kwargs = mm_kwargs_list[0]
batched_dummy_mm_inputs = MultiModalKwargs.batch(