diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 3637f037751c..a0e98ca3f815 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -9,7 +9,7 @@ # -------------------------------------------------------- from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, TypeVar, Union +from typing import Annotated, Any, Literal, Optional, TypeVar, Union import numpy.typing as npt import torch @@ -37,6 +37,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -51,54 +52,60 @@ IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) -class InternVLImagePixelInputs(TypedDict): +class InternVLImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - bnp: Batch size * number of images * (1 + num_patches) + - c: Number of channels (3) + - h: Height of each image patch + - w: Width of each image patch + """ type: Literal["pixel_values"] - pixel_values_flat: torch.Tensor + pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] + + +class InternVLImageEmbeddingInputs(TensorSchema): """ - Shape: - `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + Dimensions: + - n: Number of images + - f: Total image feature size + - h: Hidden size (must match the hidden size of language model backbone) """ - - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" - - -class InternVLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - A tensor of shape `(num_images, total_image_feature_size, hidden_size)` - or a list of tensors of shape `(total_image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - """ + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("n", "f", "h")] InternVLImageInputs = Union[InternVLImagePixelInputs, InternVLImageEmbeddingInputs] -class InternVLVideoPixelInputs(TypedDict): +class InternVLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - bvf: Batch size * number of videos * num_frames + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each video frame + - w: Width of each video frame + """ type: Literal["pixel_values_videos"] - pixel_values_flat: torch.Tensor + pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] + + +class InternVLVideoEmbeddingInputs(TensorSchema): """ - Shape: - `(batch_size * num_video * num_frames, num_channels, height, width)` + Dimensions: + - n: Number of videos + - f: Total video feature size + - h: Hidden size (must match the hidden size of language model backbone) """ - - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" - - -class InternVLVideoEmbeddingInputs(TypedDict): type: Literal["video_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - A tensor of shape `(num_videos, total_video_feature_size, hidden_size)` - or a list of tensors of shape `(total_video_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - """ + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("n", "f", "h")] InternVLVideoInputs = Union[InternVLVideoPixelInputs, @@ -1151,26 +1158,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, vit_embeds = self.mlp1(vit_embeds) return vit_embeds - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[InternVLImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) @@ -1205,12 +1192,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True) + expected_h = expected_w = self.config.vision_config.image_size + resolve_bindings = {"h": expected_h, "w": expected_w} return InternVLImagePixelInputs( type="pixel_values", - pixel_values_flat=self._validate_pixel_values( - pixel_values_flat), + pixel_values_flat=pixel_values_flat, num_patches=image_num_patches, + resolve_bindings=resolve_bindings, ) raise AssertionError("This line should be unreachable.") @@ -1225,11 +1214,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return None if video_embeds is not None: - if not isinstance(video_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") - - return InternVLImageEmbeddingInputs( + return InternVLVideoEmbeddingInputs( type="video_embeds", data=flatten_bn(video_embeds), ) @@ -1250,12 +1235,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, pixel_values_flat_video = flatten_bn(pixel_values_flat_video, concat=True) video_num_patches = flatten_bn(video_num_patches, concat=True) + expected_h = expected_w = self.config.vision_config.image_size + resolve_bindings = {"h": expected_h, "w": expected_w} return InternVLVideoPixelInputs( type="pixel_values_videos", - pixel_values_flat=self._validate_pixel_values( - pixel_values_flat_video), + pixel_values_flat=pixel_values_flat_video, num_patches=video_num_patches, + resolve_bindings=resolve_bindings, ) raise AssertionError("This line should be unreachable.")