mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 05:55:01 +08:00
[mypy] Further improve MM type annotations (#25654)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
f17d37b006
commit
686cfd91e3
@ -415,9 +415,12 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
|
||||
num_image_patches),
|
||||
)
|
||||
|
||||
# Use overrides if provided; fallback to data-dependent hashing.
|
||||
mm_hashes = (mm_uuids if mm_uuids is not None else self._hash_mm_items(
|
||||
mm_items, hf_processor_mm_kwargs, tokenization_kwargs))
|
||||
mm_hashes = self._hash_mm_items(mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs,
|
||||
mm_uuids=mm_uuids)
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
|
||||
@ -14,7 +14,7 @@ import numpy as np
|
||||
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
|
||||
|
||||
from vllm.utils import LazyLoader, full_groupby, is_list_of
|
||||
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
@ -203,7 +203,7 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
|
||||
return a == b
|
||||
|
||||
|
||||
BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
|
||||
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
|
||||
"""
|
||||
A dictionary containing nested tensors which have been batched via
|
||||
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
|
||||
@ -377,6 +377,7 @@ class MultiModalBatchedField(BaseMultiModalField):
|
||||
pin_memory: bool,
|
||||
) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
batch = cast(list[torch.Tensor], batch)
|
||||
if len(batch) == 1:
|
||||
# An optimization when `batch` contains only one tensor:
|
||||
# - produce exactly same result as `torch.stack(batch)`
|
||||
@ -422,6 +423,7 @@ class MultiModalFlatField(BaseMultiModalField):
|
||||
pin_memory: bool,
|
||||
) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
batch = cast(list[torch.Tensor], batch)
|
||||
if len(batch) == 1:
|
||||
# An optimization when `batch` contains only one tensor:
|
||||
# - produce exactly same result as `torch.concat(batch)`
|
||||
@ -764,6 +766,15 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
|
||||
|
||||
return super().__getitem__(modality) # type: ignore[return-value]
|
||||
|
||||
def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
|
||||
for modality, items in self.items():
|
||||
for i, item in enumerate(items):
|
||||
if item is None:
|
||||
raise RuntimeError(
|
||||
f"Found empty mm_items[{modality}][{i}]")
|
||||
|
||||
return self # type: ignore[return-value]
|
||||
|
||||
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
|
||||
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||
for modality, items in self.items():
|
||||
@ -897,15 +908,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
*,
|
||||
device: torch.types.Device,
|
||||
) -> BatchedTensorInputs:
|
||||
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
||||
|
||||
json_mapped = json_map_leaves(
|
||||
return json_map_leaves(
|
||||
lambda x: x.to(device=device, non_blocking=True),
|
||||
json_inputs,
|
||||
batched_inputs,
|
||||
)
|
||||
|
||||
return cast(BatchedTensorInputs, json_mapped)
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
if key not in self:
|
||||
raise KeyError(f"Keyword argument {key!r} not found. "
|
||||
|
||||
@ -1585,7 +1585,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> MultiModalHashes:
|
||||
"""Create MM hashes to be returned (only used in V1).
|
||||
"""Create MM hashes to be returned.
|
||||
|
||||
|
||||
Note: When overrides are provided via callers of `apply`,
|
||||
@ -2098,23 +2098,22 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
encoder_inputs: MultiModalInputs,
|
||||
):
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
|
||||
if isinstance(decoder_prompt, str):
|
||||
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data)
|
||||
if isinstance(decoder_prompt_raw, str):
|
||||
decoder_prompt = decoder_prompt_raw
|
||||
decoder_prompt_ids = encode_tokens(tokenizer,
|
||||
decoder_prompt,
|
||||
decoder_prompt_raw,
|
||||
add_special_tokens=False)
|
||||
else:
|
||||
decoder_prompt_ids = decoder_prompt
|
||||
decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
|
||||
decoder_prompt = decode_tokens(tokenizer, decoder_prompt_raw)
|
||||
decoder_prompt_ids = decoder_prompt_raw
|
||||
|
||||
mm_inputs = MultiModalEncDecInputs(
|
||||
encoder_prompt=encoder_inputs["prompt"],
|
||||
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
|
||||
**encoder_inputs)
|
||||
mm_inputs.update({
|
||||
"prompt": decoder_prompt,
|
||||
"prompt_token_ids": decoder_prompt_ids
|
||||
})
|
||||
mm_inputs["prompt"] = decoder_prompt
|
||||
mm_inputs["prompt_token_ids"] = decoder_prompt_ids
|
||||
return mm_inputs
|
||||
|
||||
def apply(
|
||||
|
||||
@ -13,7 +13,7 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs, MultiModalKwargsOptionalItems,
|
||||
MultiModalInputs, MultiModalKwargsItems,
|
||||
MultiModalPlaceholderDict)
|
||||
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor)
|
||||
@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
multi_modal_data: MultiModalKwargsOptionalItems
|
||||
multi_modal_data: MultiModalKwargsItems
|
||||
multi_modal_placeholders: MultiModalPlaceholderDict
|
||||
|
||||
|
||||
@ -239,7 +239,7 @@ class MultiModalProfiler(Generic[_I]):
|
||||
|
||||
return DummyDecoderData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=mm_inputs["mm_kwargs"],
|
||||
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
|
||||
from .audio import AudioMediaIO
|
||||
from .base import MediaIO
|
||||
@ -383,6 +384,7 @@ def group_mm_kwargs_by_modality(
|
||||
*,
|
||||
device: torch.types.Device = None,
|
||||
pin_memory: bool = False,
|
||||
merge_by_field_config: bool = False,
|
||||
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||
modality together into the same `MultiModalKwargs` instance.
|
||||
@ -400,19 +402,21 @@ def group_mm_kwargs_by_modality(
|
||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||
items_lst = list(items)
|
||||
|
||||
# mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \
|
||||
# .get_data(pin_memory=pin_memory)
|
||||
|
||||
# if device is not None:
|
||||
# mm_kwargs_group = json_map_leaves(
|
||||
# lambda x: x.to(device=device),
|
||||
# mm_kwargs_group,
|
||||
# )
|
||||
|
||||
# TODO: Once V0 is removed, we can use the merging logic above
|
||||
# TODO: Enable `merge_by_field_config` for all models
|
||||
# to avoid creating an extra batch dimension (except for fields
|
||||
# that are meant to be stacked anyway).
|
||||
# We will also need to update each model to remove `flatten_bn`.
|
||||
if merge_by_field_config:
|
||||
mm_kwargs_group: BatchedTensorInputs = dict(
|
||||
MultiModalKwargsItems.from_seq(items_lst).get_data(
|
||||
pin_memory=pin_memory))
|
||||
|
||||
if device is not None:
|
||||
mm_kwargs_group = json_map_leaves(
|
||||
lambda x: x.to(device=device),
|
||||
mm_kwargs_group,
|
||||
)
|
||||
else:
|
||||
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
||||
MultiModalKwargs.batch(
|
||||
[
|
||||
|
||||
@ -4,7 +4,12 @@
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import reduce
|
||||
from typing import Callable, TypeVar, Union, cast, overload
|
||||
from typing import TYPE_CHECKING, Callable, TypeVar, Union, cast, overload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import BatchedTensorInputs
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_U = TypeVar("_U")
|
||||
@ -17,6 +22,19 @@ JSONTree = Union[
|
||||
]
|
||||
"""A nested JSON structure where the leaves need not be JSON-serializable."""
|
||||
|
||||
_JSONTree = Union[
|
||||
dict[str, "JSONTree[_T]"],
|
||||
list["JSONTree[_T]"],
|
||||
tuple["JSONTree[_T]", ...],
|
||||
dict[str, _T],
|
||||
list[_T],
|
||||
tuple[_T, ...],
|
||||
_T,
|
||||
]
|
||||
"""
|
||||
Same as `JSONTree` but with additional `Union` members to satisfy overloads.
|
||||
"""
|
||||
|
||||
|
||||
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
|
||||
"""Iterate through each leaf in a nested JSON structure."""
|
||||
@ -30,6 +48,14 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
|
||||
yield value
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[["torch.Tensor"], "torch.Tensor"],
|
||||
value: "BatchedTensorInputs",
|
||||
) -> "BatchedTensorInputs":
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
@ -64,11 +90,14 @@ def json_map_leaves(
|
||||
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
|
||||
) -> Union[dict[str, _U], list[_U], tuple[_U, ...], JSONTree[_U]]:
|
||||
value: Union["BatchedTensorInputs", _JSONTree[_T]],
|
||||
) -> Union["BatchedTensorInputs", _JSONTree[_U]]:
|
||||
"""Apply a function to each leaf in a nested JSON structure."""
|
||||
if isinstance(value, dict):
|
||||
return {k: json_map_leaves(func, v) for k, v in value.items()}
|
||||
return {
|
||||
k: json_map_leaves(func, v) # type: ignore[arg-type]
|
||||
for k, v in value.items()
|
||||
}
|
||||
elif isinstance(value, list):
|
||||
return [json_map_leaves(func, v) for v in value]
|
||||
elif isinstance(value, tuple):
|
||||
@ -125,7 +154,7 @@ def json_reduce_leaves(
|
||||
|
||||
def json_reduce_leaves(
|
||||
func: Callable[..., Union[_T, _U]],
|
||||
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
|
||||
value: _JSONTree[_T],
|
||||
initial: _U = cast(_U, ...), # noqa: B008
|
||||
/,
|
||||
) -> Union[_T, _U]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user