diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 04fb6b5736a5..a63c18493df5 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, +from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, Union) import torch @@ -11,7 +11,6 @@ import torch.nn as nn from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) -from typing_extensions import NotRequired from vllm.config import VllmConfig from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -19,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.parse import ImageSize from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -30,32 +30,36 @@ from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal, flatten_bn, init_vllm_registered_model, maybe_prefix) -class LlavaNextImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class LlavaNextImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - + Dimensions: + - bn: Batch size * number of images + - np: Number of patches + 1 + - c: Number of channels (3) + - h: Height + - w: Width + Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"})] - image_sizes: NotRequired[torch.Tensor] + image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + # This should be in `(height, width)` format. + + +class LlavaNextImageEmbeddingInputs(TensorSchema): """ - Shape: `(batch_size * num_images, 2)` - - This should be in `(height, width)` format. - """ - - -class LlavaNextImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, @@ -269,44 +273,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - expected_dims = (2, ) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - f"The expected shape of image sizes per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - - def _validate_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("num_patches", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaNextImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -325,13 +291,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") + expected_h = expected_w = self.config.vision_config.image_size return LlavaNextImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values)), - image_sizes=self._validate_image_sizes( - flatten_bn(image_sizes, concat=True)), - ) + pixel_values=flatten_bn(pixel_values), + image_sizes=flatten_bn(image_sizes, concat=True), + resolve_bindings={ + "h": expected_h, + "w": expected_w, + }) if image_embeds is not None: if not isinstance(image_embeds, torch.Tensor): diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py index 343df71e1058..4c3acf0094c7 100644 --- a/vllm/utils/tensor_schema.py +++ b/vllm/utils/tensor_schema.py @@ -60,6 +60,9 @@ class TensorSchema: def __getitem__(self, item) -> Any: return getattr(self, item) + def get(self, item, default=None) -> Any: + return getattr(self, item, default) + def _match_shape_with_dynamic(self, actual: tuple[int, ...], reference: tuple[int, ...], expected_shape: tuple[Union[int, str], ...],