diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 74b18df7214b..8d705f40ce8f 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import torch import torch.nn as nn @@ -38,6 +38,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, SupportsQuant) @@ -48,10 +49,16 @@ from .utils import (flatten_bn, is_pp_missing_parameter, logger = init_logger(__name__) -class ChameleonImagePixelInputs(TypedDict): +class ChameleonImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + """ type: Literal["pixel_values"] - data: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] class ChameleonProcessingInfo(BaseProcessingInfo): @@ -962,19 +969,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - vq_config: ChameleonVQVAEConfig = self.config.vq_config - expected_dims = (3, vq_config.resolution, vq_config.resolution) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -982,16 +976,16 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values is None: return None - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") + vq_config: ChameleonVQVAEConfig = self.config.vq_config + expected_h = expected_w = vq_config.resolution - pixel_values = flatten_bn(pixel_values, concat=True) - - return ChameleonImagePixelInputs( - type="pixel_values", - data=self._validate_pixel_values(pixel_values), - ) + return ChameleonImagePixelInputs(type="pixel_values", + data=flatten_bn(pixel_values, + concat=True), + resolve_bindings={ + "h": expected_h, + "w": expected_w + }) def get_language_model(self) -> torch.nn.Module: return self.model