From a70d0bd0a39bfb278b7bda9c82d99df8f628d779 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Tue, 19 Aug 2025 10:02:02 -0700 Subject: [PATCH] Migrate LlavaOnevisionMultiInputs to TensorSchema (#21844) Signed-off-by: Benji Beck --- vllm/model_executor/models/llava_onevision.py | 149 +++++++----------- 1 file changed, 56 insertions(+), 93 deletions(-) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index babd72a4b782e..42ab5e7c74d37 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Final, Literal, Optional, Protocol, TypedDict, Union +from typing import Annotated, Final, Literal, Optional, Protocol, Union import torch import torch.nn as nn @@ -11,7 +11,6 @@ from transformers import (BatchFeature, LlavaOnevisionConfig, LlavaOnevisionProcessor) from transformers.models.llava_onevision.modeling_llava_onevision import ( get_anyres_image_grid_shape, unpad_image) -from typing_extensions import NotRequired from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn @@ -23,6 +22,7 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -38,44 +38,62 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, _MAX_FRAMES_PER_VIDEO = 16 -class LlavaOnevisionVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]] +class LlavaOnevisionVideoPixelInputs(TensorSchema): """ - Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)` + Dimensions: + - bn: Batch size * number of videos + - f: Number of frames + - c: Number of channels (3) + - h: Height + - w: Width - Note that `num_videos` may be different for each batch, and 'num_frames' - may be different for each video, in which case the data is passed as a - list instead of a batched tensor. + Note that `num_videos` may be different for each batch, and 'num_frames' + may be different for each video, in which case the data is passed as a + list instead of a batched tensor. """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" + + pixel_values_videos: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}), + ] -class LlavaOnevisionImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class LlavaOnevisionImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` + Dimensions: + - bn: Batch size * number of images + - np: Number of patches (1 + num_patches) + - c: Number of channels (3) + - h: Height + - w: Width - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" - image_sizes: NotRequired[torch.Tensor] + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "np", 3, "h", "w"), + ] + + image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + + +class LlavaOnevisionImageEmbeddingInputs(TensorSchema): """ - Shape: `(batch_size * num_images, 2)` - - This should be in `(height, width)` format. + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" - -class LlavaOnevisionImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - """ + data: Annotated[ + torch.Tensor, + TensorShape("bn", "ifs", "hs"), + ] LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs, @@ -482,44 +500,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - expected_dims = (2, ) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - f"The expected shape of image sizes per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - - def _validate_image_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("num_patches", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -540,11 +520,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return LlavaOnevisionImagePixelInputs( type="pixel_values", - pixel_values=self._validate_image_pixel_values( - flatten_bn(pixel_values)), - image_sizes=self._validate_image_sizes( - flatten_bn(image_sizes, concat=True)), - ) + pixel_values=flatten_bn(pixel_values), + image_sizes=flatten_bn(image_sizes, concat=True), + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size + }) if image_embeds is not None: if not isinstance(image_embeds, torch.Tensor): @@ -558,27 +539,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, raise AssertionError("This line should be unreachable.") - def _validate_video_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[2:]) - - if actual_dims != expected_dims: - expected_expr = ("num_frames", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values in each video frame " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]: @@ -600,7 +560,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return LlavaOnevisionVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=flatten_bn(pixel_values_videos), - ) + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size + }) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {}