diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 475a68bc642b9..19dd242f16eb6 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -415,9 +415,12 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, num_image_patches), ) + # Use overrides if provided; fallback to data-dependent hashing. - mm_hashes = (mm_uuids if mm_uuids is not None else self._hash_mm_items( - mm_items, hf_processor_mm_kwargs, tokenization_kwargs)) + mm_hashes = self._hash_mm_items(mm_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_uuids=mm_uuids) return MultiModalInputs( type="multimodal", diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index e00c10fb66eeb..3539517ed45ee 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -14,7 +14,7 @@ import numpy as np from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated from vllm.utils import LazyLoader, full_groupby, is_list_of -from vllm.utils.jsontree import JSONTree, json_map_leaves +from vllm.utils.jsontree import json_map_leaves if TYPE_CHECKING: import torch @@ -203,7 +203,7 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: return a == b -BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors] +BatchedTensorInputs: TypeAlias = dict[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via [`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch]. @@ -377,6 +377,7 @@ class MultiModalBatchedField(BaseMultiModalField): pin_memory: bool, ) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + batch = cast(list[torch.Tensor], batch) if len(batch) == 1: # An optimization when `batch` contains only one tensor: # - produce exactly same result as `torch.stack(batch)` @@ -422,6 +423,7 @@ class MultiModalFlatField(BaseMultiModalField): pin_memory: bool, ) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + batch = cast(list[torch.Tensor], batch) if len(batch) == 1: # An optimization when `batch` contains only one tensor: # - produce exactly same result as `torch.concat(batch)` @@ -764,6 +766,15 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): return super().__getitem__(modality) # type: ignore[return-value] + def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]": + for modality, items in self.items(): + for i, item in enumerate(items): + if item is None: + raise RuntimeError( + f"Found empty mm_items[{modality}][{i}]") + + return self # type: ignore[return-value] + def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) for modality, items in self.items(): @@ -897,15 +908,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): *, device: torch.types.Device, ) -> BatchedTensorInputs: - json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) - - json_mapped = json_map_leaves( + return json_map_leaves( lambda x: x.to(device=device, non_blocking=True), - json_inputs, + batched_inputs, ) - return cast(BatchedTensorInputs, json_mapped) - def __getitem__(self, key: str): if key not in self: raise KeyError(f"Keyword argument {key!r} not found. " diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 78e2cb7fa7334..ce671479b1ae7 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1585,7 +1585,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): *, mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalHashes: - """Create MM hashes to be returned (only used in V1). + """Create MM hashes to be returned. Note: When overrides are provided via callers of `apply`, @@ -2098,23 +2098,22 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): encoder_inputs: MultiModalInputs, ): tokenizer = self.info.get_tokenizer() - decoder_prompt = self.create_decoder_prompt(prompt, mm_data) - if isinstance(decoder_prompt, str): + decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data) + if isinstance(decoder_prompt_raw, str): + decoder_prompt = decoder_prompt_raw decoder_prompt_ids = encode_tokens(tokenizer, - decoder_prompt, + decoder_prompt_raw, add_special_tokens=False) else: - decoder_prompt_ids = decoder_prompt - decoder_prompt = decode_tokens(tokenizer, decoder_prompt) + decoder_prompt = decode_tokens(tokenizer, decoder_prompt_raw) + decoder_prompt_ids = decoder_prompt_raw mm_inputs = MultiModalEncDecInputs( encoder_prompt=encoder_inputs["prompt"], encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], **encoder_inputs) - mm_inputs.update({ - "prompt": decoder_prompt, - "prompt_token_ids": decoder_prompt_ids - }) + mm_inputs["prompt"] = decoder_prompt + mm_inputs["prompt_token_ids"] = decoder_prompt_ids return mm_inputs def apply( diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 9b463e212bb49..26c5d188964c4 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -13,7 +13,7 @@ import vllm.envs as envs from vllm.logger import init_logger from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs, MultiModalKwargsOptionalItems, + MultiModalInputs, MultiModalKwargsItems, MultiModalPlaceholderDict) from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, EncDecMultiModalProcessor) @@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple): """Dummy data used for profiling.""" prompt_token_ids: list[int] - multi_modal_data: MultiModalKwargsOptionalItems + multi_modal_data: MultiModalKwargsItems multi_modal_placeholders: MultiModalPlaceholderDict @@ -239,7 +239,7 @@ class MultiModalProfiler(Generic[_I]): return DummyDecoderData( prompt_token_ids=prompt_token_ids, - multi_modal_data=mm_inputs["mm_kwargs"], + multi_modal_data=mm_inputs["mm_kwargs"].require_data(), multi_modal_placeholders=mm_inputs["mm_placeholders"], ) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 0f8aeceb39448..9b158267040af 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -19,6 +19,7 @@ from typing_extensions import deprecated import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection +from vllm.utils.jsontree import json_map_leaves from .audio import AudioMediaIO from .base import MediaIO @@ -383,6 +384,7 @@ def group_mm_kwargs_by_modality( *, device: torch.types.Device = None, pin_memory: bool = False, + merge_by_field_config: bool = False, ) -> Iterable[tuple[str, int, BatchedTensorInputs]]: """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same modality together into the same `MultiModalKwargs` instance. @@ -400,29 +402,31 @@ def group_mm_kwargs_by_modality( for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): items_lst = list(items) - # mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \ - # .get_data(pin_memory=pin_memory) - - # if device is not None: - # mm_kwargs_group = json_map_leaves( - # lambda x: x.to(device=device), - # mm_kwargs_group, - # ) - - # TODO: Once V0 is removed, we can use the merging logic above + # TODO: Enable `merge_by_field_config` for all models # to avoid creating an extra batch dimension (except for fields # that are meant to be stacked anyway). # We will also need to update each model to remove `flatten_bn`. - mm_kwargs_group = MultiModalKwargs.as_kwargs( - MultiModalKwargs.batch( - [ - MultiModalKwargsItems.from_seq([item]).get_data() - for item in items_lst - ], - pin_memory=pin_memory, - ), - device=device, - ) + if merge_by_field_config: + mm_kwargs_group: BatchedTensorInputs = dict( + MultiModalKwargsItems.from_seq(items_lst).get_data( + pin_memory=pin_memory)) + + if device is not None: + mm_kwargs_group = json_map_leaves( + lambda x: x.to(device=device), + mm_kwargs_group, + ) + else: + mm_kwargs_group = MultiModalKwargs.as_kwargs( + MultiModalKwargs.batch( + [ + MultiModalKwargsItems.from_seq([item]).get_data() + for item in items_lst + ], + pin_memory=pin_memory, + ), + device=device, + ) yield modality, len(items_lst), mm_kwargs_group diff --git a/vllm/utils/jsontree.py b/vllm/utils/jsontree.py index 804c443eb1841..7eb58b5f5cf84 100644 --- a/vllm/utils/jsontree.py +++ b/vllm/utils/jsontree.py @@ -4,7 +4,12 @@ from collections.abc import Iterable from functools import reduce -from typing import Callable, TypeVar, Union, cast, overload +from typing import TYPE_CHECKING, Callable, TypeVar, Union, cast, overload + +if TYPE_CHECKING: + import torch + + from vllm.multimodal.inputs import BatchedTensorInputs _T = TypeVar("_T") _U = TypeVar("_U") @@ -17,6 +22,19 @@ JSONTree = Union[ ] """A nested JSON structure where the leaves need not be JSON-serializable.""" +_JSONTree = Union[ + dict[str, "JSONTree[_T]"], + list["JSONTree[_T]"], + tuple["JSONTree[_T]", ...], + dict[str, _T], + list[_T], + tuple[_T, ...], + _T, +] +""" +Same as `JSONTree` but with additional `Union` members to satisfy overloads. +""" + def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: """Iterate through each leaf in a nested JSON structure.""" @@ -30,6 +48,14 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: yield value +@overload +def json_map_leaves( + func: Callable[["torch.Tensor"], "torch.Tensor"], + value: "BatchedTensorInputs", +) -> "BatchedTensorInputs": + ... + + @overload def json_map_leaves( func: Callable[[_T], _U], @@ -64,11 +90,14 @@ def json_map_leaves( def json_map_leaves( func: Callable[[_T], _U], - value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]], -) -> Union[dict[str, _U], list[_U], tuple[_U, ...], JSONTree[_U]]: + value: Union["BatchedTensorInputs", _JSONTree[_T]], +) -> Union["BatchedTensorInputs", _JSONTree[_U]]: """Apply a function to each leaf in a nested JSON structure.""" if isinstance(value, dict): - return {k: json_map_leaves(func, v) for k, v in value.items()} + return { + k: json_map_leaves(func, v) # type: ignore[arg-type] + for k, v in value.items() + } elif isinstance(value, list): return [json_map_leaves(func, v) for v in value] elif isinstance(value, tuple): @@ -125,7 +154,7 @@ def json_reduce_leaves( def json_reduce_leaves( func: Callable[..., Union[_T, _U]], - value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]], + value: _JSONTree[_T], initial: _U = cast(_U, ...), # noqa: B008 /, ) -> Union[_T, _U]: