Migrate LlavaNextVideoPixelInputs to TensorSchema (#21843)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-08-10 19:29:16 -07:00 committed by GitHub
parent d1af8b7be9
commit a554991748
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@ -25,6 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
@ -35,17 +36,25 @@ from .utils import (AutoWeightsLoader, WeightsMapper,
from .vision import get_vision_encoder_info
class LlavaNextVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size, num_frames, num_channels, height, width)`
class LlavaNextVideoPixelInputs(TensorSchema):
"""
Dimensions:
- bs: Batch size
- nv: Number of videos
- nf: Number of frames
- nc: Number of channels (3)
- h: Height of each frame
- w: Width of each frame
Note that `num_frames` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Note that it only supports one video input for one batch.
"""
type: Literal["pixel_values_videos"] = "pixel_values_videos"
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bs", "nv", "nf", 3, "h", "w")]
class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
@ -320,27 +329,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[2:])
if actual_dims != expected_dims:
expected_expr = ("num_frames", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each video frame "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]:
"""
@ -355,14 +343,13 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values_videos is None:
return None
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel_values_videos. "
f"Got type: {type(pixel_values_videos)}")
return LlavaNextVideoPixelInputs(
type="pixel_values_videos",
data=pixel_values_videos,
)
expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextVideoPixelInputs(type="pixel_values_videos",
data=pixel_values_videos,
resolve_bindings={
"h": expected_h,
"w": expected_w,
})
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor: