[Misc] Improve type annotations for jsontree (#25577)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-09-24 22:49:58 +08:00 committed by GitHub
parent 8938774c79
commit 9313be5017
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 88 additions and 39 deletions

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision # Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, Optional, Union, cast from typing import Annotated, Literal, Optional, Union
import torch import torch
from torch import nn from torch import nn
@ -347,12 +347,16 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, def _image_pixels_to_features(
pixel_values: torch.Tensor, self,
**kwargs) -> torch.Tensor: vision_tower: SiglipVisionModel,
target_dtype = vision_tower.get_input_embeddings().weight.dtype pixel_values: torch.Tensor,
image_features = vision_tower(pixel_values.to(dtype=target_dtype), **kwargs,
**kwargs) ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
target_dtype: torch.dtype = \
vision_tower.get_input_embeddings().weight.dtype
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
vision_tower(pixel_values.to(dtype=target_dtype), **kwargs)
def select_features(leaf: torch.Tensor): def select_features(leaf: torch.Tensor):
return self._select_image_features( return self._select_image_features(
@ -360,10 +364,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
strategy=self.config.vision_feature_select_strategy, strategy=self.config.vision_feature_select_strategy,
) )
return cast( return json_map_leaves(select_features, image_features)
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
def _select_image_features(self, image_features: torch.Tensor, *, def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor: strategy: str) -> torch.Tensor:

View File

@ -4,7 +4,7 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union, cast) Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -623,7 +623,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since # NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower # this is already done inside the vision tower
image_features = vision_tower(pixel_values) image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
vision_tower(pixel_values)
def select_features(leaf: torch.Tensor): def select_features(leaf: torch.Tensor):
return self._select_image_features( return self._select_image_features(
@ -631,10 +632,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
strategy=self.config.vision_feature_select_strategy, strategy=self.config.vision_feature_select_strategy,
) )
return cast( return json_map_leaves(select_features, image_features)
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
def _process_image_pixels( def _process_image_pixels(
self, self,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import Annotated, Literal, Optional, Union, cast from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -254,7 +254,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since # NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower # this is already done inside the vision tower
image_features = tuple(vision_tower(p) for p in pixel_values) image_features: tuple[torch.Tensor, ...] = \
tuple(vision_tower(p) for p in pixel_values)
def select_features(leaf: torch.Tensor): def select_features(leaf: torch.Tensor):
return self._select_image_features( return self._select_image_features(
@ -262,10 +263,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
strategy=self.config.vision_feature_select_strategy, strategy=self.config.vision_feature_select_strategy,
) )
return cast( return json_map_leaves(select_features, image_features)
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
# adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631 # adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631
def pack_image_features(self, image_features: list[torch.Tensor], def pack_image_features(self, image_features: list[torch.Tensor],

View File

@ -4,7 +4,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union, cast) Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -490,11 +490,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values: Union[torch.Tensor, list[torch.Tensor]], pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# From vLLM LLaVA, vision tower output handling # From vLLM LLaVA, vision tower output handling
image_hidden_states = vision_tower(pixel_values) image_hidden_states: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
if not isinstance(image_hidden_states, torch.Tensor): vision_tower(pixel_values)
raise TypeError(
f"image_hidden_states type: {type(image_hidden_states)}"
" is not supported")
def select_features_fn(leaf: torch.Tensor): def select_features_fn(leaf: torch.Tensor):
return self._select_image_features( return self._select_image_features(
@ -502,11 +499,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
strategy=self.config.vision_feature_select_strategy, strategy=self.config.vision_feature_select_strategy,
) )
selected_features = cast( return json_map_leaves(select_features_fn, image_hidden_states)
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features_fn, image_hidden_states),
)
return selected_features
def _add_tarsier_split_tokens( def _add_tarsier_split_tokens(
self, projected_image_features: torch.Tensor) -> torch.Tensor: self, projected_image_features: torch.Tensor) -> torch.Tensor:

View File

@ -4,7 +4,7 @@
from collections.abc import Iterable from collections.abc import Iterable
from functools import reduce from functools import reduce
from typing import Callable, TypeVar, Union, overload from typing import Callable, TypeVar, Union, cast, overload
_T = TypeVar("_T") _T = TypeVar("_T")
_U = TypeVar("_U") _U = TypeVar("_U")
@ -30,10 +30,42 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
yield value yield value
@overload
def json_map_leaves(
func: Callable[[_T], _U],
value: Union[_T, dict[str, _T]],
) -> Union[_U, dict[str, _U]]:
...
@overload
def json_map_leaves(
func: Callable[[_T], _U],
value: Union[_T, list[_T]],
) -> Union[_U, list[_U]]:
...
@overload
def json_map_leaves(
func: Callable[[_T], _U],
value: Union[_T, tuple[_T, ...]],
) -> Union[_U, tuple[_U, ...]]:
...
@overload
def json_map_leaves( def json_map_leaves(
func: Callable[[_T], _U], func: Callable[[_T], _U],
value: JSONTree[_T], value: JSONTree[_T],
) -> JSONTree[_U]: ) -> JSONTree[_U]:
...
def json_map_leaves(
func: Callable[[_T], _U],
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
) -> Union[dict[str, _U], list[_U], tuple[_U, ...], JSONTree[_U]]:
"""Apply a function to each leaf in a nested JSON structure.""" """Apply a function to each leaf in a nested JSON structure."""
if isinstance(value, dict): if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()} return {k: json_map_leaves(func, v) for k, v in value.items()}
@ -45,6 +77,33 @@ def json_map_leaves(
return func(value) return func(value)
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
value: Union[_T, dict[str, _T]],
/,
) -> _T:
...
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
value: Union[_T, list[_T]],
/,
) -> _T:
...
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
value: Union[_T, tuple[_T, ...]],
/,
) -> _T:
...
@overload @overload
def json_reduce_leaves( def json_reduce_leaves(
func: Callable[[_T, _T], _T], func: Callable[[_T, _T], _T],
@ -65,10 +124,10 @@ def json_reduce_leaves(
def json_reduce_leaves( def json_reduce_leaves(
func: Callable[..., Union[_T, _U]], func: Callable[..., Union[_T, _U]],
value: JSONTree[_T], value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
initial: _U = ..., # type: ignore[assignment] initial: _U = cast(_U, ...), # noqa: B008
/, /,
) -> Union[_T, _U]: ) -> Union[_T, _U]:
""" """
Apply a function of two arguments cumulatively to each leaf in a Apply a function of two arguments cumulatively to each leaf in a