Migrate MolmoImageInputs to TensorSchema (#22022)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-08-21 09:54:08 -07:00 committed by GitHub
parent e0b056e443
commit a482e4e769
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property, partial from functools import cached_property, partial
from typing import Optional, TypedDict, Union from typing import Annotated, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -51,6 +51,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant) SupportsMultiModal, SupportsPP, SupportsQuant)
@ -70,23 +71,25 @@ IM_END_TOKEN = "<im_end>"
POOLING_SIZE = 2 POOLING_SIZE = 2
class MolmoImageInputs(TypedDict): class MolmoImageInputs(TensorSchema):
images: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`"""
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
"""Shape: `(batch_size * num_images, num_crops, num_patch)`"""
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
A boolean mask indicating which image features correspond Dimensions:
to patch tokens. - bn: Batch size * number of images
- nc: Number of crops
Shape: `(batch_size * num_images, num_crops, num_patch)` - np: Number of patches
- pd: Patch dimension
""" """
images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "np", "pd")]
num_crops: torch.Tensor image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]],
"""Shape: `(batch_size * num_images)`""" TensorShape("bn", "nc", "np")]
feat_is_patch: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "np")]
# A boolean mask indicating which image features correspond to patch tokens.
num_crops: Annotated[torch.Tensor, TensorShape("bn")]
@dataclass @dataclass
@ -1410,28 +1413,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
**kwargs: object, **kwargs: object,
) -> Optional[MolmoImageInputs]: ) -> Optional[MolmoImageInputs]:
images = kwargs.pop("images", None) images = kwargs.pop("images", None)
image_masks = kwargs.pop("image_masks", None)
feat_is_patch = kwargs.pop("feat_is_patch", None)
num_crops = kwargs.pop("num_crops", None)
if images is None: if images is None:
return None return None
if not isinstance(images, (torch.Tensor, list)):
raise ValueError("Incorrect type of images. "
f"Got type: {type(images)}")
image_masks = kwargs.pop("image_masks", None)
if not (image_masks is None or isinstance(image_masks,
(torch.Tensor, list))):
raise ValueError("Incorrect type of image_masks. "
f"Got type: {type(image_masks)}")
feat_is_patch = kwargs.pop("feat_is_patch", None)
if not isinstance(feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
num_crops = kwargs.pop("num_crops", None)
if not isinstance(num_crops, (torch.Tensor, list)): if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. " raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}") f"Got type: {type(num_crops)}")
num_crops = flatten_bn(num_crops, concat=True)
img_patch_id = kwargs.pop("img_patch_id", None) img_patch_id = kwargs.pop("img_patch_id", None)
if not isinstance(img_patch_id, torch.Tensor): if not isinstance(img_patch_id, torch.Tensor):
@ -1439,8 +1431,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
f"Got type: {type(img_patch_id)}") f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item() self.img_patch_id = img_patch_id.flatten().unique().item()
num_crops = flatten_bn(num_crops, concat=True)
return MolmoImageInputs( return MolmoImageInputs(
images=images, images=images,
image_masks=image_masks, image_masks=image_masks,