From 467850347687f0ef76c1a57d79e2c0639eaa1456 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Mon, 11 Aug 2025 20:43:37 -0700 Subject: [PATCH] Migrate MiniCPMVImageInputs to TensorSchema (#21939) Signed-off-by: Benji Beck --- vllm/model_executor/models/minicpmv.py | 65 ++++++++++++++------------ 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 3aa16bb9abe4..7db3a1bb90b4 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -27,7 +27,7 @@ import math from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import numpy as np import torch @@ -63,6 +63,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import flatten_2d_lists +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -74,36 +75,47 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, _MAX_FRAMES_PER_VIDEO = 16 -class MiniCPMVImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: list[torch.Tensor] +class MiniCPMVImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images * num_slices, num_channels, height, width)` - - Note that the image size may vary, so we pass it as a list - instead of a batched tensor. + Dimensions: + - bns: Batch size * number of images * number of slices + - bn: Batch size * number of images + - c: Number of channels + - h: Height + - w: Width """ - tgt_sizes: torch.Tensor - """ - Shape: `(batch_size * num_images * num_slices, 2)` + type: Literal["pixel_values"] = "pixel_values" - This should be in `(height, width)` format. + # Note that the image size may vary, so we pass it as a list instead of a + # batched tensor. + pixel_values: Annotated[ + list[torch.Tensor], + TensorShape("bns", "c", "h", "w"), + ] + tgt_sizes: Annotated[ + torch.Tensor, + TensorShape("bns", 2), # This should be in `(height, width)` format. + ] + num_slices: Annotated[ + torch.Tensor, + TensorShape("bn"), + ] + + +class MiniCPMVImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ns: Number of slices + - hs: Hidden size (must match language model backbone) """ - num_slices: torch.Tensor - """Shape: `(batch_size * num_images)`""" - - -class MiniCPMVImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - image_embeds: Union[torch.Tensor, list[torch.Tensor]] - """ - Shape: `(batch_size * num_images, num_slices, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - instead of a batched tensor. - """ + image_embeds: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "ns", "hs"), + ] MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, @@ -832,11 +844,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values)) tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True) - if len(pixel_values_flat) != len(tgt_sizes_flat): - raise ValueError("Inconsistent flattened lengths, found: " - f"{len(pixel_values_flat)} vs. " - f"{len(tgt_sizes_flat)}") - return MiniCPMVImagePixelInputs( type="pixel_values", pixel_values=pixel_values_flat,