Migrate AriaImagePixelInputs to TensorSchema for shape validation (#21620)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-07-26 06:08:15 -07:00 committed by GitHub
parent e98def439c
commit 9d197280fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, TypedDict, Union
from typing import Annotated, Optional, Union
import torch
import torch.nn as nn
@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
# yapf: disable
from .idefics2_vision_model import Idefics2VisionConfig
@ -42,15 +43,26 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
merge_multimodal_embeddings)
class AriaImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
pixel_mask: Optional[torch.Tensor]
class AriaImagePixelInputs(TensorSchema):
"""
Shape:
pixel_values: `(batch_size * num_images, num_channels, height, width)`
pixel_mask: `(batch_size * num_images, height, width)`
Dimensions:
- b: Batch size
- n: Number of images
- c: Number of channels
- h: Height of each image
- w: Width of each image
"""
pixel_values: Annotated[
torch.Tensor,
TensorShape("bn", 3, "h", "w"),
]
pixel_mask: Annotated[
Optional[torch.Tensor],
TensorShape("bn", "h", "w"),
]
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
@ -540,12 +552,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale)
def _validate_image_sizes(
self, images: list[torch.Tensor]) -> list[torch.Tensor]:
if not all(img.shape == images[0].shape for img in images):
raise ValueError("All images must be the same size")
return images
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[AriaImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -554,23 +560,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_values = self._validate_image_sizes(pixel_values)
pixel_values = flatten_bn(pixel_values, concat=True)
if pixel_mask is not None:
if not isinstance(pixel_mask, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel mask. "
f"Got type: {type(pixel_mask)}")
pixel_mask = flatten_bn(pixel_mask, concat=True)
return AriaImagePixelInputs(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
pixel_values=flatten_bn(pixel_values, concat=True),
pixel_mask=flatten_bn(pixel_mask, concat=True),
)
def _create_patch_attention_mask(