[Model] Support nested structures for TensorSchema (#26212)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-04 16:20:32 +08:00 committed by GitHub
parent d3d649efec
commit 44ea85137a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 274 additions and 292 deletions

View File

@ -6,37 +6,39 @@ import torch
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
from vllm.model_executor.models.hyperclovax_vision import (
HCXVisionVideoPixelInputs)
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
def test_tensor_schema_valid_tensor(): def test_tensor_schema_valid_tensor():
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3, 32, 32), pixel_values=torch.randn(16, 64, 3, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
def test_tensor_schema_optional_fields(): def test_tensor_schema_optional_fields():
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3, 32, 32), pixel_values=torch.randn(16, 64, 3, 32, 32),
image_sizes=None, image_sizes=None,
) )
Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), ) Phi3VImagePixelInputs(pixel_values=torch.randn(16, 64, 3, 32, 32))
def test_tensor_schema_constant_dim_failure(): def test_tensor_schema_constant_dim_failure():
with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"): with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4 pixel_values=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
def test_tensor_schema_invalid_types_in_list(): def test_tensor_schema_invalid_types_in_list():
with pytest.raises(ValueError, match="is not a torch.Tensor"): with pytest.raises(TypeError, match="is not one of the expected types"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=[ pixel_values=[
torch.randn(64, 3, 32, 32), torch.randn(64, 3, 32, 32),
"not_a_tensor", "not_a_tensor",
torch.randn(64, 3, 32, 32), torch.randn(64, 3, 32, 32),
@ -48,27 +50,28 @@ def test_tensor_schema_invalid_types_in_list():
def test_tensor_schema_rank_mismatch(): def test_tensor_schema_rank_mismatch():
with pytest.raises(ValueError, match="has rank 3 but expected 5"): with pytest.raises(ValueError, match="has rank 3 but expected 5"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3), pixel_values=torch.randn(16, 64, 3),
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
def test_tensor_schema_missing_required_field(): def test_tensor_schema_missing_required_field():
with pytest.raises(ValueError, match="Required field 'data' is missing"): with pytest.raises(ValueError,
match="Required field 'pixel_values' is missing"):
Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), ) Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), )
def test_tensor_schema_symbolic_dim_mismatch(): def test_tensor_schema_symbolic_dim_mismatch():
with pytest.raises(ValueError, match="expected 'bn'=12, got 16"): with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=torch.randn(12, 64, 3, 32, 32), pixel_values=torch.randn(12, 64, 3, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
def test_tensor_schema_list_tensor_valid(): def test_tensor_schema_list_tensor_valid():
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=[torch.randn(64, 3, 32, 32) for _ in range(16)], pixel_values=[torch.randn(64, 3, 32, 32) for _ in range(16)],
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
@ -76,39 +79,46 @@ def test_tensor_schema_list_tensor_valid():
def test_tensor_schema_variable_patch_counts_valid(): def test_tensor_schema_variable_patch_counts_valid():
# Each image has a different number of patches (p) # Each image has a different number of patches (p)
# Each tensor has shape (p, 3, 32, 32) # Each tensor has shape (p, 3, 32, 32)
data = [
torch.randn(16, 3, 32, 32), # p = 16
torch.randn(32, 3, 32, 32), # p = 32
torch.randn(64, 3, 32, 32), # p = 64
]
image_sizes = torch.randint(0, 256, (3, 2)) # bn = 3
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=data, pixel_values=[
image_sizes=image_sizes, torch.randn(16, 3, 32, 32), # p = 16
torch.randn(32, 3, 32, 32), # p = 32
torch.randn(64, 3, 32, 32), # p = 64
],
image_sizes=torch.randint(0, 256, (3, 2)), # bn = 3
) )
def test_tensor_schema_tuple_tensor_valid(): def test_tensor_schema_tuple_tensor_valid():
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)), pixel_values=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
def test_tensor_schema_double_nested_tensors():
x = torch.rand(4, 3, 32, 32)
y = torch.rand(2, 3, 32, 32)
HCXVisionVideoPixelInputs(pixel_values_videos=([x, y, x], [y], [x, y]))
def test_tensor_schema_inconsistent_shapes_in_list(): def test_tensor_schema_inconsistent_shapes_in_list():
with pytest.raises(ValueError, match="contains inconsistent shapes"): with pytest.raises(ValueError, match="contains inconsistent shapes"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=[torch.randn(64, 3, 32, 32), pixel_values=[
torch.randn(64, 3, 16, 16)] + torch.randn(64, 3, 32, 32),
[torch.randn(64, 3, 32, 32) for _ in range(14)], torch.randn(64, 3, 16, 16),
*(torch.randn(64, 3, 32, 32) for _ in range(14)),
],
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
def test_tensor_schema_empty_list(): def test_tensor_schema_empty_list():
with pytest.raises(ValueError, match="is an empty list"): with pytest.raises(ValueError, match="is an empty sequence"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=[], pixel_values=[],
image_sizes=torch.randint(0, 256, (0, 2)), image_sizes=torch.randint(0, 256, (0, 2)),
) )
@ -117,18 +127,18 @@ def test_tensor_schema_validation_disabled_skips_shape_check():
# This should NOT raise, because validation is turned off # This should NOT raise, because validation is turned off
# This would normally fail (dim[2] should be 3, not 4) # This would normally fail (dim[2] should be 3, not 4)
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=torch.randn(16, 64, 4, 32, 32), pixel_values=torch.randn(16, 64, 4, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
validate=False, validate=False,
) )
def test_tensor_schema_with_valid_resolve_binding_dims(): def test_tensor_schema_with_valid_resolve_binding_dims():
data = torch.randn(16, 64, 3, 336, 336) # h=336, w=336 pixel_values = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
image_sizes = torch.randint(0, 256, (16, 2)) image_sizes = torch.randint(0, 256, (16, 2))
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=data, pixel_values=pixel_values,
image_sizes=image_sizes, image_sizes=image_sizes,
resolve_bindings={ resolve_bindings={
"h": 336, "h": 336,
@ -138,13 +148,13 @@ def test_tensor_schema_with_valid_resolve_binding_dims():
def test_tensor_schema_with_invalid_resolve_binding_dims(): def test_tensor_schema_with_invalid_resolve_binding_dims():
data = torch.randn(16, 64, 3, 36, 36) # h=36, w=36 pixel_values = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
image_sizes = torch.randint(0, 256, (16, 2)) image_sizes = torch.randint(0, 256, (16, 2))
# Should raise because 'h' and 'w' don't match resolve bindings # Should raise because 'h' and 'w' don't match resolve bindings
with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"): with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=data, pixel_values=pixel_values,
image_sizes=image_sizes, image_sizes=image_sizes,
resolve_bindings={ resolve_bindings={
"h": 336, "h": 336,

View File

@ -29,7 +29,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Callable, Literal, Optional, Union, override from typing import Annotated, Any, Callable, Literal, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -1170,7 +1170,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
"video.height override (%d) exceeds model's " "video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored", "maximum height (%d), will be ignored",
overrides.height, height) overrides.height, height)
height = min(height, override.height) height = min(height, overrides.height)
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
video_items = [] video_items = []

View File

@ -2,27 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# copied from : https://github.com/huggingface/transformers # copied from : https://github.com/huggingface/transformers
import ast import ast
import sys
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from itertools import chain from itertools import accumulate
from typing import Any, Literal, Optional, TypedDict, Union from typing import Annotated, Any, Literal, Optional, Union
import numpy as np import numpy as np
import PIL
from einops import rearrange
from PIL import Image
if sys.version_info >= (3, 11):
import typing
Unpack = typing.Unpack
else:
import typing_extensions
Unpack = typing_extensions.Unpack
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange
from timm.layers import LayerNorm, LayerNorm2d from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage from timm.models.regnet import RegStage
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
@ -42,11 +31,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix)
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
EOT = "<|endofturn|>" EOT = "<|endofturn|>"
@ -69,28 +60,42 @@ def get_num_combined_frames(
return num_canvases + (leftover_frames > 0) return num_canvases + (leftover_frames > 0)
class HCXVisionMultimodalPixelInputs(TypedDict): class HCXVisionImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values_images: list[torch.Tensor]
""" """
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres Dimensions:
- n: Number of images
Note that `height` or `width` may be different per batch and image, - g: Number of grids
in which case the data is passed as a list instead of a batched tensor. - c: Number of channels (3)
- h: Height
- w: Width
""" """
image_sizes_images: list[tuple[Union[int, float]]] type: Literal["pixel_values"] = "pixel_values"
""" pixel_values_images: Annotated[
Shape: `[(height, width), ...]` list[torch.Tensor],
""" TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})]
vision_query_lengths_images: list[Union[int, float]] image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)]
pixel_values_videos: list[tuple[Union[int, float]]]
"""
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
"""
vision_query_lengths_videos: list[Union[int, float]]
HCXVisionMultimodalInputs = Union[HCXVisionMultimodalPixelInputs] HCXVisionImageInputs = HCXVisionImagePixelInputs
class HCXVisionVideoPixelInputs(TensorSchema):
"""
Dimensions:
- n: Number of videos
- f: Number of frames
- g: Number of grids
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values_videos"] = "pixel_values_videos"
pixel_values_videos: Annotated[
list[list[torch.Tensor]],
TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"})]
HCXVisionVideoInputs = HCXVisionVideoPixelInputs
class HCXVisionProcessingInfo(BaseProcessingInfo): class HCXVisionProcessingInfo(BaseProcessingInfo):
@ -191,27 +196,9 @@ class HCXVisionMultiModalProcessor(
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
def replace_multimodal_token(
token_ids: torch.Tensor,
target_token: int,
repeats: list[int],
):
output = list[int]()
_repeats_idx = 0
for token_id in token_ids:
if token_id == target_token:
output += [token_id.item()] * repeats[_repeats_idx]
_repeats_idx += 1
else:
output += [token_id.item()]
return torch.tensor(output, device=token_ids.device)
for video_idx, video_arr in enumerate(mm_data.get("videos", [])): for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
if video_arr.dtype == np.uint8: if video_arr.dtype != np.uint8:
continue mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
processed_outputs = self.info.ctx.call_hf_processor( processed_outputs = self.info.ctx.call_hf_processor(
hf_processor=self.info.get_hf_processor(**mm_kwargs), hf_processor=self.info.get_hf_processor(**mm_kwargs),
@ -223,20 +210,16 @@ class HCXVisionMultiModalProcessor(
) # text-only ) # text-only
if len(mm_data) > 0: if len(mm_data) > 0:
images = mm_data.get("images")
videos = mm_data.get("videos")
# batchify input as a single item # batchify input as a single item
images = mm_data.get("images", None)
batched_images = None if images is None else [images]
# list of video in single conversation
videos = mm_data.get("videos", None)
batched_videos = None if videos is None else [videos]
_processed_outputs = self.info.ctx.call_hf_processor( _processed_outputs = self.info.ctx.call_hf_processor(
hf_processor=self.info.get_hf_processor(**mm_kwargs), hf_processor=self.info.get_hf_processor(**mm_kwargs),
data=dict( data=dict(
text=None, text=None,
images=batched_images, images=None if images is None else [images],
videos=batched_videos, videos=None if videos is None else [videos],
), ),
) # mm-only ) # mm-only
@ -246,51 +229,43 @@ class HCXVisionMultiModalProcessor(
_processed_outputs[k] = v[0] _processed_outputs[k] = v[0]
if images: if images:
tokenizer = self.info.get_tokenizer() _processed_outputs["image_sizes_images"] = torch.tensor(
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) _processed_outputs["image_sizes_images"])
processed_outputs["input_ids"] = torch.stack([ _processed_outputs[
replace_multimodal_token( "vision_query_lengths_images"] = torch.tensor(
token_ids=_input_ids, _processed_outputs["vision_query_lengths_images"])
target_token=image_token_id,
repeats=_processed_outputs[
"vision_query_lengths_images"],
) for _input_ids in processed_outputs["input_ids"]
],
dim=0)
if videos: if videos:
_num_per_videos = [ _idx_per_video = [
get_num_combined_frames(len(video)) for video in videos 0, *accumulate(
get_num_combined_frames(len(video))
for video in videos)
] ]
_processed_outputs["pixel_values_videos"] = [ _processed_outputs["pixel_values_videos"] = [
_processed_outputs["pixel_values_videos"] _processed_outputs["pixel_values_videos"]
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] [_idx_per_video[i]:_idx_per_video[i + 1]]
for _i in range(len(videos)) for i in range(len(videos))
] ]
_processed_outputs["vision_query_lengths_videos"] = [ _processed_outputs["vision_query_lengths_videos"] = [
_processed_outputs["vision_query_lengths_videos"] torch.tensor(
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] _processed_outputs["vision_query_lengths_videos"]
for _i in range(len(videos)) [_idx_per_video[i]:_idx_per_video[i + 1]])
for i in range(len(videos))
] ]
tokenizer = self.info.get_tokenizer()
video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN)
processed_outputs["input_ids"] = torch.stack([
replace_multimodal_token(
token_ids=_input_ids,
target_token=video_token_id,
repeats=[
sum(lens) for lens in
_processed_outputs["vision_query_lengths_videos"]
],
) for _input_ids in processed_outputs["input_ids"]
],
dim=0)
processed_outputs.update(_processed_outputs) processed_outputs.update(_processed_outputs)
return processed_outputs return processed_outputs
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
@ -311,11 +286,11 @@ class HCXVisionMultiModalProcessor(
out_item = out_mm_kwargs[modality][item_idx] out_item = out_mm_kwargs[modality][item_idx]
if modality == "image": if modality == "image":
lens = out_item["vision_query_lengths_images"].data lens = out_item["vision_query_lengths_images"].data.tolist()
num_tokens = self.info.get_num_image_tokens( num_tokens = self.info.get_num_image_tokens(
vision_query_length=lens) vision_query_length=lens)
elif modality == "video": elif modality == "video":
lens = out_item["vision_query_lengths_videos"].data lens = out_item["vision_query_lengths_videos"].data.tolist()
num_tokens = self.info.get_num_video_tokens( num_tokens = self.info.get_num_video_tokens(
vision_query_length=lens) vision_query_length=lens)
else: else:
@ -343,26 +318,11 @@ class HCXVisionMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict( return dict(
# image
pixel_values_images=MultiModalFieldConfig.batched("image"), pixel_values_images=MultiModalFieldConfig.batched("image"),
image_sizes_images=MultiModalFieldConfig.batched("image"), image_sizes_images=MultiModalFieldConfig.batched("image"),
vision_query_lengths_images=MultiModalFieldConfig.batched("image"), vision_query_lengths_images=MultiModalFieldConfig.batched("image"),
num_queries_vis_abstractors_images=MultiModalFieldConfig.batched(
"image"),
num_queries_vis_abstractors_slow_images=MultiModalFieldConfig.
batched("image"),
first_last_frames_slows_images=MultiModalFieldConfig.batched(
"image"),
# video
pixel_values_videos=MultiModalFieldConfig.batched("video"), pixel_values_videos=MultiModalFieldConfig.batched("video"),
image_sizes_videos=MultiModalFieldConfig.batched("video"),
vision_query_lengths_videos=MultiModalFieldConfig.batched("video"), vision_query_lengths_videos=MultiModalFieldConfig.batched("video"),
num_queries_vis_abstractors_videos=MultiModalFieldConfig.batched(
"video"),
num_queries_vis_abstractors_slow_videos=MultiModalFieldConfig.
batched("video"),
first_last_frames_slows_videos=MultiModalFieldConfig.batched(
"video"),
) )
@ -617,6 +577,7 @@ class HCXVisionCAbstractor(nn.Module):
info=_build_hcxvision_hf_info, info=_build_hcxvision_hf_info,
dummy_inputs=HCXVisionDummyInputsBuilder) dummy_inputs=HCXVisionDummyInputsBuilder)
class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
@ -692,55 +653,94 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Only image or video modality is supported") raise ValueError("Only image or video modality is supported")
def _parse_and_validate_image_input(
self,
**kwargs: object,
) -> Optional[HCXVisionImageInputs]:
pixel_values_images = kwargs.pop("pixel_values_images", None)
if pixel_values_images is None:
return None
image_sizes_images = kwargs.pop("image_sizes_images")
return HCXVisionImagePixelInputs(
pixel_values_images=pixel_values_images,
image_sizes_images=image_sizes_images,
)
def _parse_and_validate_video_input(
self,
**kwargs: object,
) -> Optional[HCXVisionVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
if pixel_values_videos is None:
return None
return HCXVisionVideoPixelInputs(
pixel_values_videos=pixel_values_videos, )
def _process_image_input(
self,
image_input: HCXVisionImageInputs,
) -> tuple[torch.Tensor, ...]:
return self.forward_images(
pixel_values_images=image_input["pixel_values_images"],
image_sizes_images=image_input["image_sizes_images"],
)
def _process_video_input(
self,
video_input: HCXVisionVideoInputs,
) -> tuple[torch.Tensor, ...]:
return self.forward_videos(
pixel_values_videos=video_input["pixel_values_videos"], )
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if (input_key == "pixel_values_images"
and "images" not in modalities):
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if (input_key == "pixel_values_videos"
and "videos" not in modalities):
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, self,
**kwargs: Unpack[HCXVisionMultimodalInputs], **kwargs: object,
) -> MultiModalEmbeddings: ) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
multimodal_embeddings = list() # The result multimodal_embeddings is tuple of tensors, with each
if kwargs.get("pixel_values_images") is not None: # tensor correspoending to a multimodal data item (image or video).
for _pixel_values_images, _image_sizes_images in zip( multimodal_embeddings: tuple[torch.Tensor, ...] = ()
kwargs["pixel_values_images"],
kwargs["image_sizes_images"]): # NOTE: It is important to iterate over the keys in this dictionary
_pixel_values_images = _pixel_values_images.unsqueeze(dim=0) # to preserve the order of the modalities.
_image_sizes_images = _image_sizes_images.unsqueeze(dim=0) for modality in modalities:
_len_pixel_values_images = [ if modality == "images":
len(pixel_value) for pixel_value in _pixel_values_images image_input = modalities["images"]
] vision_embeddings = self._process_image_input(image_input)
if isinstance(_image_sizes_images, torch.Tensor): multimodal_embeddings += vision_embeddings
_image_sizes_images = _image_sizes_images.detach().cpu( if modality == "videos":
).tolist() video_input = modalities["videos"]
_multimodal_embeddings_images = self.forward_images( video_embeddings = self._process_video_input(video_input)
pixel_values_images=_pixel_values_images, multimodal_embeddings += video_embeddings
image_sizes_images=_image_sizes_images,
len_pixel_values_images=_len_pixel_values_images,
)
_multimodal_embeddings_images = torch.cat(
_multimodal_embeddings_images, dim=0)
multimodal_embeddings.append(_multimodal_embeddings_images)
if kwargs.get("pixel_values_videos") is not None:
for _pixel_values_videos, _vision_query_lengths_videos in zip(
kwargs["pixel_values_videos"],
kwargs["vision_query_lengths_videos"]):
_len_pixel_values_videos = [
len(_vision_query_lengths)
for _vision_query_lengths in _vision_query_lengths_videos
]
_c, _w, _h = _pixel_values_videos.shape[-3:]
_pixel_values_videos = _pixel_values_videos.reshape(
sum(_len_pixel_values_videos), -1, _c, _w,
_h).unsqueeze(dim=0)
_multimodal_embeddings_videos = self.forward_videos(
pixel_values_videos=_pixel_values_videos,
len_pixel_values_videos=_len_pixel_values_videos,
)
_multimodal_embeddings_videos = torch.cat(
_multimodal_embeddings_videos, dim=0)
multimodal_embeddings.append(_multimodal_embeddings_videos)
return multimodal_embeddings return multimodal_embeddings
def forward( def forward(
@ -762,28 +762,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward_images( def forward_images(
self, self,
pixel_values_images: list[list[torch.FloatTensor]], pixel_values_images: list[torch.Tensor],
image_sizes_images: list[list[tuple[int, int]]], image_sizes_images: torch.Tensor,
len_pixel_values_images: list[int], ) -> tuple[torch.Tensor, ...]:
) -> list[list[torch.Tensor]]: pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True)
if sum(len_pixel_values_images) == 0:
return None
concat_pixel_values_images = torch.cat(list(
chain(*pixel_values_images)),
dim=0)
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
image_forward_outs = self.vision_model( image_forward_outs = self.vision_model(
concat_pixel_values_images)[:, visual_token_idx:] pixel_values_image_flat)[:, visual_token_idx:]
image_forward_outs = image_forward_outs.to( image_forward_outs = image_forward_outs.to(
dtype=self.mm_projector.dtype) dtype=self.mm_projector.dtype)
image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d
split_sizes = [ split_sizes = [len(item) for item in pixel_values_images]
pixel_value.shape[0] for pixel_value in chain(*pixel_values_images)
]
image_forward_outs = torch.split(image_forward_outs, image_forward_outs = torch.split(image_forward_outs,
split_sizes, split_sizes,
dim=0) dim=0)
@ -791,10 +783,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# newline for anyres postprocessing # newline for anyres postprocessing
image_features = anyres_postprocessing( image_features = anyres_postprocessing(
image_forward_outs=image_forward_outs, image_forward_outs=image_forward_outs,
image_sizes=[ image_sizes=image_sizes_images.tolist(),
image_size for image_sizes in image_sizes_images
for image_size in image_sizes
],
num_queries_vis_abstractor=self.config. num_queries_vis_abstractor=self.config.
num_queries_vis_abstractor_image, num_queries_vis_abstractor_image,
unpad=self.config.unpad, unpad=self.config.unpad,
@ -803,26 +792,21 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_newline=self.image_newline, image_newline=self.image_newline,
possible_resolutions=self.config.possible_resolutions, possible_resolutions=self.config.possible_resolutions,
) )
return image_features
return tuple(image_features)
def forward_videos( def forward_videos(
self, self,
pixel_values_videos: list[list[torch.FloatTensor]], pixel_values_videos: list[list[torch.Tensor]],
len_pixel_values_videos: list[int], ) -> tuple[torch.Tensor, ...]:
) -> list[torch.Tensor]: pixel_values_videos_flat = flatten_bn(
[frame for frames in pixel_values_videos for frame in frames],
len_video_grids = sum(len_pixel_values_videos) concat=True,
if len_video_grids == 0: )
return None
# Run Vision Model
concat_pixel_values_videos = torch.cat(list(
chain(*pixel_values_videos)),
dim=0)
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
video_forward_outs = self.vision_model( video_forward_outs = self.vision_model(
concat_pixel_values_videos)[:, visual_token_idx:] pixel_values_videos_flat)[:, visual_token_idx:]
video_forward_outs = video_forward_outs.to( video_forward_outs = video_forward_outs.to(
dtype=self.mm_projector.dtype) dtype=self.mm_projector.dtype)
@ -905,7 +889,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) == 0, f"target_features is not empty!! {target_features}" ) == 0, f"target_features is not empty!! {target_features}"
assert len(video_groups) == len(video_features) assert len(video_groups) == len(video_features)
return video_features feats_per_video = [len(video) for video in pixel_values_videos]
idxs_per_video = [0, *accumulate(feats_per_video)]
return tuple(
torch.cat(video_features[idxs_per_video[i]:idxs_per_video[i + 1]])
for i in range(len(feats_per_video)))
def _prepare_multimodal_kwargs(self, **kwargs: object): def _prepare_multimodal_kwargs(self, **kwargs: object):
output = defaultdict(list) output = defaultdict(list)
@ -1111,15 +1099,15 @@ def reshape_and_unpad_image_features(
def anyres_postprocessing( def anyres_postprocessing(
image_forward_outs: list[torch.FloatTensor], image_forward_outs: list[torch.Tensor],
image_sizes: list[list[int]], image_sizes: list[list[int]],
possible_resolutions: list[tuple[int, int]], possible_resolutions: list[tuple[int, int]],
patch_size: int, patch_size: int,
grid_size: int, grid_size: int,
image_newline: torch.FloatTensor, image_newline: torch.Tensor,
num_queries_vis_abstractor: int = -1, num_queries_vis_abstractor: int = -1,
unpad: bool = False, unpad: bool = False,
) -> list[torch.FloatTensor]: ) -> list[torch.Tensor]:
height = width = grid_size // patch_size height = width = grid_size // patch_size
if num_queries_vis_abstractor > 0: if num_queries_vis_abstractor > 0:
@ -1147,26 +1135,5 @@ def anyres_postprocessing(
(image_feature, image_newline[None].to(image_feature.device)), (image_feature, image_newline[None].to(image_feature.device)),
dim=0) dim=0)
new_image_features.append(image_feature) new_image_features.append(image_feature)
image_features = new_image_features
return image_features
return new_image_features
def resize_image(
image: Union[np.ndarray, PIL.Image.Image],
max_side: int = 378,
) -> np.ndarray:
image_arr = image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
width, height = image.size
cur_max_size = max(width, height)
if cur_max_size <= max_side:
return image_arr
scale = max_side / cur_max_size
width = int(width * scale)
height = int(height * scale)
image = image.resize((width, height), Image.LANCZOS)
image_arr = np.array(image)
return image_arr

View File

@ -109,7 +109,7 @@ class Phi3VImagePixelInputs(TensorSchema):
type: Literal["pixel_values", "image_embeds"] = "pixel_values" type: Literal["pixel_values", "image_embeds"] = "pixel_values"
# Supports either a stacked tensor or a list of (p, 3, h, w) tensors # Supports either a stacked tensor or a list of (p, 3, h, w) tensors
data: Annotated[ pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]], Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
), # 'p' may vary across items ), # 'p' may vary across items
@ -594,7 +594,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
if pixel_values is not None: if pixel_values is not None:
return Phi3VImagePixelInputs( return Phi3VImagePixelInputs(
type="pixel_values", type="pixel_values",
data=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
image_sizes=flatten_bn(image_sizes, concat=True), image_sizes=flatten_bn(image_sizes, concat=True),
resolve_bindings={ resolve_bindings={
"h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
@ -628,7 +628,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
) )
assert self.vision_embed_tokens is not None assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens(image_input["data"], image_embeds = self.vision_embed_tokens(image_input["pixel_values"],
image_input["image_sizes"]) image_input["image_sizes"])
return image_embeds return image_embeds

View File

@ -94,34 +94,63 @@ class TensorSchema:
return False return False
return True return True
def _validate_nested_tensors( def _fmt_indexer(self, idxs: tuple[int, ...]) -> str:
self, if not idxs:
value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], return ""
field_name: str,
expected_shape: tuple[Union[int, str], ...], return str(list(idxs))
dynamic_dims: set[str],
def _validate_field(
self,
value: object,
field_name: str,
expected_shape: tuple[Union[int, str], ...],
dynamic_dims: set[str],
leading_idxs: tuple[int, ...] = (),
) -> tuple[int, ...]: ) -> tuple[int, ...]:
"""Validate a list/tuple of tensors and return the actual shape.""" """Validate a field and return the actual shape."""
if isinstance(value, (int, float)):
return () # Scalar
if isinstance(value, torch.Tensor):
return value.shape
if not isinstance(value, (list, tuple)):
raise TypeError(
f"{field_name}{self._fmt_indexer(leading_idxs)} is not "
f"one of the expected types: int, float, Tensor, list, tuple. "
f"Got: {type(value)}")
if len(value) == 0:
raise ValueError(f"{field_name}{self._fmt_indexer(leading_idxs)} "
f"is an empty sequence")
# Ensure all tensors in the list have the same # Ensure all tensors in the list have the same
# shape, besides dynamic dimensions # shape, besides dynamic dimensions
first = value[0]
for i, v in enumerate(value): for i, v in enumerate(value):
if not isinstance(v, torch.Tensor): shape = self._validate_field(
raise ValueError(f"{field_name}[{i}] is not a " v,
f"torch.Tensor") field_name,
if not self._match_shape_with_dynamic( expected_shape[1:],
v.shape, dynamic_dims,
first.shape, leading_idxs=leading_idxs + (i, ),
)
if i == 0:
first_shape = shape
elif not self._match_shape_with_dynamic(
shape,
first_shape,
expected_shape, expected_shape,
dynamic_dims, dynamic_dims,
): ):
raise ValueError(f"{field_name} contains inconsistent " raise ValueError(
f"shapes: {first.shape} vs {v.shape} " f"{field_name}{self._fmt_indexer(leading_idxs)} "
f"at index {i}") f"contains inconsistent shapes: {first_shape} "
f"(index 0) vs {shape} (index {i})")
# Treat the list as a stacked tensor: # Treat the list as a stacked tensor:
# shape = (len(list), *tensor.shape) # shape = (len(list), *tensor.shape)
return (len(value), ) + first.shape return (len(value), ) + first_shape
def _validate_tensor_shape_expected( def _validate_tensor_shape_expected(
self, self,
@ -187,36 +216,12 @@ class TensorSchema:
for arg in args: for arg in args:
if isinstance(arg, TensorShape): if isinstance(arg, TensorShape):
expected_shape = arg.resolve(**self._resolve_bindings) expected_shape = arg.resolve(**self._resolve_bindings)
if isinstance(value, (list, tuple)): actual_shape = self._validate_field(
# list/tuple of Tensors → shape = (len(value), ...) value,
if value and isinstance(value[0], torch.Tensor): field_name,
actual_shape = self._validate_nested_tensors( expected_shape,
value, field_name, expected_shape, arg.dynamic_dims,
arg.dynamic_dims) )
elif value:
# list/tuple of scalars → shape = (len(value),)
actual_shape = (len(value), )
else:
raise ValueError(
f"{field_name} is an empty list")
# Tensor → shape = tensor.shape
elif isinstance(value, torch.Tensor):
actual_shape = value.shape
# Otherwise, it's an unsupported type
else:
type_names = []
for arg in args:
if hasattr(arg, "__name__"):
type_names.append(str(arg.__name__))
else:
type_names.append(str(arg))
expected_types = ", ".join(type_names)
raise ValueError(
f"{field_name} is not one of the expected "
f"types: {expected_types}")
self._validate_tensor_shape_expected( self._validate_tensor_shape_expected(
actual_shape, expected_shape, field_name, actual_shape, expected_shape, field_name,