[mypy] Further improve MM type annotations (#25654)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Cyrus Leung 2025-09-25 18:57:36 +08:00 committed by yewentao256
parent f17d37b006
commit 686cfd91e3
6 changed files with 90 additions and 48 deletions

View File

@ -415,9 +415,12 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
num_image_patches),
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = (mm_uuids if mm_uuids is not None else self._hash_mm_items(
mm_items, hf_processor_mm_kwargs, tokenization_kwargs))
mm_hashes = self._hash_mm_items(mm_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids)
return MultiModalInputs(
type="multimodal",

View File

@ -14,7 +14,7 @@ import numpy as np
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
from vllm.utils import LazyLoader, full_groupby, is_list_of
from vllm.utils.jsontree import JSONTree, json_map_leaves
from vllm.utils.jsontree import json_map_leaves
if TYPE_CHECKING:
import torch
@ -203,7 +203,7 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
return a == b
BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
@ -377,6 +377,7 @@ class MultiModalBatchedField(BaseMultiModalField):
pin_memory: bool,
) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
batch = cast(list[torch.Tensor], batch)
if len(batch) == 1:
# An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.stack(batch)`
@ -422,6 +423,7 @@ class MultiModalFlatField(BaseMultiModalField):
pin_memory: bool,
) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
batch = cast(list[torch.Tensor], batch)
if len(batch) == 1:
# An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.concat(batch)`
@ -764,6 +766,15 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
return super().__getitem__(modality) # type: ignore[return-value]
def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
for modality, items in self.items():
for i, item in enumerate(items):
if item is None:
raise RuntimeError(
f"Found empty mm_items[{modality}][{i}]")
return self # type: ignore[return-value]
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for modality, items in self.items():
@ -897,15 +908,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
*,
device: torch.types.Device,
) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
json_mapped = json_map_leaves(
return json_map_leaves(
lambda x: x.to(device=device, non_blocking=True),
json_inputs,
batched_inputs,
)
return cast(BatchedTensorInputs, json_mapped)
def __getitem__(self, key: str):
if key not in self:
raise KeyError(f"Keyword argument {key!r} not found. "

View File

@ -1585,7 +1585,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1).
"""Create MM hashes to be returned.
Note: When overrides are provided via callers of `apply`,
@ -2098,23 +2098,22 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
encoder_inputs: MultiModalInputs,
):
tokenizer = self.info.get_tokenizer()
decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
if isinstance(decoder_prompt, str):
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data)
if isinstance(decoder_prompt_raw, str):
decoder_prompt = decoder_prompt_raw
decoder_prompt_ids = encode_tokens(tokenizer,
decoder_prompt,
decoder_prompt_raw,
add_special_tokens=False)
else:
decoder_prompt_ids = decoder_prompt
decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
decoder_prompt = decode_tokens(tokenizer, decoder_prompt_raw)
decoder_prompt_ids = decoder_prompt_raw
mm_inputs = MultiModalEncDecInputs(
encoder_prompt=encoder_inputs["prompt"],
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
**encoder_inputs)
mm_inputs.update({
"prompt": decoder_prompt,
"prompt_token_ids": decoder_prompt_ids
})
mm_inputs["prompt"] = decoder_prompt
mm_inputs["prompt_token_ids"] = decoder_prompt_ids
return mm_inputs
def apply(

View File

@ -13,7 +13,7 @@ import vllm.envs as envs
from vllm.logger import init_logger
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalKwargsOptionalItems,
MultiModalInputs, MultiModalKwargsItems,
MultiModalPlaceholderDict)
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
EncDecMultiModalProcessor)
@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargsOptionalItems
multi_modal_data: MultiModalKwargsItems
multi_modal_placeholders: MultiModalPlaceholderDict
@ -239,7 +239,7 @@ class MultiModalProfiler(Generic[_I]):
return DummyDecoderData(
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)

View File

@ -19,6 +19,7 @@ from typing_extensions import deprecated
import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection
from vllm.utils.jsontree import json_map_leaves
from .audio import AudioMediaIO
from .base import MediaIO
@ -383,6 +384,7 @@ def group_mm_kwargs_by_modality(
*,
device: torch.types.Device = None,
pin_memory: bool = False,
merge_by_field_config: bool = False,
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance.
@ -400,29 +402,31 @@ def group_mm_kwargs_by_modality(
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items)
# mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \
# .get_data(pin_memory=pin_memory)
# if device is not None:
# mm_kwargs_group = json_map_leaves(
# lambda x: x.to(device=device),
# mm_kwargs_group,
# )
# TODO: Once V0 is removed, we can use the merging logic above
# TODO: Enable `merge_by_field_config` for all models
# to avoid creating an extra batch dimension (except for fields
# that are meant to be stacked anyway).
# We will also need to update each model to remove `flatten_bn`.
mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch(
[
MultiModalKwargsItems.from_seq([item]).get_data()
for item in items_lst
],
pin_memory=pin_memory,
),
device=device,
)
if merge_by_field_config:
mm_kwargs_group: BatchedTensorInputs = dict(
MultiModalKwargsItems.from_seq(items_lst).get_data(
pin_memory=pin_memory))
if device is not None:
mm_kwargs_group = json_map_leaves(
lambda x: x.to(device=device),
mm_kwargs_group,
)
else:
mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch(
[
MultiModalKwargsItems.from_seq([item]).get_data()
for item in items_lst
],
pin_memory=pin_memory,
),
device=device,
)
yield modality, len(items_lst), mm_kwargs_group

View File

@ -4,7 +4,12 @@
from collections.abc import Iterable
from functools import reduce
from typing import Callable, TypeVar, Union, cast, overload
from typing import TYPE_CHECKING, Callable, TypeVar, Union, cast, overload
if TYPE_CHECKING:
import torch
from vllm.multimodal.inputs import BatchedTensorInputs
_T = TypeVar("_T")
_U = TypeVar("_U")
@ -17,6 +22,19 @@ JSONTree = Union[
]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
_JSONTree = Union[
dict[str, "JSONTree[_T]"],
list["JSONTree[_T]"],
tuple["JSONTree[_T]", ...],
dict[str, _T],
list[_T],
tuple[_T, ...],
_T,
]
"""
Same as `JSONTree` but with additional `Union` members to satisfy overloads.
"""
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
"""Iterate through each leaf in a nested JSON structure."""
@ -30,6 +48,14 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
yield value
@overload
def json_map_leaves(
func: Callable[["torch.Tensor"], "torch.Tensor"],
value: "BatchedTensorInputs",
) -> "BatchedTensorInputs":
...
@overload
def json_map_leaves(
func: Callable[[_T], _U],
@ -64,11 +90,14 @@ def json_map_leaves(
def json_map_leaves(
func: Callable[[_T], _U],
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
) -> Union[dict[str, _U], list[_U], tuple[_U, ...], JSONTree[_U]]:
value: Union["BatchedTensorInputs", _JSONTree[_T]],
) -> Union["BatchedTensorInputs", _JSONTree[_U]]:
"""Apply a function to each leaf in a nested JSON structure."""
if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()}
return {
k: json_map_leaves(func, v) # type: ignore[arg-type]
for k, v in value.items()
}
elif isinstance(value, list):
return [json_map_leaves(func, v) for v in value]
elif isinstance(value, tuple):
@ -125,7 +154,7 @@ def json_reduce_leaves(
def json_reduce_leaves(
func: Callable[..., Union[_T, _U]],
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
value: _JSONTree[_T],
initial: _U = cast(_U, ...), # noqa: B008
/,
) -> Union[_T, _U]: