diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 27a9208107871..8e3505f872eb2 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -22,6 +22,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptInsertion, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .blip import BlipVisionModel from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, @@ -34,19 +35,27 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, _IMAGE_TOKEN_ID = 50265 -class Blip2ImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" - - -class Blip2ImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class Blip2ImagePixelInputs(TensorSchema): """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + """ + type: Literal["pixel_values"] + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] + + +class Blip2ImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - f: Image feature size + - h: Hidden size (must match the hidden size of language model backbone) + """ + type: Literal["image_embeds"] + data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs] @@ -551,21 +560,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Blip2ImageInputs]: + def _create_image_input(self, + **kwargs: object) -> Optional[Blip2ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -573,27 +569,19 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - pixel_values = flatten_bn(pixel_values, concat=True) - - return Blip2ImagePixelInputs( - type="pixel_values", - data=self._validate_pixel_values(pixel_values), - ) + expected_h = expected_w = self.config.vision_config.image_size + return Blip2ImagePixelInputs(type="pixel_values", + data=flatten_bn(pixel_values, + concat=True), + resolve_bindings={ + "h": expected_h, + "w": expected_w + }) if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - - image_embeds = flatten_bn(image_embeds, concat=True) - return Blip2ImageEmbeddingInputs( type="image_embeds", - data=image_embeds, + data=flatten_bn(image_embeds, concat=True), ) raise AssertionError("This line should be unreachable.")