From de10ff0b7cc757af3d0374d82c1a2130196af496 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sat, 26 Jul 2025 06:08:18 -0700 Subject: [PATCH] Migrate AyaVisionImagePixelInputs to TensorSchema for shape validation (#21622) Signed-off-by: Benji Beck --- vllm/model_executor/models/aya_vision.py | 67 ++++++++++-------------- 1 file changed, 29 insertions(+), 38 deletions(-) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 45dd660c89375..a3eee9f065aea 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union, cast +from typing import Annotated, Literal, Optional, Union, cast import torch from torch import nn @@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel @@ -37,18 +38,28 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, merge_multimodal_embeddings) -class AyaVisionImagePixelInputs(TypedDict): +class AyaVisionImagePixelInputs(TensorSchema): + """ + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - c: Number of channels + - h: Height of each image patch + - w: Width of each image patch + - bn: Batch size * number of images + """ + type: Literal["pixel_values"] - pixel_values: torch.Tensor - """ - Shape: `(num_patches_total, num_channels, height, width)` - `num_patches_total` is the total number of patches over each image over each - prompt in the batch. - """ + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", 3, "h", "w"), + ] - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" + num_patches: Annotated[ + torch.Tensor, + TensorShape("bn"), + ] class AyaVisionMultiModalProjector(nn.Module): @@ -383,21 +394,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist()) ] - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - if d.shape != expected_dims: - raise ValueError( - "The expected shape of pixel values per image per batch " - f"is {expected_dims}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -405,22 +401,17 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = kwargs.pop("image_embeds", None) assert image_embeds is None, "Aya Vision does not support image_embeds." - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - if num_patches is not None and not isinstance(num_patches, - (torch.Tensor, list)): - raise ValueError("Incorrect type of num_patches. " - f"Got type: {type(num_patches)}") - - pixel_values = flatten_bn(pixel_values, concat=True) - num_patches = flatten_bn(num_patches, concat=True) + if pixel_values is None: + return None return AyaVisionImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values(pixel_values), - num_patches=num_patches, - ) + pixel_values=flatten_bn(pixel_values, concat=True), + num_patches=flatten_bn(num_patches, concat=True), + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size, + }) def get_language_model(self) -> torch.nn.Module: return self.language_model