From a554991748584b00e3bbd2ab192cbcac3f630263 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sun, 10 Aug 2025 19:29:16 -0700 Subject: [PATCH] Migrate LlavaNextVideoPixelInputs to TensorSchema (#21843) Signed-off-by: Benji Beck --- .../model_executor/models/llava_next_video.py | 57 +++++++------------ 1 file changed, 22 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index a96df0b6f572e..abc519edadcca 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -25,6 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llava import init_vision_tower_for_llava @@ -35,17 +36,25 @@ from .utils import (AutoWeightsLoader, WeightsMapper, from .vision import get_vision_encoder_info -class LlavaNextVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - Shape: `(batch_size, num_frames, num_channels, height, width)` +class LlavaNextVideoPixelInputs(TensorSchema): + """ + Dimensions: + - bs: Batch size + - nv: Number of videos + - nf: Number of frames + - nc: Number of channels (3) + - h: Height of each frame + - w: Width of each frame Note that `num_frames` may be different for each batch, in which case the data is passed as a list instead of a batched tensor. Note that it only supports one video input for one batch. """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" + + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bs", "nv", "nf", 3, "h", "w")] class LlavaNextVideoProcessingInfo(BaseProcessingInfo): @@ -320,27 +329,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) - def _validate_video_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[2:]) - - if actual_dims != expected_dims: - expected_expr = ("num_frames", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values in each video frame " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]: """ @@ -355,14 +343,13 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values_videos is None: return None - if not isinstance(pixel_values_videos, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel_values_videos. " - f"Got type: {type(pixel_values_videos)}") - - return LlavaNextVideoPixelInputs( - type="pixel_values_videos", - data=pixel_values_videos, - ) + expected_h = expected_w = self.config.vision_config.image_size + return LlavaNextVideoPixelInputs(type="pixel_values_videos", + data=pixel_values_videos, + resolve_bindings={ + "h": expected_h, + "w": expected_w, + }) def _select_image_features(self, image_features: torch.Tensor, *, strategy: str) -> torch.Tensor: