diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index bb3267ce5b00..2a60450de414 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -17,7 +17,7 @@ """PyTorch Mllama model.""" import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import numpy as np import torch @@ -64,6 +64,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPMLP from .interfaces import SupportsMultiModal, SupportsV0Only @@ -73,15 +74,30 @@ from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix logger = init_logger(__name__) -class MllamaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: """ - """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" - aspect_ratio_ids: torch.Tensor - """Shape: `(batch_size, max_num_image)`""" - aspect_ratio_mask: torch.Tensor - """Shape: `(batch_size, max_num_image, max_num_tiles)`""" +class MllamaImagePixelInputs(TensorSchema): + """ + Dimensions: + - batch_size: Batch size + - max_num_image: Max number of images + - max_num_chunk: Max number of chunks + - max_num_tiles: Max number of tiles per image + - num_channel: Number of channels + - height: Height + - width: Width + """ + + type: Literal["pixel_values"] = "pixel_values" + + data: Annotated[torch.Tensor, + TensorShape("batch_size", "max_num_image", "max_num_chunk", + "num_channel", "height", "width")] + + aspect_ratio_ids: Annotated[torch.Tensor, + TensorShape("batch_size", "max_num_image")] + + aspect_ratio_mask: Annotated[ + torch.Tensor, + TensorShape("batch_size", "max_num_image", "max_num_tiles")] # TODO: support LlamaImageEmbeddingInputs