From 3ea57a56d9113ffb81673918931a69058fce4ae1 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sun, 27 Jul 2025 22:37:23 -0700 Subject: [PATCH] =?UTF-8?q?Migrate=20Idefics3ImagePixelInputs=20and=20Idef?= =?UTF-8?q?ics3ImageEmbeddingInputs=20to=20=E2=80=A6=20(#21683)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Benji Beck --- vllm/model_executor/models/idefics3.py | 73 +++++++++++--------------- 1 file changed, 30 insertions(+), 43 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index de216a81e9344..6e991d99b9638 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -18,7 +18,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -45,6 +45,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape # yapf: disable from .idefics2_vision_model import ( @@ -56,26 +57,30 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) -class Idefics3ImagePixelInputs(TypedDict): +class Idefics3ImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - bnp: Batch size * number of images * number of patches + - c: Number of channels (3) + - h: Height + - w: Width + """ type: Literal["pixel_values"] - pixel_values: torch.Tensor - """ - Shape: `(batch_size * num_images * num_patches, - num_channels, height, width)` - """ + pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] pixel_attention_mask: torch.Tensor - - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" + num_patches: Annotated[torch.Tensor, TensorShape("bn")] -class Idefics3ImageEmbeddingInputs(TypedDict): +class Idefics3ImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - f: Image feature size + - h: Hidden size (must match the hidden size of language model backbone) + """ type: Literal["image_embeds"] - data: torch.Tensor - """ - Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. - """ + data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] @@ -614,25 +619,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, self.lm_head.weight = self.model.text_model.wte.weight self.logits_processor = LogitsProcessor(config.text_config.vocab_size) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -666,16 +652,17 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError("Incorrect type of num_patches. " f"Got type: {type(num_patches)}") - pixel_values = flatten_bn(pixel_values, concat=True) - pixel_attention_mask = flatten_bn(pixel_attention_mask, - concat=True) - num_patches = flatten_bn(num_patches, concat=True) - + expected_h = expected_w = self.config.vision_config.image_size return Idefics3ImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values(pixel_values), - pixel_attention_mask=pixel_attention_mask, - num_patches=num_patches, + pixel_values=flatten_bn(pixel_values, concat=True), + pixel_attention_mask=flatten_bn(pixel_attention_mask, + concat=True), + num_patches=flatten_bn(num_patches, concat=True), + resolve_bindings={ + "h": expected_h, + "w": expected_w + }, ) raise AssertionError("This line should be unreachable.")