[Misc] Improve type annotations for jsontree (#25577)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-09-24 22:49:58 +08:00 committed by GitHub
parent 8938774c79
commit 9313be5017
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 88 additions and 39 deletions

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, Optional, Union, cast
from typing import Annotated, Literal, Optional, Union
import torch
from torch import nn
@ -347,12 +347,16 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
**kwargs) -> torch.Tensor:
target_dtype = vision_tower.get_input_embeddings().weight.dtype
image_features = vision_tower(pixel_values.to(dtype=target_dtype),
**kwargs)
def _image_pixels_to_features(
self,
vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
**kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
target_dtype: torch.dtype = \
vision_tower.get_input_embeddings().weight.dtype
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
vision_tower(pixel_values.to(dtype=target_dtype), **kwargs)
def select_features(leaf: torch.Tensor):
return self._select_image_features(
@ -360,10 +364,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
strategy=self.config.vision_feature_select_strategy,
)
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
return json_map_leaves(select_features, image_features)
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:

View File

@ -4,7 +4,7 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union, cast)
Union)
import torch
import torch.nn as nn
@ -623,7 +623,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
vision_tower(pixel_values)
def select_features(leaf: torch.Tensor):
return self._select_image_features(
@ -631,10 +632,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
strategy=self.config.vision_feature_select_strategy,
)
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
return json_map_leaves(select_features, image_features)
def _process_image_pixels(
self,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
from typing import Annotated, Literal, Optional, Union, cast
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@ -254,7 +254,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = tuple(vision_tower(p) for p in pixel_values)
image_features: tuple[torch.Tensor, ...] = \
tuple(vision_tower(p) for p in pixel_values)
def select_features(leaf: torch.Tensor):
return self._select_image_features(
@ -262,10 +263,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
strategy=self.config.vision_feature_select_strategy,
)
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
return json_map_leaves(select_features, image_features)
# adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631
def pack_image_features(self, image_features: list[torch.Tensor],

View File

@ -4,7 +4,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union, cast)
Union)
import torch
import torch.nn as nn
@ -490,11 +490,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# From vLLM LLaVA, vision tower output handling
image_hidden_states = vision_tower(pixel_values)
if not isinstance(image_hidden_states, torch.Tensor):
raise TypeError(
f"image_hidden_states type: {type(image_hidden_states)}"
" is not supported")
image_hidden_states: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
vision_tower(pixel_values)
def select_features_fn(leaf: torch.Tensor):
return self._select_image_features(
@ -502,11 +499,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
strategy=self.config.vision_feature_select_strategy,
)
selected_features = cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features_fn, image_hidden_states),
)
return selected_features
return json_map_leaves(select_features_fn, image_hidden_states)
def _add_tarsier_split_tokens(
self, projected_image_features: torch.Tensor) -> torch.Tensor:

View File

@ -4,7 +4,7 @@
from collections.abc import Iterable
from functools import reduce
from typing import Callable, TypeVar, Union, overload
from typing import Callable, TypeVar, Union, cast, overload
_T = TypeVar("_T")
_U = TypeVar("_U")
@ -30,10 +30,42 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
yield value
@overload
def json_map_leaves(
func: Callable[[_T], _U],
value: Union[_T, dict[str, _T]],
) -> Union[_U, dict[str, _U]]:
...
@overload
def json_map_leaves(
func: Callable[[_T], _U],
value: Union[_T, list[_T]],
) -> Union[_U, list[_U]]:
...
@overload
def json_map_leaves(
func: Callable[[_T], _U],
value: Union[_T, tuple[_T, ...]],
) -> Union[_U, tuple[_U, ...]]:
...
@overload
def json_map_leaves(
func: Callable[[_T], _U],
value: JSONTree[_T],
) -> JSONTree[_U]:
...
def json_map_leaves(
func: Callable[[_T], _U],
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
) -> Union[dict[str, _U], list[_U], tuple[_U, ...], JSONTree[_U]]:
"""Apply a function to each leaf in a nested JSON structure."""
if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()}
@ -45,6 +77,33 @@ def json_map_leaves(
return func(value)
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
value: Union[_T, dict[str, _T]],
/,
) -> _T:
...
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
value: Union[_T, list[_T]],
/,
) -> _T:
...
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
value: Union[_T, tuple[_T, ...]],
/,
) -> _T:
...
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
@ -65,10 +124,10 @@ def json_reduce_leaves(
def json_reduce_leaves(
func: Callable[..., Union[_T, _U]],
value: JSONTree[_T],
initial: _U = ..., # type: ignore[assignment]
/,
func: Callable[..., Union[_T, _U]],
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
initial: _U = cast(_U, ...), # noqa: B008
/,
) -> Union[_T, _U]:
"""
Apply a function of two arguments cumulatively to each leaf in a