diff --git a/tests/utils_/test_tensor_schema.py b/tests/utils_/test_tensor_schema.py index 6aa781c1564de..102d58ec452b4 100644 --- a/tests/utils_/test_tensor_schema.py +++ b/tests/utils_/test_tensor_schema.py @@ -6,37 +6,39 @@ import torch from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs +from vllm.model_executor.models.hyperclovax_vision import ( + HCXVisionVideoPixelInputs) from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs def test_tensor_schema_valid_tensor(): Phi3VImagePixelInputs( - data=torch.randn(16, 64, 3, 32, 32), + pixel_values=torch.randn(16, 64, 3, 32, 32), image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_optional_fields(): Phi3VImagePixelInputs( - data=torch.randn(16, 64, 3, 32, 32), + pixel_values=torch.randn(16, 64, 3, 32, 32), image_sizes=None, ) - Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), ) + Phi3VImagePixelInputs(pixel_values=torch.randn(16, 64, 3, 32, 32)) def test_tensor_schema_constant_dim_failure(): with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"): Phi3VImagePixelInputs( - data=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4 + pixel_values=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4 image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_invalid_types_in_list(): - with pytest.raises(ValueError, match="is not a torch.Tensor"): + with pytest.raises(TypeError, match="is not one of the expected types"): Phi3VImagePixelInputs( - data=[ + pixel_values=[ torch.randn(64, 3, 32, 32), "not_a_tensor", torch.randn(64, 3, 32, 32), @@ -48,27 +50,28 @@ def test_tensor_schema_invalid_types_in_list(): def test_tensor_schema_rank_mismatch(): with pytest.raises(ValueError, match="has rank 3 but expected 5"): Phi3VImagePixelInputs( - data=torch.randn(16, 64, 3), + pixel_values=torch.randn(16, 64, 3), image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_missing_required_field(): - with pytest.raises(ValueError, match="Required field 'data' is missing"): + with pytest.raises(ValueError, + match="Required field 'pixel_values' is missing"): Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_symbolic_dim_mismatch(): with pytest.raises(ValueError, match="expected 'bn'=12, got 16"): Phi3VImagePixelInputs( - data=torch.randn(12, 64, 3, 32, 32), + pixel_values=torch.randn(12, 64, 3, 32, 32), image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_list_tensor_valid(): Phi3VImagePixelInputs( - data=[torch.randn(64, 3, 32, 32) for _ in range(16)], + pixel_values=[torch.randn(64, 3, 32, 32) for _ in range(16)], image_sizes=torch.randint(0, 256, (16, 2)), ) @@ -76,39 +79,46 @@ def test_tensor_schema_list_tensor_valid(): def test_tensor_schema_variable_patch_counts_valid(): # Each image has a different number of patches (p) # Each tensor has shape (p, 3, 32, 32) - data = [ - torch.randn(16, 3, 32, 32), # p = 16 - torch.randn(32, 3, 32, 32), # p = 32 - torch.randn(64, 3, 32, 32), # p = 64 - ] - image_sizes = torch.randint(0, 256, (3, 2)) # bn = 3 Phi3VImagePixelInputs( - data=data, - image_sizes=image_sizes, + pixel_values=[ + torch.randn(16, 3, 32, 32), # p = 16 + torch.randn(32, 3, 32, 32), # p = 32 + torch.randn(64, 3, 32, 32), # p = 64 + ], + image_sizes=torch.randint(0, 256, (3, 2)), # bn = 3 ) def test_tensor_schema_tuple_tensor_valid(): Phi3VImagePixelInputs( - data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)), + pixel_values=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)), image_sizes=torch.randint(0, 256, (16, 2)), ) +def test_tensor_schema_double_nested_tensors(): + x = torch.rand(4, 3, 32, 32) + y = torch.rand(2, 3, 32, 32) + + HCXVisionVideoPixelInputs(pixel_values_videos=([x, y, x], [y], [x, y])) + + def test_tensor_schema_inconsistent_shapes_in_list(): with pytest.raises(ValueError, match="contains inconsistent shapes"): Phi3VImagePixelInputs( - data=[torch.randn(64, 3, 32, 32), - torch.randn(64, 3, 16, 16)] + - [torch.randn(64, 3, 32, 32) for _ in range(14)], + pixel_values=[ + torch.randn(64, 3, 32, 32), + torch.randn(64, 3, 16, 16), + *(torch.randn(64, 3, 32, 32) for _ in range(14)), + ], image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_empty_list(): - with pytest.raises(ValueError, match="is an empty list"): + with pytest.raises(ValueError, match="is an empty sequence"): Phi3VImagePixelInputs( - data=[], + pixel_values=[], image_sizes=torch.randint(0, 256, (0, 2)), ) @@ -117,18 +127,18 @@ def test_tensor_schema_validation_disabled_skips_shape_check(): # This should NOT raise, because validation is turned off # This would normally fail (dim[2] should be 3, not 4) Phi3VImagePixelInputs( - data=torch.randn(16, 64, 4, 32, 32), + pixel_values=torch.randn(16, 64, 4, 32, 32), image_sizes=torch.randint(0, 256, (16, 2)), validate=False, ) def test_tensor_schema_with_valid_resolve_binding_dims(): - data = torch.randn(16, 64, 3, 336, 336) # h=336, w=336 + pixel_values = torch.randn(16, 64, 3, 336, 336) # h=336, w=336 image_sizes = torch.randint(0, 256, (16, 2)) Phi3VImagePixelInputs( - data=data, + pixel_values=pixel_values, image_sizes=image_sizes, resolve_bindings={ "h": 336, @@ -138,13 +148,13 @@ def test_tensor_schema_with_valid_resolve_binding_dims(): def test_tensor_schema_with_invalid_resolve_binding_dims(): - data = torch.randn(16, 64, 3, 36, 36) # h=36, w=36 + pixel_values = torch.randn(16, 64, 3, 36, 36) # h=36, w=36 image_sizes = torch.randint(0, 256, (16, 2)) # Should raise because 'h' and 'w' don't match resolve bindings with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"): Phi3VImagePixelInputs( - data=data, + pixel_values=pixel_values, image_sizes=image_sizes, resolve_bindings={ "h": 336, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index c253631eb8b40..36e2e29951847 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -29,7 +29,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Annotated, Any, Callable, Literal, Optional, Union, override +from typing import Annotated, Any, Callable, Literal, Optional, Union import numpy as np import torch @@ -1170,7 +1170,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): "video.height override (%d) exceeds model's " "maximum height (%d), will be ignored", overrides.height, height) - height = min(height, override.height) + height = min(height, overrides.height) video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video_items = [] diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 10d3bc8464ba0..4d1ab3aad3b4d 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -2,27 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # copied from : https://github.com/huggingface/transformers import ast -import sys from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial -from itertools import chain -from typing import Any, Literal, Optional, TypedDict, Union +from itertools import accumulate +from typing import Annotated, Any, Literal, Optional, Union import numpy as np -import PIL -from einops import rearrange -from PIL import Image - -if sys.version_info >= (3, 11): - import typing - Unpack = typing.Unpack -else: - import typing_extensions - Unpack = typing_extensions.Unpack - import torch import torch.nn as nn +from einops import rearrange from timm.layers import LayerNorm, LayerNorm2d from timm.models.regnet import RegStage from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig @@ -42,11 +31,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix) from .vision import get_vision_encoder_info EOT = "<|endofturn|>" @@ -69,28 +60,42 @@ def get_num_combined_frames( return num_canvases + (leftover_frames > 0) -class HCXVisionMultimodalPixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values_images: list[torch.Tensor] +class HCXVisionImagePixelInputs(TensorSchema): """ - Shape: `[(num_grids, num_channels, height, width), ...]` if anyres - - Note that `height` or `width` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + Dimensions: + - n: Number of images + - g: Number of grids + - c: Number of channels (3) + - h: Height + - w: Width """ - image_sizes_images: list[tuple[Union[int, float]]] - """ - Shape: `[(height, width), ...]` - """ - vision_query_lengths_images: list[Union[int, float]] - pixel_values_videos: list[tuple[Union[int, float]]] - """ - Shape: `[(num_grids, num_channels, height, width), ...]` if anyres - """ - vision_query_lengths_videos: list[Union[int, float]] + type: Literal["pixel_values"] = "pixel_values" + pixel_values_images: Annotated[ + list[torch.Tensor], + TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})] + image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)] -HCXVisionMultimodalInputs = Union[HCXVisionMultimodalPixelInputs] +HCXVisionImageInputs = HCXVisionImagePixelInputs + + +class HCXVisionVideoPixelInputs(TensorSchema): + """ + Dimensions: + - n: Number of videos + - f: Number of frames + - g: Number of grids + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" + pixel_values_videos: Annotated[ + list[list[torch.Tensor]], + TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"})] + + +HCXVisionVideoInputs = HCXVisionVideoPixelInputs class HCXVisionProcessingInfo(BaseProcessingInfo): @@ -191,27 +196,9 @@ class HCXVisionMultiModalProcessor( mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: - - def replace_multimodal_token( - token_ids: torch.Tensor, - target_token: int, - repeats: list[int], - ): - output = list[int]() - _repeats_idx = 0 - for token_id in token_ids: - if token_id == target_token: - output += [token_id.item()] * repeats[_repeats_idx] - _repeats_idx += 1 - else: - output += [token_id.item()] - - return torch.tensor(output, device=token_ids.device) - for video_idx, video_arr in enumerate(mm_data.get("videos", [])): - if video_arr.dtype == np.uint8: - continue - mm_data["videos"][video_idx] = video_arr.astype(np.uint8) + if video_arr.dtype != np.uint8: + mm_data["videos"][video_idx] = video_arr.astype(np.uint8) processed_outputs = self.info.ctx.call_hf_processor( hf_processor=self.info.get_hf_processor(**mm_kwargs), @@ -223,20 +210,16 @@ class HCXVisionMultiModalProcessor( ) # text-only if len(mm_data) > 0: + images = mm_data.get("images") + videos = mm_data.get("videos") + # batchify input as a single item - images = mm_data.get("images", None) - batched_images = None if images is None else [images] - - # list of video in single conversation - videos = mm_data.get("videos", None) - batched_videos = None if videos is None else [videos] - _processed_outputs = self.info.ctx.call_hf_processor( hf_processor=self.info.get_hf_processor(**mm_kwargs), data=dict( text=None, - images=batched_images, - videos=batched_videos, + images=None if images is None else [images], + videos=None if videos is None else [videos], ), ) # mm-only @@ -246,51 +229,43 @@ class HCXVisionMultiModalProcessor( _processed_outputs[k] = v[0] if images: - tokenizer = self.info.get_tokenizer() - image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) - processed_outputs["input_ids"] = torch.stack([ - replace_multimodal_token( - token_ids=_input_ids, - target_token=image_token_id, - repeats=_processed_outputs[ - "vision_query_lengths_images"], - ) for _input_ids in processed_outputs["input_ids"] - ], - dim=0) + _processed_outputs["image_sizes_images"] = torch.tensor( + _processed_outputs["image_sizes_images"]) + _processed_outputs[ + "vision_query_lengths_images"] = torch.tensor( + _processed_outputs["vision_query_lengths_images"]) if videos: - _num_per_videos = [ - get_num_combined_frames(len(video)) for video in videos + _idx_per_video = [ + 0, *accumulate( + get_num_combined_frames(len(video)) + for video in videos) ] _processed_outputs["pixel_values_videos"] = [ _processed_outputs["pixel_values_videos"] - [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] - for _i in range(len(videos)) + [_idx_per_video[i]:_idx_per_video[i + 1]] + for i in range(len(videos)) ] _processed_outputs["vision_query_lengths_videos"] = [ - _processed_outputs["vision_query_lengths_videos"] - [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] - for _i in range(len(videos)) + torch.tensor( + _processed_outputs["vision_query_lengths_videos"] + [_idx_per_video[i]:_idx_per_video[i + 1]]) + for i in range(len(videos)) ] - tokenizer = self.info.get_tokenizer() - video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN) - processed_outputs["input_ids"] = torch.stack([ - replace_multimodal_token( - token_ids=_input_ids, - target_token=video_token_id, - repeats=[ - sum(lens) for lens in - _processed_outputs["vision_query_lengths_videos"] - ], - ) for _input_ids in processed_outputs["input_ids"] - ], - dim=0) - processed_outputs.update(_processed_outputs) return processed_outputs + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -311,11 +286,11 @@ class HCXVisionMultiModalProcessor( out_item = out_mm_kwargs[modality][item_idx] if modality == "image": - lens = out_item["vision_query_lengths_images"].data + lens = out_item["vision_query_lengths_images"].data.tolist() num_tokens = self.info.get_num_image_tokens( vision_query_length=lens) elif modality == "video": - lens = out_item["vision_query_lengths_videos"].data + lens = out_item["vision_query_lengths_videos"].data.tolist() num_tokens = self.info.get_num_video_tokens( vision_query_length=lens) else: @@ -343,26 +318,11 @@ class HCXVisionMultiModalProcessor( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( - # image pixel_values_images=MultiModalFieldConfig.batched("image"), image_sizes_images=MultiModalFieldConfig.batched("image"), vision_query_lengths_images=MultiModalFieldConfig.batched("image"), - num_queries_vis_abstractors_images=MultiModalFieldConfig.batched( - "image"), - num_queries_vis_abstractors_slow_images=MultiModalFieldConfig. - batched("image"), - first_last_frames_slows_images=MultiModalFieldConfig.batched( - "image"), - # video pixel_values_videos=MultiModalFieldConfig.batched("video"), - image_sizes_videos=MultiModalFieldConfig.batched("video"), vision_query_lengths_videos=MultiModalFieldConfig.batched("video"), - num_queries_vis_abstractors_videos=MultiModalFieldConfig.batched( - "video"), - num_queries_vis_abstractors_slow_videos=MultiModalFieldConfig. - batched("video"), - first_last_frames_slows_videos=MultiModalFieldConfig.batched( - "video"), ) @@ -617,6 +577,7 @@ class HCXVisionCAbstractor(nn.Module): info=_build_hcxvision_hf_info, dummy_inputs=HCXVisionDummyInputsBuilder) class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -692,55 +653,94 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): raise ValueError("Only image or video modality is supported") + def _parse_and_validate_image_input( + self, + **kwargs: object, + ) -> Optional[HCXVisionImageInputs]: + pixel_values_images = kwargs.pop("pixel_values_images", None) + + if pixel_values_images is None: + return None + + image_sizes_images = kwargs.pop("image_sizes_images") + + return HCXVisionImagePixelInputs( + pixel_values_images=pixel_values_images, + image_sizes_images=image_sizes_images, + ) + + def _parse_and_validate_video_input( + self, + **kwargs: object, + ) -> Optional[HCXVisionVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + + if pixel_values_videos is None: + return None + + return HCXVisionVideoPixelInputs( + pixel_values_videos=pixel_values_videos, ) + + def _process_image_input( + self, + image_input: HCXVisionImageInputs, + ) -> tuple[torch.Tensor, ...]: + return self.forward_images( + pixel_values_images=image_input["pixel_values_images"], + image_sizes_images=image_input["image_sizes_images"], + ) + + def _process_video_input( + self, + video_input: HCXVisionVideoInputs, + ) -> tuple[torch.Tensor, ...]: + return self.forward_videos( + pixel_values_videos=video_input["pixel_values_videos"], ) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if (input_key == "pixel_values_images" + and "images" not in modalities): + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if (input_key == "pixel_values_videos" + and "videos" not in modalities): + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( self, - **kwargs: Unpack[HCXVisionMultimodalInputs], + **kwargs: object, ) -> MultiModalEmbeddings: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return [] - multimodal_embeddings = list() - if kwargs.get("pixel_values_images") is not None: - for _pixel_values_images, _image_sizes_images in zip( - kwargs["pixel_values_images"], - kwargs["image_sizes_images"]): - _pixel_values_images = _pixel_values_images.unsqueeze(dim=0) - _image_sizes_images = _image_sizes_images.unsqueeze(dim=0) - _len_pixel_values_images = [ - len(pixel_value) for pixel_value in _pixel_values_images - ] - if isinstance(_image_sizes_images, torch.Tensor): - _image_sizes_images = _image_sizes_images.detach().cpu( - ).tolist() - _multimodal_embeddings_images = self.forward_images( - pixel_values_images=_pixel_values_images, - image_sizes_images=_image_sizes_images, - len_pixel_values_images=_len_pixel_values_images, - ) - _multimodal_embeddings_images = torch.cat( - _multimodal_embeddings_images, dim=0) - multimodal_embeddings.append(_multimodal_embeddings_images) + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings - if kwargs.get("pixel_values_videos") is not None: - for _pixel_values_videos, _vision_query_lengths_videos in zip( - kwargs["pixel_values_videos"], - kwargs["vision_query_lengths_videos"]): - _len_pixel_values_videos = [ - len(_vision_query_lengths) - for _vision_query_lengths in _vision_query_lengths_videos - ] - _c, _w, _h = _pixel_values_videos.shape[-3:] - _pixel_values_videos = _pixel_values_videos.reshape( - sum(_len_pixel_values_videos), -1, _c, _w, - _h).unsqueeze(dim=0) - _multimodal_embeddings_videos = self.forward_videos( - pixel_values_videos=_pixel_values_videos, - len_pixel_values_videos=_len_pixel_values_videos, - ) - _multimodal_embeddings_videos = torch.cat( - _multimodal_embeddings_videos, dim=0) - multimodal_embeddings.append(_multimodal_embeddings_videos) return multimodal_embeddings def forward( @@ -762,28 +762,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def forward_images( self, - pixel_values_images: list[list[torch.FloatTensor]], - image_sizes_images: list[list[tuple[int, int]]], - len_pixel_values_images: list[int], - ) -> list[list[torch.Tensor]]: - if sum(len_pixel_values_images) == 0: - return None - - concat_pixel_values_images = torch.cat(list( - chain(*pixel_values_images)), - dim=0) + pixel_values_images: list[torch.Tensor], + image_sizes_images: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True) visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 image_forward_outs = self.vision_model( - concat_pixel_values_images)[:, visual_token_idx:] + pixel_values_image_flat)[:, visual_token_idx:] image_forward_outs = image_forward_outs.to( dtype=self.mm_projector.dtype) image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d - split_sizes = [ - pixel_value.shape[0] for pixel_value in chain(*pixel_values_images) - ] + split_sizes = [len(item) for item in pixel_values_images] image_forward_outs = torch.split(image_forward_outs, split_sizes, dim=0) @@ -791,10 +783,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # newline for anyres postprocessing image_features = anyres_postprocessing( image_forward_outs=image_forward_outs, - image_sizes=[ - image_size for image_sizes in image_sizes_images - for image_size in image_sizes - ], + image_sizes=image_sizes_images.tolist(), num_queries_vis_abstractor=self.config. num_queries_vis_abstractor_image, unpad=self.config.unpad, @@ -803,26 +792,21 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): image_newline=self.image_newline, possible_resolutions=self.config.possible_resolutions, ) - return image_features + + return tuple(image_features) def forward_videos( self, - pixel_values_videos: list[list[torch.FloatTensor]], - len_pixel_values_videos: list[int], - ) -> list[torch.Tensor]: - - len_video_grids = sum(len_pixel_values_videos) - if len_video_grids == 0: - return None - - # Run Vision Model - concat_pixel_values_videos = torch.cat(list( - chain(*pixel_values_videos)), - dim=0) + pixel_values_videos: list[list[torch.Tensor]], + ) -> tuple[torch.Tensor, ...]: + pixel_values_videos_flat = flatten_bn( + [frame for frames in pixel_values_videos for frame in frames], + concat=True, + ) visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 video_forward_outs = self.vision_model( - concat_pixel_values_videos)[:, visual_token_idx:] + pixel_values_videos_flat)[:, visual_token_idx:] video_forward_outs = video_forward_outs.to( dtype=self.mm_projector.dtype) @@ -905,7 +889,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ) == 0, f"target_features is not empty!! {target_features}" assert len(video_groups) == len(video_features) - return video_features + feats_per_video = [len(video) for video in pixel_values_videos] + idxs_per_video = [0, *accumulate(feats_per_video)] + return tuple( + torch.cat(video_features[idxs_per_video[i]:idxs_per_video[i + 1]]) + for i in range(len(feats_per_video))) def _prepare_multimodal_kwargs(self, **kwargs: object): output = defaultdict(list) @@ -1111,15 +1099,15 @@ def reshape_and_unpad_image_features( def anyres_postprocessing( - image_forward_outs: list[torch.FloatTensor], + image_forward_outs: list[torch.Tensor], image_sizes: list[list[int]], possible_resolutions: list[tuple[int, int]], patch_size: int, grid_size: int, - image_newline: torch.FloatTensor, + image_newline: torch.Tensor, num_queries_vis_abstractor: int = -1, unpad: bool = False, -) -> list[torch.FloatTensor]: +) -> list[torch.Tensor]: height = width = grid_size // patch_size if num_queries_vis_abstractor > 0: @@ -1147,26 +1135,5 @@ def anyres_postprocessing( (image_feature, image_newline[None].to(image_feature.device)), dim=0) new_image_features.append(image_feature) - image_features = new_image_features - return image_features - -def resize_image( - image: Union[np.ndarray, PIL.Image.Image], - max_side: int = 378, -) -> np.ndarray: - image_arr = image - if isinstance(image, np.ndarray): - image = Image.fromarray(image) - - width, height = image.size - cur_max_size = max(width, height) - if cur_max_size <= max_side: - return image_arr - - scale = max_side / cur_max_size - width = int(width * scale) - height = int(height * scale) - image = image.resize((width, height), Image.LANCZOS) - image_arr = np.array(image) - return image_arr + return new_image_features diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index f5720e726c48e..2415f3696f001 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -109,7 +109,7 @@ class Phi3VImagePixelInputs(TensorSchema): type: Literal["pixel_values", "image_embeds"] = "pixel_values" # Supports either a stacked tensor or a list of (p, 3, h, w) tensors - data: Annotated[ + pixel_values: Annotated[ Union[torch.Tensor, list[torch.Tensor]], TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} ), # 'p' may vary across items @@ -594,7 +594,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, if pixel_values is not None: return Phi3VImagePixelInputs( type="pixel_values", - data=flatten_bn(pixel_values), + pixel_values=flatten_bn(pixel_values), image_sizes=flatten_bn(image_sizes, concat=True), resolve_bindings={ "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, @@ -628,7 +628,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ) assert self.vision_embed_tokens is not None - image_embeds = self.vision_embed_tokens(image_input["data"], + image_embeds = self.vision_embed_tokens(image_input["pixel_values"], image_input["image_sizes"]) return image_embeds diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py index d75dbcd5401b2..44688467b8998 100644 --- a/vllm/utils/tensor_schema.py +++ b/vllm/utils/tensor_schema.py @@ -94,34 +94,63 @@ class TensorSchema: return False return True - def _validate_nested_tensors( - self, - value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], - field_name: str, - expected_shape: tuple[Union[int, str], ...], - dynamic_dims: set[str], + def _fmt_indexer(self, idxs: tuple[int, ...]) -> str: + if not idxs: + return "" + + return str(list(idxs)) + + def _validate_field( + self, + value: object, + field_name: str, + expected_shape: tuple[Union[int, str], ...], + dynamic_dims: set[str], + leading_idxs: tuple[int, ...] = (), ) -> tuple[int, ...]: - """Validate a list/tuple of tensors and return the actual shape.""" + """Validate a field and return the actual shape.""" + if isinstance(value, (int, float)): + return () # Scalar + if isinstance(value, torch.Tensor): + return value.shape + + if not isinstance(value, (list, tuple)): + raise TypeError( + f"{field_name}{self._fmt_indexer(leading_idxs)} is not " + f"one of the expected types: int, float, Tensor, list, tuple. " + f"Got: {type(value)}") + + if len(value) == 0: + raise ValueError(f"{field_name}{self._fmt_indexer(leading_idxs)} " + f"is an empty sequence") + # Ensure all tensors in the list have the same # shape, besides dynamic dimensions - first = value[0] for i, v in enumerate(value): - if not isinstance(v, torch.Tensor): - raise ValueError(f"{field_name}[{i}] is not a " - f"torch.Tensor") - if not self._match_shape_with_dynamic( - v.shape, - first.shape, + shape = self._validate_field( + v, + field_name, + expected_shape[1:], + dynamic_dims, + leading_idxs=leading_idxs + (i, ), + ) + + if i == 0: + first_shape = shape + elif not self._match_shape_with_dynamic( + shape, + first_shape, expected_shape, dynamic_dims, ): - raise ValueError(f"{field_name} contains inconsistent " - f"shapes: {first.shape} vs {v.shape} " - f"at index {i}") + raise ValueError( + f"{field_name}{self._fmt_indexer(leading_idxs)} " + f"contains inconsistent shapes: {first_shape} " + f"(index 0) vs {shape} (index {i})") # Treat the list as a stacked tensor: # shape = (len(list), *tensor.shape) - return (len(value), ) + first.shape + return (len(value), ) + first_shape def _validate_tensor_shape_expected( self, @@ -187,36 +216,12 @@ class TensorSchema: for arg in args: if isinstance(arg, TensorShape): expected_shape = arg.resolve(**self._resolve_bindings) - if isinstance(value, (list, tuple)): - # list/tuple of Tensors → shape = (len(value), ...) - if value and isinstance(value[0], torch.Tensor): - actual_shape = self._validate_nested_tensors( - value, field_name, expected_shape, - arg.dynamic_dims) - elif value: - # list/tuple of scalars → shape = (len(value),) - actual_shape = (len(value), ) - else: - raise ValueError( - f"{field_name} is an empty list") - - # Tensor → shape = tensor.shape - elif isinstance(value, torch.Tensor): - actual_shape = value.shape - - # Otherwise, it's an unsupported type - else: - type_names = [] - for arg in args: - if hasattr(arg, "__name__"): - type_names.append(str(arg.__name__)) - else: - type_names.append(str(arg)) - - expected_types = ", ".join(type_names) - raise ValueError( - f"{field_name} is not one of the expected " - f"types: {expected_types}") + actual_shape = self._validate_field( + value, + field_name, + expected_shape, + arg.dynamic_dims, + ) self._validate_tensor_shape_expected( actual_shape, expected_shape, field_name,