Migrate MiniCPMVImageInputs to TensorSchema (#21939)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-08-11 20:43:37 -07:00 committed by GitHub
parent 93d0652433
commit 4678503476
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,7 +27,7 @@ import math
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Any, Callable, Literal, Optional, TypedDict, Union from typing import Annotated, Any, Callable, Literal, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -63,6 +63,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists from vllm.utils import flatten_2d_lists
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@ -74,36 +75,47 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
_MAX_FRAMES_PER_VIDEO = 16 _MAX_FRAMES_PER_VIDEO = 16
class MiniCPMVImagePixelInputs(TypedDict): class MiniCPMVImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: list[torch.Tensor]
""" """
Shape: `(batch_size * num_images * num_slices, num_channels, height, width)` Dimensions:
- bns: Batch size * number of images * number of slices
Note that the image size may vary, so we pass it as a list - bn: Batch size * number of images
instead of a batched tensor. - c: Number of channels
- h: Height
- w: Width
""" """
tgt_sizes: torch.Tensor type: Literal["pixel_values"] = "pixel_values"
"""
Shape: `(batch_size * num_images * num_slices, 2)`
This should be in `(height, width)` format. # Note that the image size may vary, so we pass it as a list instead of a
# batched tensor.
pixel_values: Annotated[
list[torch.Tensor],
TensorShape("bns", "c", "h", "w"),
]
tgt_sizes: Annotated[
torch.Tensor,
TensorShape("bns", 2), # This should be in `(height, width)` format.
]
num_slices: Annotated[
torch.Tensor,
TensorShape("bn"),
]
class MiniCPMVImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- ns: Number of slices
- hs: Hidden size (must match language model backbone)
""" """
num_slices: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
image_embeds: Union[torch.Tensor, list[torch.Tensor]] image_embeds: Annotated[
""" Union[torch.Tensor, list[torch.Tensor]],
Shape: `(batch_size * num_images, num_slices, hidden_size)` TensorShape("bn", "ns", "hs"),
]
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
"""
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
@ -832,11 +844,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values)) pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True) tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
if len(pixel_values_flat) != len(tgt_sizes_flat):
raise ValueError("Inconsistent flattened lengths, found: "
f"{len(pixel_values_flat)} vs. "
f"{len(tgt_sizes_flat)}")
return MiniCPMVImagePixelInputs( return MiniCPMVImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=pixel_values_flat, pixel_values=pixel_values_flat,