mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 05:55:01 +08:00
[mypy] Further improve MM type annotations (#25654)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
f17d37b006
commit
686cfd91e3
@ -415,9 +415,12 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
|||||||
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
|
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
|
||||||
num_image_patches),
|
num_image_patches),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use overrides if provided; fallback to data-dependent hashing.
|
# Use overrides if provided; fallback to data-dependent hashing.
|
||||||
mm_hashes = (mm_uuids if mm_uuids is not None else self._hash_mm_items(
|
mm_hashes = self._hash_mm_items(mm_items,
|
||||||
mm_items, hf_processor_mm_kwargs, tokenization_kwargs))
|
hf_processor_mm_kwargs,
|
||||||
|
tokenization_kwargs,
|
||||||
|
mm_uuids=mm_uuids)
|
||||||
|
|
||||||
return MultiModalInputs(
|
return MultiModalInputs(
|
||||||
type="multimodal",
|
type="multimodal",
|
||||||
|
|||||||
@ -14,7 +14,7 @@ import numpy as np
|
|||||||
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
|
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
|
||||||
|
|
||||||
from vllm.utils import LazyLoader, full_groupby, is_list_of
|
from vllm.utils import LazyLoader, full_groupby, is_list_of
|
||||||
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
from vllm.utils.jsontree import json_map_leaves
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
import torch
|
||||||
@ -203,7 +203,7 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
|
|||||||
return a == b
|
return a == b
|
||||||
|
|
||||||
|
|
||||||
BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
|
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
|
||||||
"""
|
"""
|
||||||
A dictionary containing nested tensors which have been batched via
|
A dictionary containing nested tensors which have been batched via
|
||||||
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
|
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
|
||||||
@ -377,6 +377,7 @@ class MultiModalBatchedField(BaseMultiModalField):
|
|||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
) -> NestedTensors:
|
) -> NestedTensors:
|
||||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||||
|
batch = cast(list[torch.Tensor], batch)
|
||||||
if len(batch) == 1:
|
if len(batch) == 1:
|
||||||
# An optimization when `batch` contains only one tensor:
|
# An optimization when `batch` contains only one tensor:
|
||||||
# - produce exactly same result as `torch.stack(batch)`
|
# - produce exactly same result as `torch.stack(batch)`
|
||||||
@ -422,6 +423,7 @@ class MultiModalFlatField(BaseMultiModalField):
|
|||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
) -> NestedTensors:
|
) -> NestedTensors:
|
||||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||||
|
batch = cast(list[torch.Tensor], batch)
|
||||||
if len(batch) == 1:
|
if len(batch) == 1:
|
||||||
# An optimization when `batch` contains only one tensor:
|
# An optimization when `batch` contains only one tensor:
|
||||||
# - produce exactly same result as `torch.concat(batch)`
|
# - produce exactly same result as `torch.concat(batch)`
|
||||||
@ -764,6 +766,15 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
|
|||||||
|
|
||||||
return super().__getitem__(modality) # type: ignore[return-value]
|
return super().__getitem__(modality) # type: ignore[return-value]
|
||||||
|
|
||||||
|
def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
|
||||||
|
for modality, items in self.items():
|
||||||
|
for i, item in enumerate(items):
|
||||||
|
if item is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Found empty mm_items[{modality}][{i}]")
|
||||||
|
|
||||||
|
return self # type: ignore[return-value]
|
||||||
|
|
||||||
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
|
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
|
||||||
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||||
for modality, items in self.items():
|
for modality, items in self.items():
|
||||||
@ -897,15 +908,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
|||||||
*,
|
*,
|
||||||
device: torch.types.Device,
|
device: torch.types.Device,
|
||||||
) -> BatchedTensorInputs:
|
) -> BatchedTensorInputs:
|
||||||
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
return json_map_leaves(
|
||||||
|
|
||||||
json_mapped = json_map_leaves(
|
|
||||||
lambda x: x.to(device=device, non_blocking=True),
|
lambda x: x.to(device=device, non_blocking=True),
|
||||||
json_inputs,
|
batched_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return cast(BatchedTensorInputs, json_mapped)
|
|
||||||
|
|
||||||
def __getitem__(self, key: str):
|
def __getitem__(self, key: str):
|
||||||
if key not in self:
|
if key not in self:
|
||||||
raise KeyError(f"Keyword argument {key!r} not found. "
|
raise KeyError(f"Keyword argument {key!r} not found. "
|
||||||
|
|||||||
@ -1585,7 +1585,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
*,
|
*,
|
||||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||||
) -> MultiModalHashes:
|
) -> MultiModalHashes:
|
||||||
"""Create MM hashes to be returned (only used in V1).
|
"""Create MM hashes to be returned.
|
||||||
|
|
||||||
|
|
||||||
Note: When overrides are provided via callers of `apply`,
|
Note: When overrides are provided via callers of `apply`,
|
||||||
@ -2098,23 +2098,22 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
encoder_inputs: MultiModalInputs,
|
encoder_inputs: MultiModalInputs,
|
||||||
):
|
):
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
|
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data)
|
||||||
if isinstance(decoder_prompt, str):
|
if isinstance(decoder_prompt_raw, str):
|
||||||
|
decoder_prompt = decoder_prompt_raw
|
||||||
decoder_prompt_ids = encode_tokens(tokenizer,
|
decoder_prompt_ids = encode_tokens(tokenizer,
|
||||||
decoder_prompt,
|
decoder_prompt_raw,
|
||||||
add_special_tokens=False)
|
add_special_tokens=False)
|
||||||
else:
|
else:
|
||||||
decoder_prompt_ids = decoder_prompt
|
decoder_prompt = decode_tokens(tokenizer, decoder_prompt_raw)
|
||||||
decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
|
decoder_prompt_ids = decoder_prompt_raw
|
||||||
|
|
||||||
mm_inputs = MultiModalEncDecInputs(
|
mm_inputs = MultiModalEncDecInputs(
|
||||||
encoder_prompt=encoder_inputs["prompt"],
|
encoder_prompt=encoder_inputs["prompt"],
|
||||||
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
|
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
|
||||||
**encoder_inputs)
|
**encoder_inputs)
|
||||||
mm_inputs.update({
|
mm_inputs["prompt"] = decoder_prompt
|
||||||
"prompt": decoder_prompt,
|
mm_inputs["prompt_token_ids"] = decoder_prompt_ids
|
||||||
"prompt_token_ids": decoder_prompt_ids
|
|
||||||
})
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
|
|||||||
@ -13,7 +13,7 @@ import vllm.envs as envs
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||||
MultiModalInputs, MultiModalKwargsOptionalItems,
|
MultiModalInputs, MultiModalKwargsItems,
|
||||||
MultiModalPlaceholderDict)
|
MultiModalPlaceholderDict)
|
||||||
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
||||||
EncDecMultiModalProcessor)
|
EncDecMultiModalProcessor)
|
||||||
@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple):
|
|||||||
"""Dummy data used for profiling."""
|
"""Dummy data used for profiling."""
|
||||||
|
|
||||||
prompt_token_ids: list[int]
|
prompt_token_ids: list[int]
|
||||||
multi_modal_data: MultiModalKwargsOptionalItems
|
multi_modal_data: MultiModalKwargsItems
|
||||||
multi_modal_placeholders: MultiModalPlaceholderDict
|
multi_modal_placeholders: MultiModalPlaceholderDict
|
||||||
|
|
||||||
|
|
||||||
@ -239,7 +239,7 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
|
|
||||||
return DummyDecoderData(
|
return DummyDecoderData(
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
multi_modal_data=mm_inputs["mm_kwargs"],
|
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
|
||||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from typing_extensions import deprecated
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.connections import HTTPConnection, global_http_connection
|
from vllm.connections import HTTPConnection, global_http_connection
|
||||||
|
from vllm.utils.jsontree import json_map_leaves
|
||||||
|
|
||||||
from .audio import AudioMediaIO
|
from .audio import AudioMediaIO
|
||||||
from .base import MediaIO
|
from .base import MediaIO
|
||||||
@ -383,6 +384,7 @@ def group_mm_kwargs_by_modality(
|
|||||||
*,
|
*,
|
||||||
device: torch.types.Device = None,
|
device: torch.types.Device = None,
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
|
merge_by_field_config: bool = False,
|
||||||
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
||||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||||
modality together into the same `MultiModalKwargs` instance.
|
modality together into the same `MultiModalKwargs` instance.
|
||||||
@ -400,29 +402,31 @@ def group_mm_kwargs_by_modality(
|
|||||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||||
items_lst = list(items)
|
items_lst = list(items)
|
||||||
|
|
||||||
# mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \
|
# TODO: Enable `merge_by_field_config` for all models
|
||||||
# .get_data(pin_memory=pin_memory)
|
|
||||||
|
|
||||||
# if device is not None:
|
|
||||||
# mm_kwargs_group = json_map_leaves(
|
|
||||||
# lambda x: x.to(device=device),
|
|
||||||
# mm_kwargs_group,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# TODO: Once V0 is removed, we can use the merging logic above
|
|
||||||
# to avoid creating an extra batch dimension (except for fields
|
# to avoid creating an extra batch dimension (except for fields
|
||||||
# that are meant to be stacked anyway).
|
# that are meant to be stacked anyway).
|
||||||
# We will also need to update each model to remove `flatten_bn`.
|
# We will also need to update each model to remove `flatten_bn`.
|
||||||
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
if merge_by_field_config:
|
||||||
MultiModalKwargs.batch(
|
mm_kwargs_group: BatchedTensorInputs = dict(
|
||||||
[
|
MultiModalKwargsItems.from_seq(items_lst).get_data(
|
||||||
MultiModalKwargsItems.from_seq([item]).get_data()
|
pin_memory=pin_memory))
|
||||||
for item in items_lst
|
|
||||||
],
|
if device is not None:
|
||||||
pin_memory=pin_memory,
|
mm_kwargs_group = json_map_leaves(
|
||||||
),
|
lambda x: x.to(device=device),
|
||||||
device=device,
|
mm_kwargs_group,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
||||||
|
MultiModalKwargs.batch(
|
||||||
|
[
|
||||||
|
MultiModalKwargsItems.from_seq([item]).get_data()
|
||||||
|
for item in items_lst
|
||||||
|
],
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
yield modality, len(items_lst), mm_kwargs_group
|
yield modality, len(items_lst), mm_kwargs_group
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,12 @@
|
|||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Callable, TypeVar, Union, cast, overload
|
from typing import TYPE_CHECKING, Callable, TypeVar, Union, cast, overload
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.multimodal.inputs import BatchedTensorInputs
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
_U = TypeVar("_U")
|
_U = TypeVar("_U")
|
||||||
@ -17,6 +22,19 @@ JSONTree = Union[
|
|||||||
]
|
]
|
||||||
"""A nested JSON structure where the leaves need not be JSON-serializable."""
|
"""A nested JSON structure where the leaves need not be JSON-serializable."""
|
||||||
|
|
||||||
|
_JSONTree = Union[
|
||||||
|
dict[str, "JSONTree[_T]"],
|
||||||
|
list["JSONTree[_T]"],
|
||||||
|
tuple["JSONTree[_T]", ...],
|
||||||
|
dict[str, _T],
|
||||||
|
list[_T],
|
||||||
|
tuple[_T, ...],
|
||||||
|
_T,
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
Same as `JSONTree` but with additional `Union` members to satisfy overloads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
|
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
|
||||||
"""Iterate through each leaf in a nested JSON structure."""
|
"""Iterate through each leaf in a nested JSON structure."""
|
||||||
@ -30,6 +48,14 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
|
|||||||
yield value
|
yield value
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def json_map_leaves(
|
||||||
|
func: Callable[["torch.Tensor"], "torch.Tensor"],
|
||||||
|
value: "BatchedTensorInputs",
|
||||||
|
) -> "BatchedTensorInputs":
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def json_map_leaves(
|
def json_map_leaves(
|
||||||
func: Callable[[_T], _U],
|
func: Callable[[_T], _U],
|
||||||
@ -64,11 +90,14 @@ def json_map_leaves(
|
|||||||
|
|
||||||
def json_map_leaves(
|
def json_map_leaves(
|
||||||
func: Callable[[_T], _U],
|
func: Callable[[_T], _U],
|
||||||
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
|
value: Union["BatchedTensorInputs", _JSONTree[_T]],
|
||||||
) -> Union[dict[str, _U], list[_U], tuple[_U, ...], JSONTree[_U]]:
|
) -> Union["BatchedTensorInputs", _JSONTree[_U]]:
|
||||||
"""Apply a function to each leaf in a nested JSON structure."""
|
"""Apply a function to each leaf in a nested JSON structure."""
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
return {k: json_map_leaves(func, v) for k, v in value.items()}
|
return {
|
||||||
|
k: json_map_leaves(func, v) # type: ignore[arg-type]
|
||||||
|
for k, v in value.items()
|
||||||
|
}
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
return [json_map_leaves(func, v) for v in value]
|
return [json_map_leaves(func, v) for v in value]
|
||||||
elif isinstance(value, tuple):
|
elif isinstance(value, tuple):
|
||||||
@ -125,7 +154,7 @@ def json_reduce_leaves(
|
|||||||
|
|
||||||
def json_reduce_leaves(
|
def json_reduce_leaves(
|
||||||
func: Callable[..., Union[_T, _U]],
|
func: Callable[..., Union[_T, _U]],
|
||||||
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
|
value: _JSONTree[_T],
|
||||||
initial: _U = cast(_U, ...), # noqa: B008
|
initial: _U = cast(_U, ...), # noqa: B008
|
||||||
/,
|
/,
|
||||||
) -> Union[_T, _U]:
|
) -> Union[_T, _U]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user