diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index cfc6ffd99af6..708ca9899522 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, +from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, Union, cast) import torch @@ -33,6 +33,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -44,35 +45,46 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .vision import get_vision_encoder_info -class LlavaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class LlavaImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, height, width)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + Note that `height` or `width` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class PixtralHFImagePixelInputs(TypedDict): - type: Literal["pixel_values_pixtral"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class PixtralHFImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, height, width)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels + - h: Height + - w: Width + Note that `height` or `width` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral" + pixel_values: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "c", "h", "w")] -class LlavaImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class LlavaImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs, @@ -547,19 +559,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -579,10 +578,14 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): pixel_values=flatten_bn(pixel_values), ) + expected_h = expected_w = self.config.vision_config.image_size return LlavaImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + pixel_values=flatten_bn(pixel_values, concat=True), + resolve_bindings={ + "h": expected_h, + "w": expected_w + }, ) if image_embeds is not None: