diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index 41fd272397e64..f1bb18716b40d 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -19,7 +19,7 @@ """ PyTorch Ovis model.""" import math from collections.abc import Iterable, Mapping -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -49,6 +49,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processors.ovis import OvisProcessor +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import merge_multimodal_embeddings @@ -201,25 +202,22 @@ class VisualTokenizer(torch.nn.Module): return tokens -class OvisImagePatchInputs(TypedDict): +class OvisImagePatchInputs(TensorSchema): + """ + Dimensions: + - batch_patches: Batch size * number of patches + - patch_size: patch_size_x * patch_size_y * num_channels + - patch_indicators: Batch size * (number of patches + 1) + - patches_per_image: List of number of total patches for each image + in the batch. + """ type: Literal["image_patches"] - flat_data: torch.Tensor - """ - Shape: - `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` - """ - - indicator_tokens: torch.Tensor - """ - Shape: - `(batch_size * (num_patches + 1))` - """ - - patches_per_image: list[int] - """ - List of number of total patches for each image in the batch. - This is used to restore the first two dimensions of `flat_data`. - """ + flat_data: Annotated[torch.Tensor, + TensorShape("batch_patches", "patch_size")] + indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")] + patches_per_image: Annotated[list[int], + TensorShape("num_patches_per_image")] + # This is used to restore the first two dimensions of `flat_data`. class VisualEmbedding(torch.nn.Embedding): @@ -458,9 +456,12 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): raise ValueError("Incorrect type of indicator_tokens. " f"Got type: {type(pixel_values)}") + flat_data = flatten_bn(pixel_values, concat=True) + if flat_data.ndim >= 3: + flat_data = flat_data.flatten(start_dim=1) return OvisImagePatchInputs( type="image_patches", - flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), + flat_data=flat_data, patches_per_image=[ x.shape[0] for x in flatten_bn(pixel_values) ],