diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 8ae1680a71f3..e1368a3f6478 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Optional, TypedDict, Union +from typing import Annotated, Optional, Union import torch import torch.nn as nn @@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape # yapf: disable from .idefics2_vision_model import Idefics2VisionConfig @@ -42,15 +43,26 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, merge_multimodal_embeddings) -class AriaImagePixelInputs(TypedDict): - pixel_values: torch.Tensor - pixel_mask: Optional[torch.Tensor] +class AriaImagePixelInputs(TensorSchema): """ - Shape: - pixel_values: `(batch_size * num_images, num_channels, height, width)` - pixel_mask: `(batch_size * num_images, height, width)` + Dimensions: + - b: Batch size + - n: Number of images + - c: Number of channels + - h: Height of each image + - w: Width of each image """ + pixel_values: Annotated[ + torch.Tensor, + TensorShape("bn", 3, "h", "w"), + ] + + pixel_mask: Annotated[ + Optional[torch.Tensor], + TensorShape("bn", "h", "w"), + ] + class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant): packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} @@ -540,12 +552,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.vocab_size, logit_scale) - def _validate_image_sizes( - self, images: list[torch.Tensor]) -> list[torch.Tensor]: - if not all(img.shape == images[0].shape for img in images): - raise ValueError("All images must be the same size") - return images - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[AriaImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -554,23 +560,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): if pixel_values is None: return None - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - pixel_values = self._validate_image_sizes(pixel_values) - pixel_values = flatten_bn(pixel_values, concat=True) - - if pixel_mask is not None: - if not isinstance(pixel_mask, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel mask. " - f"Got type: {type(pixel_mask)}") - - pixel_mask = flatten_bn(pixel_mask, concat=True) - return AriaImagePixelInputs( - pixel_values=pixel_values, - pixel_mask=pixel_mask, + pixel_values=flatten_bn(pixel_values, concat=True), + pixel_mask=flatten_bn(pixel_mask, concat=True), ) def _create_patch_attention_mask(