[Model] Support nested structures for TensorSchema (#26212)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-04 16:20:32 +08:00 committed by GitHub
parent d3d649efec
commit 44ea85137a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 274 additions and 292 deletions

View File

@ -6,37 +6,39 @@ import torch
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
from vllm.model_executor.models.hyperclovax_vision import (
HCXVisionVideoPixelInputs)
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
def test_tensor_schema_valid_tensor():
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3, 32, 32),
pixel_values=torch.randn(16, 64, 3, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_optional_fields():
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3, 32, 32),
pixel_values=torch.randn(16, 64, 3, 32, 32),
image_sizes=None,
)
Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), )
Phi3VImagePixelInputs(pixel_values=torch.randn(16, 64, 3, 32, 32))
def test_tensor_schema_constant_dim_failure():
with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"):
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
pixel_values=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_invalid_types_in_list():
with pytest.raises(ValueError, match="is not a torch.Tensor"):
with pytest.raises(TypeError, match="is not one of the expected types"):
Phi3VImagePixelInputs(
data=[
pixel_values=[
torch.randn(64, 3, 32, 32),
"not_a_tensor",
torch.randn(64, 3, 32, 32),
@ -48,27 +50,28 @@ def test_tensor_schema_invalid_types_in_list():
def test_tensor_schema_rank_mismatch():
with pytest.raises(ValueError, match="has rank 3 but expected 5"):
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3),
pixel_values=torch.randn(16, 64, 3),
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_missing_required_field():
with pytest.raises(ValueError, match="Required field 'data' is missing"):
with pytest.raises(ValueError,
match="Required field 'pixel_values' is missing"):
Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), )
def test_tensor_schema_symbolic_dim_mismatch():
with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
Phi3VImagePixelInputs(
data=torch.randn(12, 64, 3, 32, 32),
pixel_values=torch.randn(12, 64, 3, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_list_tensor_valid():
Phi3VImagePixelInputs(
data=[torch.randn(64, 3, 32, 32) for _ in range(16)],
pixel_values=[torch.randn(64, 3, 32, 32) for _ in range(16)],
image_sizes=torch.randint(0, 256, (16, 2)),
)
@ -76,39 +79,46 @@ def test_tensor_schema_list_tensor_valid():
def test_tensor_schema_variable_patch_counts_valid():
# Each image has a different number of patches (p)
# Each tensor has shape (p, 3, 32, 32)
data = [
torch.randn(16, 3, 32, 32), # p = 16
torch.randn(32, 3, 32, 32), # p = 32
torch.randn(64, 3, 32, 32), # p = 64
]
image_sizes = torch.randint(0, 256, (3, 2)) # bn = 3
Phi3VImagePixelInputs(
data=data,
image_sizes=image_sizes,
pixel_values=[
torch.randn(16, 3, 32, 32), # p = 16
torch.randn(32, 3, 32, 32), # p = 32
torch.randn(64, 3, 32, 32), # p = 64
],
image_sizes=torch.randint(0, 256, (3, 2)), # bn = 3
)
def test_tensor_schema_tuple_tensor_valid():
Phi3VImagePixelInputs(
data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
pixel_values=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_double_nested_tensors():
x = torch.rand(4, 3, 32, 32)
y = torch.rand(2, 3, 32, 32)
HCXVisionVideoPixelInputs(pixel_values_videos=([x, y, x], [y], [x, y]))
def test_tensor_schema_inconsistent_shapes_in_list():
with pytest.raises(ValueError, match="contains inconsistent shapes"):
Phi3VImagePixelInputs(
data=[torch.randn(64, 3, 32, 32),
torch.randn(64, 3, 16, 16)] +
[torch.randn(64, 3, 32, 32) for _ in range(14)],
pixel_values=[
torch.randn(64, 3, 32, 32),
torch.randn(64, 3, 16, 16),
*(torch.randn(64, 3, 32, 32) for _ in range(14)),
],
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_empty_list():
with pytest.raises(ValueError, match="is an empty list"):
with pytest.raises(ValueError, match="is an empty sequence"):
Phi3VImagePixelInputs(
data=[],
pixel_values=[],
image_sizes=torch.randint(0, 256, (0, 2)),
)
@ -117,18 +127,18 @@ def test_tensor_schema_validation_disabled_skips_shape_check():
# This should NOT raise, because validation is turned off
# This would normally fail (dim[2] should be 3, not 4)
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 4, 32, 32),
pixel_values=torch.randn(16, 64, 4, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)),
validate=False,
)
def test_tensor_schema_with_valid_resolve_binding_dims():
data = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
pixel_values = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
image_sizes = torch.randint(0, 256, (16, 2))
Phi3VImagePixelInputs(
data=data,
pixel_values=pixel_values,
image_sizes=image_sizes,
resolve_bindings={
"h": 336,
@ -138,13 +148,13 @@ def test_tensor_schema_with_valid_resolve_binding_dims():
def test_tensor_schema_with_invalid_resolve_binding_dims():
data = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
pixel_values = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
image_sizes = torch.randint(0, 256, (16, 2))
# Should raise because 'h' and 'w' don't match resolve bindings
with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"):
Phi3VImagePixelInputs(
data=data,
pixel_values=pixel_values,
image_sizes=image_sizes,
resolve_bindings={
"h": 336,

View File

@ -29,7 +29,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Callable, Literal, Optional, Union, override
from typing import Annotated, Any, Callable, Literal, Optional, Union
import numpy as np
import torch
@ -1170,7 +1170,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
"video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
overrides.height, height)
height = min(height, override.height)
height = min(height, overrides.height)
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
video_items = []

View File

@ -2,27 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# copied from : https://github.com/huggingface/transformers
import ast
import sys
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from itertools import chain
from typing import Any, Literal, Optional, TypedDict, Union
from itertools import accumulate
from typing import Annotated, Any, Literal, Optional, Union
import numpy as np
import PIL
from einops import rearrange
from PIL import Image
if sys.version_info >= (3, 11):
import typing
Unpack = typing.Unpack
else:
import typing_extensions
Unpack = typing_extensions.Unpack
import torch
import torch.nn as nn
from einops import rearrange
from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
@ -42,11 +31,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix)
from .vision import get_vision_encoder_info
EOT = "<|endofturn|>"
@ -69,28 +60,42 @@ def get_num_combined_frames(
return num_canvases + (leftover_frames > 0)
class HCXVisionMultimodalPixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values_images: list[torch.Tensor]
class HCXVisionImagePixelInputs(TensorSchema):
"""
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
Dimensions:
- n: Number of images
- g: Number of grids
- c: Number of channels (3)
- h: Height
- w: Width
"""
image_sizes_images: list[tuple[Union[int, float]]]
"""
Shape: `[(height, width), ...]`
"""
vision_query_lengths_images: list[Union[int, float]]
pixel_values_videos: list[tuple[Union[int, float]]]
"""
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
"""
vision_query_lengths_videos: list[Union[int, float]]
type: Literal["pixel_values"] = "pixel_values"
pixel_values_images: Annotated[
list[torch.Tensor],
TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})]
image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)]
HCXVisionMultimodalInputs = Union[HCXVisionMultimodalPixelInputs]
HCXVisionImageInputs = HCXVisionImagePixelInputs
class HCXVisionVideoPixelInputs(TensorSchema):
"""
Dimensions:
- n: Number of videos
- f: Number of frames
- g: Number of grids
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values_videos"] = "pixel_values_videos"
pixel_values_videos: Annotated[
list[list[torch.Tensor]],
TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"})]
HCXVisionVideoInputs = HCXVisionVideoPixelInputs
class HCXVisionProcessingInfo(BaseProcessingInfo):
@ -191,27 +196,9 @@ class HCXVisionMultiModalProcessor(
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
def replace_multimodal_token(
token_ids: torch.Tensor,
target_token: int,
repeats: list[int],
):
output = list[int]()
_repeats_idx = 0
for token_id in token_ids:
if token_id == target_token:
output += [token_id.item()] * repeats[_repeats_idx]
_repeats_idx += 1
else:
output += [token_id.item()]
return torch.tensor(output, device=token_ids.device)
for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
if video_arr.dtype == np.uint8:
continue
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
if video_arr.dtype != np.uint8:
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
processed_outputs = self.info.ctx.call_hf_processor(
hf_processor=self.info.get_hf_processor(**mm_kwargs),
@ -223,20 +210,16 @@ class HCXVisionMultiModalProcessor(
) # text-only
if len(mm_data) > 0:
images = mm_data.get("images")
videos = mm_data.get("videos")
# batchify input as a single item
images = mm_data.get("images", None)
batched_images = None if images is None else [images]
# list of video in single conversation
videos = mm_data.get("videos", None)
batched_videos = None if videos is None else [videos]
_processed_outputs = self.info.ctx.call_hf_processor(
hf_processor=self.info.get_hf_processor(**mm_kwargs),
data=dict(
text=None,
images=batched_images,
videos=batched_videos,
images=None if images is None else [images],
videos=None if videos is None else [videos],
),
) # mm-only
@ -246,51 +229,43 @@ class HCXVisionMultiModalProcessor(
_processed_outputs[k] = v[0]
if images:
tokenizer = self.info.get_tokenizer()
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
processed_outputs["input_ids"] = torch.stack([
replace_multimodal_token(
token_ids=_input_ids,
target_token=image_token_id,
repeats=_processed_outputs[
"vision_query_lengths_images"],
) for _input_ids in processed_outputs["input_ids"]
],
dim=0)
_processed_outputs["image_sizes_images"] = torch.tensor(
_processed_outputs["image_sizes_images"])
_processed_outputs[
"vision_query_lengths_images"] = torch.tensor(
_processed_outputs["vision_query_lengths_images"])
if videos:
_num_per_videos = [
get_num_combined_frames(len(video)) for video in videos
_idx_per_video = [
0, *accumulate(
get_num_combined_frames(len(video))
for video in videos)
]
_processed_outputs["pixel_values_videos"] = [
_processed_outputs["pixel_values_videos"]
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
for _i in range(len(videos))
[_idx_per_video[i]:_idx_per_video[i + 1]]
for i in range(len(videos))
]
_processed_outputs["vision_query_lengths_videos"] = [
_processed_outputs["vision_query_lengths_videos"]
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
for _i in range(len(videos))
torch.tensor(
_processed_outputs["vision_query_lengths_videos"]
[_idx_per_video[i]:_idx_per_video[i + 1]])
for i in range(len(videos))
]
tokenizer = self.info.get_tokenizer()
video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN)
processed_outputs["input_ids"] = torch.stack([
replace_multimodal_token(
token_ids=_input_ids,
target_token=video_token_id,
repeats=[
sum(lens) for lens in
_processed_outputs["vision_query_lengths_videos"]
],
) for _input_ids in processed_outputs["input_ids"]
],
dim=0)
processed_outputs.update(_processed_outputs)
return processed_outputs
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
@ -311,11 +286,11 @@ class HCXVisionMultiModalProcessor(
out_item = out_mm_kwargs[modality][item_idx]
if modality == "image":
lens = out_item["vision_query_lengths_images"].data
lens = out_item["vision_query_lengths_images"].data.tolist()
num_tokens = self.info.get_num_image_tokens(
vision_query_length=lens)
elif modality == "video":
lens = out_item["vision_query_lengths_videos"].data
lens = out_item["vision_query_lengths_videos"].data.tolist()
num_tokens = self.info.get_num_video_tokens(
vision_query_length=lens)
else:
@ -343,26 +318,11 @@ class HCXVisionMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
# image
pixel_values_images=MultiModalFieldConfig.batched("image"),
image_sizes_images=MultiModalFieldConfig.batched("image"),
vision_query_lengths_images=MultiModalFieldConfig.batched("image"),
num_queries_vis_abstractors_images=MultiModalFieldConfig.batched(
"image"),
num_queries_vis_abstractors_slow_images=MultiModalFieldConfig.
batched("image"),
first_last_frames_slows_images=MultiModalFieldConfig.batched(
"image"),
# video
pixel_values_videos=MultiModalFieldConfig.batched("video"),
image_sizes_videos=MultiModalFieldConfig.batched("video"),
vision_query_lengths_videos=MultiModalFieldConfig.batched("video"),
num_queries_vis_abstractors_videos=MultiModalFieldConfig.batched(
"video"),
num_queries_vis_abstractors_slow_videos=MultiModalFieldConfig.
batched("video"),
first_last_frames_slows_videos=MultiModalFieldConfig.batched(
"video"),
)
@ -617,6 +577,7 @@ class HCXVisionCAbstractor(nn.Module):
info=_build_hcxvision_hf_info,
dummy_inputs=HCXVisionDummyInputsBuilder)
class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@ -692,55 +653,94 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Only image or video modality is supported")
def _parse_and_validate_image_input(
self,
**kwargs: object,
) -> Optional[HCXVisionImageInputs]:
pixel_values_images = kwargs.pop("pixel_values_images", None)
if pixel_values_images is None:
return None
image_sizes_images = kwargs.pop("image_sizes_images")
return HCXVisionImagePixelInputs(
pixel_values_images=pixel_values_images,
image_sizes_images=image_sizes_images,
)
def _parse_and_validate_video_input(
self,
**kwargs: object,
) -> Optional[HCXVisionVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
if pixel_values_videos is None:
return None
return HCXVisionVideoPixelInputs(
pixel_values_videos=pixel_values_videos, )
def _process_image_input(
self,
image_input: HCXVisionImageInputs,
) -> tuple[torch.Tensor, ...]:
return self.forward_images(
pixel_values_images=image_input["pixel_values_images"],
image_sizes_images=image_input["image_sizes_images"],
)
def _process_video_input(
self,
video_input: HCXVisionVideoInputs,
) -> tuple[torch.Tensor, ...]:
return self.forward_videos(
pixel_values_videos=video_input["pixel_values_videos"], )
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if (input_key == "pixel_values_images"
and "images" not in modalities):
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if (input_key == "pixel_values_videos"
and "videos" not in modalities):
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self,
**kwargs: Unpack[HCXVisionMultimodalInputs],
**kwargs: object,
) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
multimodal_embeddings = list()
if kwargs.get("pixel_values_images") is not None:
for _pixel_values_images, _image_sizes_images in zip(
kwargs["pixel_values_images"],
kwargs["image_sizes_images"]):
_pixel_values_images = _pixel_values_images.unsqueeze(dim=0)
_image_sizes_images = _image_sizes_images.unsqueeze(dim=0)
_len_pixel_values_images = [
len(pixel_value) for pixel_value in _pixel_values_images
]
if isinstance(_image_sizes_images, torch.Tensor):
_image_sizes_images = _image_sizes_images.detach().cpu(
).tolist()
_multimodal_embeddings_images = self.forward_images(
pixel_values_images=_pixel_values_images,
image_sizes_images=_image_sizes_images,
len_pixel_values_images=_len_pixel_values_images,
)
_multimodal_embeddings_images = torch.cat(
_multimodal_embeddings_images, dim=0)
multimodal_embeddings.append(_multimodal_embeddings_images)
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings
if kwargs.get("pixel_values_videos") is not None:
for _pixel_values_videos, _vision_query_lengths_videos in zip(
kwargs["pixel_values_videos"],
kwargs["vision_query_lengths_videos"]):
_len_pixel_values_videos = [
len(_vision_query_lengths)
for _vision_query_lengths in _vision_query_lengths_videos
]
_c, _w, _h = _pixel_values_videos.shape[-3:]
_pixel_values_videos = _pixel_values_videos.reshape(
sum(_len_pixel_values_videos), -1, _c, _w,
_h).unsqueeze(dim=0)
_multimodal_embeddings_videos = self.forward_videos(
pixel_values_videos=_pixel_values_videos,
len_pixel_values_videos=_len_pixel_values_videos,
)
_multimodal_embeddings_videos = torch.cat(
_multimodal_embeddings_videos, dim=0)
multimodal_embeddings.append(_multimodal_embeddings_videos)
return multimodal_embeddings
def forward(
@ -762,28 +762,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward_images(
self,
pixel_values_images: list[list[torch.FloatTensor]],
image_sizes_images: list[list[tuple[int, int]]],
len_pixel_values_images: list[int],
) -> list[list[torch.Tensor]]:
if sum(len_pixel_values_images) == 0:
return None
concat_pixel_values_images = torch.cat(list(
chain(*pixel_values_images)),
dim=0)
pixel_values_images: list[torch.Tensor],
image_sizes_images: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True)
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
image_forward_outs = self.vision_model(
concat_pixel_values_images)[:, visual_token_idx:]
pixel_values_image_flat)[:, visual_token_idx:]
image_forward_outs = image_forward_outs.to(
dtype=self.mm_projector.dtype)
image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d
split_sizes = [
pixel_value.shape[0] for pixel_value in chain(*pixel_values_images)
]
split_sizes = [len(item) for item in pixel_values_images]
image_forward_outs = torch.split(image_forward_outs,
split_sizes,
dim=0)
@ -791,10 +783,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# newline for anyres postprocessing
image_features = anyres_postprocessing(
image_forward_outs=image_forward_outs,
image_sizes=[
image_size for image_sizes in image_sizes_images
for image_size in image_sizes
],
image_sizes=image_sizes_images.tolist(),
num_queries_vis_abstractor=self.config.
num_queries_vis_abstractor_image,
unpad=self.config.unpad,
@ -803,26 +792,21 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_newline=self.image_newline,
possible_resolutions=self.config.possible_resolutions,
)
return image_features
return tuple(image_features)
def forward_videos(
self,
pixel_values_videos: list[list[torch.FloatTensor]],
len_pixel_values_videos: list[int],
) -> list[torch.Tensor]:
len_video_grids = sum(len_pixel_values_videos)
if len_video_grids == 0:
return None
# Run Vision Model
concat_pixel_values_videos = torch.cat(list(
chain(*pixel_values_videos)),
dim=0)
pixel_values_videos: list[list[torch.Tensor]],
) -> tuple[torch.Tensor, ...]:
pixel_values_videos_flat = flatten_bn(
[frame for frames in pixel_values_videos for frame in frames],
concat=True,
)
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
video_forward_outs = self.vision_model(
concat_pixel_values_videos)[:, visual_token_idx:]
pixel_values_videos_flat)[:, visual_token_idx:]
video_forward_outs = video_forward_outs.to(
dtype=self.mm_projector.dtype)
@ -905,7 +889,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) == 0, f"target_features is not empty!! {target_features}"
assert len(video_groups) == len(video_features)
return video_features
feats_per_video = [len(video) for video in pixel_values_videos]
idxs_per_video = [0, *accumulate(feats_per_video)]
return tuple(
torch.cat(video_features[idxs_per_video[i]:idxs_per_video[i + 1]])
for i in range(len(feats_per_video)))
def _prepare_multimodal_kwargs(self, **kwargs: object):
output = defaultdict(list)
@ -1111,15 +1099,15 @@ def reshape_and_unpad_image_features(
def anyres_postprocessing(
image_forward_outs: list[torch.FloatTensor],
image_forward_outs: list[torch.Tensor],
image_sizes: list[list[int]],
possible_resolutions: list[tuple[int, int]],
patch_size: int,
grid_size: int,
image_newline: torch.FloatTensor,
image_newline: torch.Tensor,
num_queries_vis_abstractor: int = -1,
unpad: bool = False,
) -> list[torch.FloatTensor]:
) -> list[torch.Tensor]:
height = width = grid_size // patch_size
if num_queries_vis_abstractor > 0:
@ -1147,26 +1135,5 @@ def anyres_postprocessing(
(image_feature, image_newline[None].to(image_feature.device)),
dim=0)
new_image_features.append(image_feature)
image_features = new_image_features
return image_features
def resize_image(
image: Union[np.ndarray, PIL.Image.Image],
max_side: int = 378,
) -> np.ndarray:
image_arr = image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
width, height = image.size
cur_max_size = max(width, height)
if cur_max_size <= max_side:
return image_arr
scale = max_side / cur_max_size
width = int(width * scale)
height = int(height * scale)
image = image.resize((width, height), Image.LANCZOS)
image_arr = np.array(image)
return image_arr
return new_image_features

View File

@ -109,7 +109,7 @@ class Phi3VImagePixelInputs(TensorSchema):
type: Literal["pixel_values", "image_embeds"] = "pixel_values"
# Supports either a stacked tensor or a list of (p, 3, h, w) tensors
data: Annotated[
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
), # 'p' may vary across items
@ -594,7 +594,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
if pixel_values is not None:
return Phi3VImagePixelInputs(
type="pixel_values",
data=flatten_bn(pixel_values),
pixel_values=flatten_bn(pixel_values),
image_sizes=flatten_bn(image_sizes, concat=True),
resolve_bindings={
"h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
@ -628,7 +628,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
)
assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens(image_input["data"],
image_embeds = self.vision_embed_tokens(image_input["pixel_values"],
image_input["image_sizes"])
return image_embeds

View File

@ -94,34 +94,63 @@ class TensorSchema:
return False
return True
def _validate_nested_tensors(
self,
value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
field_name: str,
expected_shape: tuple[Union[int, str], ...],
dynamic_dims: set[str],
def _fmt_indexer(self, idxs: tuple[int, ...]) -> str:
if not idxs:
return ""
return str(list(idxs))
def _validate_field(
self,
value: object,
field_name: str,
expected_shape: tuple[Union[int, str], ...],
dynamic_dims: set[str],
leading_idxs: tuple[int, ...] = (),
) -> tuple[int, ...]:
"""Validate a list/tuple of tensors and return the actual shape."""
"""Validate a field and return the actual shape."""
if isinstance(value, (int, float)):
return () # Scalar
if isinstance(value, torch.Tensor):
return value.shape
if not isinstance(value, (list, tuple)):
raise TypeError(
f"{field_name}{self._fmt_indexer(leading_idxs)} is not "
f"one of the expected types: int, float, Tensor, list, tuple. "
f"Got: {type(value)}")
if len(value) == 0:
raise ValueError(f"{field_name}{self._fmt_indexer(leading_idxs)} "
f"is an empty sequence")
# Ensure all tensors in the list have the same
# shape, besides dynamic dimensions
first = value[0]
for i, v in enumerate(value):
if not isinstance(v, torch.Tensor):
raise ValueError(f"{field_name}[{i}] is not a "
f"torch.Tensor")
if not self._match_shape_with_dynamic(
v.shape,
first.shape,
shape = self._validate_field(
v,
field_name,
expected_shape[1:],
dynamic_dims,
leading_idxs=leading_idxs + (i, ),
)
if i == 0:
first_shape = shape
elif not self._match_shape_with_dynamic(
shape,
first_shape,
expected_shape,
dynamic_dims,
):
raise ValueError(f"{field_name} contains inconsistent "
f"shapes: {first.shape} vs {v.shape} "
f"at index {i}")
raise ValueError(
f"{field_name}{self._fmt_indexer(leading_idxs)} "
f"contains inconsistent shapes: {first_shape} "
f"(index 0) vs {shape} (index {i})")
# Treat the list as a stacked tensor:
# shape = (len(list), *tensor.shape)
return (len(value), ) + first.shape
return (len(value), ) + first_shape
def _validate_tensor_shape_expected(
self,
@ -187,36 +216,12 @@ class TensorSchema:
for arg in args:
if isinstance(arg, TensorShape):
expected_shape = arg.resolve(**self._resolve_bindings)
if isinstance(value, (list, tuple)):
# list/tuple of Tensors → shape = (len(value), ...)
if value and isinstance(value[0], torch.Tensor):
actual_shape = self._validate_nested_tensors(
value, field_name, expected_shape,
arg.dynamic_dims)
elif value:
# list/tuple of scalars → shape = (len(value),)
actual_shape = (len(value), )
else:
raise ValueError(
f"{field_name} is an empty list")
# Tensor → shape = tensor.shape
elif isinstance(value, torch.Tensor):
actual_shape = value.shape
# Otherwise, it's an unsupported type
else:
type_names = []
for arg in args:
if hasattr(arg, "__name__"):
type_names.append(str(arg.__name__))
else:
type_names.append(str(arg))
expected_types = ", ".join(type_names)
raise ValueError(
f"{field_name} is not one of the expected "
f"types: {expected_types}")
actual_shape = self._validate_field(
value,
field_name,
expected_shape,
arg.dynamic_dims,
)
self._validate_tensor_shape_expected(
actual_shape, expected_shape, field_name,