Migrate DonutImagePixelInputs to TensorSchema (#23509)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-08-24 22:02:15 -07:00 committed by GitHub
parent a5203d04df
commit 787cdb3829
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
PromptIndexTargets, PromptInsertion,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.utils.tensor_schema import TensorSchema, TensorShape
class MBartDecoderWrapper(nn.Module):
@ -132,10 +133,16 @@ class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
return loaded_params
class DonutImagePixelInputs(TypedDict):
class DonutImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channel, height, width)"""
data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")]
class DonutProcessingInfo(BaseProcessingInfo):
@ -275,27 +282,6 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
)
self.pad_token_id = config.pad_token_id
def _validate_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
# size = self.processor_config["size"]
h, w = self.config.encoder.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
raise ValueError(
"The expected shape of pixel values per batch "
f"is {expected_dims}. You supplied {actual_dims}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
@ -314,11 +300,14 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
"Both pixel values and image embeds are provided.")
if pixel_values is not None:
return DonutImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
h, w = self.config.encoder.image_size
return DonutImagePixelInputs(type="pixel_values",
data=flatten_bn(pixel_values,
concat=True),
resolve_bindings={
"h": h,
"w": w,
})
if image_embeds is not None:
raise NotImplementedError