[mypy] Further improve MM type annotations (#25654)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Cyrus Leung 2025-09-25 18:57:36 +08:00 committed by yewentao256
parent f17d37b006
commit 686cfd91e3
6 changed files with 90 additions and 48 deletions

View File

@ -415,9 +415,12 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
num_image_patches), num_image_patches),
) )
# Use overrides if provided; fallback to data-dependent hashing. # Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = (mm_uuids if mm_uuids is not None else self._hash_mm_items( mm_hashes = self._hash_mm_items(mm_items,
mm_items, hf_processor_mm_kwargs, tokenization_kwargs)) hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids)
return MultiModalInputs( return MultiModalInputs(
type="multimodal", type="multimodal",

View File

@ -14,7 +14,7 @@ import numpy as np
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
from vllm.utils import LazyLoader, full_groupby, is_list_of from vllm.utils import LazyLoader, full_groupby, is_list_of
from vllm.utils.jsontree import JSONTree, json_map_leaves from vllm.utils.jsontree import json_map_leaves
if TYPE_CHECKING: if TYPE_CHECKING:
import torch import torch
@ -203,7 +203,7 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
return a == b return a == b
BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors] BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
""" """
A dictionary containing nested tensors which have been batched via A dictionary containing nested tensors which have been batched via
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch]. [`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
@ -377,6 +377,7 @@ class MultiModalBatchedField(BaseMultiModalField):
pin_memory: bool, pin_memory: bool,
) -> NestedTensors: ) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
batch = cast(list[torch.Tensor], batch)
if len(batch) == 1: if len(batch) == 1:
# An optimization when `batch` contains only one tensor: # An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.stack(batch)` # - produce exactly same result as `torch.stack(batch)`
@ -422,6 +423,7 @@ class MultiModalFlatField(BaseMultiModalField):
pin_memory: bool, pin_memory: bool,
) -> NestedTensors: ) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
batch = cast(list[torch.Tensor], batch)
if len(batch) == 1: if len(batch) == 1:
# An optimization when `batch` contains only one tensor: # An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.concat(batch)` # - produce exactly same result as `torch.concat(batch)`
@ -764,6 +766,15 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
return super().__getitem__(modality) # type: ignore[return-value] return super().__getitem__(modality) # type: ignore[return-value]
def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
for modality, items in self.items():
for i, item in enumerate(items):
if item is None:
raise RuntimeError(
f"Found empty mm_items[{modality}][{i}]")
return self # type: ignore[return-value]
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for modality, items in self.items(): for modality, items in self.items():
@ -897,15 +908,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
*, *,
device: torch.types.Device, device: torch.types.Device,
) -> BatchedTensorInputs: ) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) return json_map_leaves(
json_mapped = json_map_leaves(
lambda x: x.to(device=device, non_blocking=True), lambda x: x.to(device=device, non_blocking=True),
json_inputs, batched_inputs,
) )
return cast(BatchedTensorInputs, json_mapped)
def __getitem__(self, key: str): def __getitem__(self, key: str):
if key not in self: if key not in self:
raise KeyError(f"Keyword argument {key!r} not found. " raise KeyError(f"Keyword argument {key!r} not found. "

View File

@ -1585,7 +1585,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalHashes: ) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1). """Create MM hashes to be returned.
Note: When overrides are provided via callers of `apply`, Note: When overrides are provided via callers of `apply`,
@ -2098,23 +2098,22 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
encoder_inputs: MultiModalInputs, encoder_inputs: MultiModalInputs,
): ):
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
decoder_prompt = self.create_decoder_prompt(prompt, mm_data) decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data)
if isinstance(decoder_prompt, str): if isinstance(decoder_prompt_raw, str):
decoder_prompt = decoder_prompt_raw
decoder_prompt_ids = encode_tokens(tokenizer, decoder_prompt_ids = encode_tokens(tokenizer,
decoder_prompt, decoder_prompt_raw,
add_special_tokens=False) add_special_tokens=False)
else: else:
decoder_prompt_ids = decoder_prompt decoder_prompt = decode_tokens(tokenizer, decoder_prompt_raw)
decoder_prompt = decode_tokens(tokenizer, decoder_prompt) decoder_prompt_ids = decoder_prompt_raw
mm_inputs = MultiModalEncDecInputs( mm_inputs = MultiModalEncDecInputs(
encoder_prompt=encoder_inputs["prompt"], encoder_prompt=encoder_inputs["prompt"],
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
**encoder_inputs) **encoder_inputs)
mm_inputs.update({ mm_inputs["prompt"] = decoder_prompt
"prompt": decoder_prompt, mm_inputs["prompt_token_ids"] = decoder_prompt_ids
"prompt_token_ids": decoder_prompt_ids
})
return mm_inputs return mm_inputs
def apply( def apply(

View File

@ -13,7 +13,7 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalKwargsOptionalItems, MultiModalInputs, MultiModalKwargsItems,
MultiModalPlaceholderDict) MultiModalPlaceholderDict)
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
EncDecMultiModalProcessor) EncDecMultiModalProcessor)
@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling.""" """Dummy data used for profiling."""
prompt_token_ids: list[int] prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargsOptionalItems multi_modal_data: MultiModalKwargsItems
multi_modal_placeholders: MultiModalPlaceholderDict multi_modal_placeholders: MultiModalPlaceholderDict
@ -239,7 +239,7 @@ class MultiModalProfiler(Generic[_I]):
return DummyDecoderData( return DummyDecoderData(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_inputs["mm_kwargs"], multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
multi_modal_placeholders=mm_inputs["mm_placeholders"], multi_modal_placeholders=mm_inputs["mm_placeholders"],
) )

View File

@ -19,6 +19,7 @@ from typing_extensions import deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection from vllm.connections import HTTPConnection, global_http_connection
from vllm.utils.jsontree import json_map_leaves
from .audio import AudioMediaIO from .audio import AudioMediaIO
from .base import MediaIO from .base import MediaIO
@ -383,6 +384,7 @@ def group_mm_kwargs_by_modality(
*, *,
device: torch.types.Device = None, device: torch.types.Device = None,
pin_memory: bool = False, pin_memory: bool = False,
merge_by_field_config: bool = False,
) -> Iterable[tuple[str, int, BatchedTensorInputs]]: ) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance. modality together into the same `MultiModalKwargs` instance.
@ -400,29 +402,31 @@ def group_mm_kwargs_by_modality(
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items) items_lst = list(items)
# mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \ # TODO: Enable `merge_by_field_config` for all models
# .get_data(pin_memory=pin_memory)
# if device is not None:
# mm_kwargs_group = json_map_leaves(
# lambda x: x.to(device=device),
# mm_kwargs_group,
# )
# TODO: Once V0 is removed, we can use the merging logic above
# to avoid creating an extra batch dimension (except for fields # to avoid creating an extra batch dimension (except for fields
# that are meant to be stacked anyway). # that are meant to be stacked anyway).
# We will also need to update each model to remove `flatten_bn`. # We will also need to update each model to remove `flatten_bn`.
mm_kwargs_group = MultiModalKwargs.as_kwargs( if merge_by_field_config:
MultiModalKwargs.batch( mm_kwargs_group: BatchedTensorInputs = dict(
[ MultiModalKwargsItems.from_seq(items_lst).get_data(
MultiModalKwargsItems.from_seq([item]).get_data() pin_memory=pin_memory))
for item in items_lst
], if device is not None:
pin_memory=pin_memory, mm_kwargs_group = json_map_leaves(
), lambda x: x.to(device=device),
device=device, mm_kwargs_group,
) )
else:
mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch(
[
MultiModalKwargsItems.from_seq([item]).get_data()
for item in items_lst
],
pin_memory=pin_memory,
),
device=device,
)
yield modality, len(items_lst), mm_kwargs_group yield modality, len(items_lst), mm_kwargs_group

View File

@ -4,7 +4,12 @@
from collections.abc import Iterable from collections.abc import Iterable
from functools import reduce from functools import reduce
from typing import Callable, TypeVar, Union, cast, overload from typing import TYPE_CHECKING, Callable, TypeVar, Union, cast, overload
if TYPE_CHECKING:
import torch
from vllm.multimodal.inputs import BatchedTensorInputs
_T = TypeVar("_T") _T = TypeVar("_T")
_U = TypeVar("_U") _U = TypeVar("_U")
@ -17,6 +22,19 @@ JSONTree = Union[
] ]
"""A nested JSON structure where the leaves need not be JSON-serializable.""" """A nested JSON structure where the leaves need not be JSON-serializable."""
_JSONTree = Union[
dict[str, "JSONTree[_T]"],
list["JSONTree[_T]"],
tuple["JSONTree[_T]", ...],
dict[str, _T],
list[_T],
tuple[_T, ...],
_T,
]
"""
Same as `JSONTree` but with additional `Union` members to satisfy overloads.
"""
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
"""Iterate through each leaf in a nested JSON structure.""" """Iterate through each leaf in a nested JSON structure."""
@ -30,6 +48,14 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
yield value yield value
@overload
def json_map_leaves(
func: Callable[["torch.Tensor"], "torch.Tensor"],
value: "BatchedTensorInputs",
) -> "BatchedTensorInputs":
...
@overload @overload
def json_map_leaves( def json_map_leaves(
func: Callable[[_T], _U], func: Callable[[_T], _U],
@ -64,11 +90,14 @@ def json_map_leaves(
def json_map_leaves( def json_map_leaves(
func: Callable[[_T], _U], func: Callable[[_T], _U],
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]], value: Union["BatchedTensorInputs", _JSONTree[_T]],
) -> Union[dict[str, _U], list[_U], tuple[_U, ...], JSONTree[_U]]: ) -> Union["BatchedTensorInputs", _JSONTree[_U]]:
"""Apply a function to each leaf in a nested JSON structure.""" """Apply a function to each leaf in a nested JSON structure."""
if isinstance(value, dict): if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()} return {
k: json_map_leaves(func, v) # type: ignore[arg-type]
for k, v in value.items()
}
elif isinstance(value, list): elif isinstance(value, list):
return [json_map_leaves(func, v) for v in value] return [json_map_leaves(func, v) for v in value]
elif isinstance(value, tuple): elif isinstance(value, tuple):
@ -125,7 +154,7 @@ def json_reduce_leaves(
def json_reduce_leaves( def json_reduce_leaves(
func: Callable[..., Union[_T, _U]], func: Callable[..., Union[_T, _U]],
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]], value: _JSONTree[_T],
initial: _U = cast(_U, ...), # noqa: B008 initial: _U = cast(_U, ...), # noqa: B008
/, /,
) -> Union[_T, _U]: ) -> Union[_T, _U]: