diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 0f05f9b4efcd6..6fd8c2fb5c561 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union, cast +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -347,12 +347,16 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, - pixel_values: torch.Tensor, - **kwargs) -> torch.Tensor: - target_dtype = vision_tower.get_input_embeddings().weight.dtype - image_features = vision_tower(pixel_values.to(dtype=target_dtype), - **kwargs) + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + **kwargs, + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + target_dtype: torch.dtype = \ + vision_tower.get_input_embeddings().weight.dtype + image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \ + vision_tower(pixel_values.to(dtype=target_dtype), **kwargs) def select_features(leaf: torch.Tensor): return self._select_image_features( @@ -360,10 +364,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, strategy=self.config.vision_feature_select_strategy, ) - return cast( - Union[torch.Tensor, tuple[torch.Tensor, ...]], - json_map_leaves(select_features, image_features), - ) + return json_map_leaves(select_features, image_features) def _select_image_features(self, image_features: torch.Tensor, *, strategy: str) -> torch.Tensor: diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 8d7feb965e76c..4d8ed95b6cc8f 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -4,7 +4,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, - Union, cast) + Union) import torch import torch.nn as nn @@ -623,7 +623,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) + image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \ + vision_tower(pixel_values) def select_features(leaf: torch.Tensor): return self._select_image_features( @@ -631,10 +632,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): strategy=self.config.vision_feature_select_strategy, ) - return cast( - Union[torch.Tensor, tuple[torch.Tensor, ...]], - json_map_leaves(select_features, image_features), - ) + return json_map_leaves(select_features, image_features) def _process_image_pixels( self, diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index b2f020f3323e8..d81ac8c704e79 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping -from typing import Annotated, Literal, Optional, Union, cast +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -254,7 +254,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = tuple(vision_tower(p) for p in pixel_values) + image_features: tuple[torch.Tensor, ...] = \ + tuple(vision_tower(p) for p in pixel_values) def select_features(leaf: torch.Tensor): return self._select_image_features( @@ -262,10 +263,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, strategy=self.config.vision_feature_select_strategy, ) - return cast( - Union[torch.Tensor, tuple[torch.Tensor, ...]], - json_map_leaves(select_features, image_features), - ) + return json_map_leaves(select_features, image_features) # adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631 def pack_image_features(self, image_features: list[torch.Tensor], diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index b75c858a64808..3660efdc079aa 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -4,7 +4,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, - Union, cast) + Union) import torch import torch.nn as nn @@ -490,11 +490,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, pixel_values: Union[torch.Tensor, list[torch.Tensor]], ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # From vLLM LLaVA, vision tower output handling - image_hidden_states = vision_tower(pixel_values) - if not isinstance(image_hidden_states, torch.Tensor): - raise TypeError( - f"image_hidden_states type: {type(image_hidden_states)}" - " is not supported") + image_hidden_states: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \ + vision_tower(pixel_values) def select_features_fn(leaf: torch.Tensor): return self._select_image_features( @@ -502,11 +499,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, strategy=self.config.vision_feature_select_strategy, ) - selected_features = cast( - Union[torch.Tensor, tuple[torch.Tensor, ...]], - json_map_leaves(select_features_fn, image_hidden_states), - ) - return selected_features + return json_map_leaves(select_features_fn, image_hidden_states) def _add_tarsier_split_tokens( self, projected_image_features: torch.Tensor) -> torch.Tensor: diff --git a/vllm/utils/jsontree.py b/vllm/utils/jsontree.py index 457afb7e2c6ff..804c443eb1841 100644 --- a/vllm/utils/jsontree.py +++ b/vllm/utils/jsontree.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from functools import reduce -from typing import Callable, TypeVar, Union, overload +from typing import Callable, TypeVar, Union, cast, overload _T = TypeVar("_T") _U = TypeVar("_U") @@ -30,10 +30,42 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: yield value +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: Union[_T, dict[str, _T]], +) -> Union[_U, dict[str, _U]]: + ... + + +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: Union[_T, list[_T]], +) -> Union[_U, list[_U]]: + ... + + +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: Union[_T, tuple[_T, ...]], +) -> Union[_U, tuple[_U, ...]]: + ... + + +@overload def json_map_leaves( func: Callable[[_T], _U], value: JSONTree[_T], ) -> JSONTree[_U]: + ... + + +def json_map_leaves( + func: Callable[[_T], _U], + value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]], +) -> Union[dict[str, _U], list[_U], tuple[_U, ...], JSONTree[_U]]: """Apply a function to each leaf in a nested JSON structure.""" if isinstance(value, dict): return {k: json_map_leaves(func, v) for k, v in value.items()} @@ -45,6 +77,33 @@ def json_map_leaves( return func(value) +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: Union[_T, dict[str, _T]], + /, +) -> _T: + ... + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: Union[_T, list[_T]], + /, +) -> _T: + ... + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: Union[_T, tuple[_T, ...]], + /, +) -> _T: + ... + + @overload def json_reduce_leaves( func: Callable[[_T, _T], _T], @@ -65,10 +124,10 @@ def json_reduce_leaves( def json_reduce_leaves( - func: Callable[..., Union[_T, _U]], - value: JSONTree[_T], - initial: _U = ..., # type: ignore[assignment] - /, + func: Callable[..., Union[_T, _U]], + value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]], + initial: _U = cast(_U, ...), # noqa: B008 + /, ) -> Union[_T, _U]: """ Apply a function of two arguments cumulatively to each leaf in a