Migrate AriaImagePixelInputs to TensorSchema for shape validation (#21620)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-07-26 06:08:15 -07:00 committed by GitHub
parent e98def439c
commit 9d197280fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, TypedDict, Union from typing import Annotated, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
# yapf: disable # yapf: disable
from .idefics2_vision_model import Idefics2VisionConfig from .idefics2_vision_model import Idefics2VisionConfig
@ -42,15 +43,26 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
merge_multimodal_embeddings) merge_multimodal_embeddings)
class AriaImagePixelInputs(TypedDict): class AriaImagePixelInputs(TensorSchema):
pixel_values: torch.Tensor
pixel_mask: Optional[torch.Tensor]
""" """
Shape: Dimensions:
pixel_values: `(batch_size * num_images, num_channels, height, width)` - b: Batch size
pixel_mask: `(batch_size * num_images, height, width)` - n: Number of images
- c: Number of channels
- h: Height of each image
- w: Width of each image
""" """
pixel_values: Annotated[
torch.Tensor,
TensorShape("bn", 3, "h", "w"),
]
pixel_mask: Annotated[
Optional[torch.Tensor],
TensorShape("bn", "h", "w"),
]
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant): class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
@ -540,12 +552,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale) self.vocab_size, logit_scale)
def _validate_image_sizes(
self, images: list[torch.Tensor]) -> list[torch.Tensor]:
if not all(img.shape == images[0].shape for img in images):
raise ValueError("All images must be the same size")
return images
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[AriaImagePixelInputs]: self, **kwargs: object) -> Optional[AriaImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
@ -554,23 +560,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
if pixel_values is None: if pixel_values is None:
return None return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_values = self._validate_image_sizes(pixel_values)
pixel_values = flatten_bn(pixel_values, concat=True)
if pixel_mask is not None:
if not isinstance(pixel_mask, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel mask. "
f"Got type: {type(pixel_mask)}")
pixel_mask = flatten_bn(pixel_mask, concat=True)
return AriaImagePixelInputs( return AriaImagePixelInputs(
pixel_values=pixel_values, pixel_values=flatten_bn(pixel_values, concat=True),
pixel_mask=pixel_mask, pixel_mask=flatten_bn(pixel_mask, concat=True),
) )
def _create_patch_attention_mask( def _create_patch_attention_mask(