diff --git a/vllm/model_executor/models/donut.py b/vllm/model_executor/models/donut.py index b1f6a0af6b3de..c00db52371b68 100644 --- a/vllm/model_executor/models/donut.py +++ b/vllm/model_executor/models/donut.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo, PromptIndexTargets, PromptInsertion, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.utils.tensor_schema import TensorSchema, TensorShape class MBartDecoderWrapper(nn.Module): @@ -132,10 +133,16 @@ class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only): return loaded_params -class DonutImagePixelInputs(TypedDict): +class DonutImagePixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - c: Number of channels (3) + - h: Height + - w: Width + """ type: Literal["pixel_values"] - data: torch.Tensor - """Shape: (batch_size, num_channel, height, width)""" + data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")] class DonutProcessingInfo(BaseProcessingInfo): @@ -275,27 +282,6 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, ) self.pad_token_id = config.pad_token_id - def _validate_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - # size = self.processor_config["size"] - h, w = self.config.encoder.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - raise ValueError( - "The expected shape of pixel values per batch " - f"is {expected_dims}. You supplied {actual_dims}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input(self, **kwargs: object): pixel_values: Optional[Union[list[list[torch.Tensor]], list[torch.Tensor], @@ -314,11 +300,14 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, "Both pixel values and image embeds are provided.") if pixel_values is not None: - return DonutImagePixelInputs( - type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), - ) + h, w = self.config.encoder.image_size + return DonutImagePixelInputs(type="pixel_values", + data=flatten_bn(pixel_values, + concat=True), + resolve_bindings={ + "h": h, + "w": w, + }) if image_embeds is not None: raise NotImplementedError