diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 6a08d2793fd03..5fc28ed0e493e 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -5,7 +5,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from functools import cached_property, partial -from typing import Optional, TypedDict, Union +from typing import Annotated, Optional, Union import numpy as np import torch @@ -51,6 +51,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) @@ -70,23 +71,25 @@ IM_END_TOKEN = "" POOLING_SIZE = 2 -class MolmoImageInputs(TypedDict): - images: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`""" - - image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]] - """Shape: `(batch_size * num_images, num_crops, num_patch)`""" - - feat_is_patch: Union[torch.Tensor, list[torch.Tensor]] +class MolmoImageInputs(TensorSchema): """ - A boolean mask indicating which image features correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_crops, num_patch)` + Dimensions: + - bn: Batch size * number of images + - nc: Number of crops + - np: Number of patches + - pd: Patch dimension """ + images: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "np", "pd")] - num_crops: torch.Tensor - """Shape: `(batch_size * num_images)`""" + image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]], + TensorShape("bn", "nc", "np")] + + feat_is_patch: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "np")] + # A boolean mask indicating which image features correspond to patch tokens. + + num_crops: Annotated[torch.Tensor, TensorShape("bn")] @dataclass @@ -1410,28 +1413,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, **kwargs: object, ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) + image_masks = kwargs.pop("image_masks", None) + feat_is_patch = kwargs.pop("feat_is_patch", None) + num_crops = kwargs.pop("num_crops", None) + if images is None: return None - if not isinstance(images, (torch.Tensor, list)): - raise ValueError("Incorrect type of images. " - f"Got type: {type(images)}") - - image_masks = kwargs.pop("image_masks", None) - if not (image_masks is None or isinstance(image_masks, - (torch.Tensor, list))): - raise ValueError("Incorrect type of image_masks. " - f"Got type: {type(image_masks)}") - - feat_is_patch = kwargs.pop("feat_is_patch", None) - if not isinstance(feat_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of feat_is_patch. " - f"Got type: {type(feat_is_patch)}") - - num_crops = kwargs.pop("num_crops", None) if not isinstance(num_crops, (torch.Tensor, list)): raise ValueError("Incorrect type of num_crops. " f"Got type: {type(num_crops)}") + num_crops = flatten_bn(num_crops, concat=True) img_patch_id = kwargs.pop("img_patch_id", None) if not isinstance(img_patch_id, torch.Tensor): @@ -1439,8 +1431,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, f"Got type: {type(img_patch_id)}") self.img_patch_id = img_patch_id.flatten().unique().item() - num_crops = flatten_bn(num_crops, concat=True) - return MolmoImageInputs( images=images, image_masks=image_masks,