diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 3e1c64bb62eab..36e57b5e4f46a 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import numpy as np import torch @@ -46,6 +46,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import ( cached_image_processor_from_config) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -102,77 +103,62 @@ def smart_resize( return h_bar, w_bar -class KeyeImagePixelInputs(TypedDict): +class KeyeImagePixelInputs(TensorSchema): + """ + Dimensions: + - np: Number of patches + - cps: Number of channels * patch_size * patch_size + - ni: Number of images + - g: Grid dimensions (3 for t, h, w) + """ type: Literal["pixel_values"] - pixel_values: torch.Tensor - """Shape: - `(num_patches, num_channels * patch_size * patch_size)` + pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +class KeyeImageEmbeddingInputs(TensorSchema): """ - - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - nf: Number of image features + - hs: Hidden size (must match the hidden size of language model + backbone) + - ni: Number of images + - g: Grid dimensions (3 for t, h, w) """ - - -class KeyeImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - image_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all images' features. - Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). - - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the images. - - `hidden_size` must match the hidden size of language model backbone. - """ - - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ + image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs] -class KeyeVideoPixelInputs(TypedDict): +class KeyeVideoPixelInputs(TensorSchema): + """ + Dimensions: + - np: Number of patches + - ctps: Number of channels * temporal_patch_size * patch_size * + patch_size + - nv: Number of videos + - g: Grid dimensions (3 for t, h, w) + """ type: Literal["pixel_values_videos"] - pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, - num_channels * temporal_patch_size * patch_size * patch_size)` + pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctps")] + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] + + +class KeyeVideoEmbeddingInputs(TensorSchema): """ - - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - nf: Number of video features + - hs: Hidden size (must match the hidden size of language model + backbone) + - nv: Number of videos + - g: Grid dimensions (3 for t, h, w) """ - - -class KeyeVideoEmbeddingInputs(TypedDict): type: Literal["video_embeds"] - video_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all videos' features. - Each tensor holds an video's features. - - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). - - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the videos. - - `hidden_size` must match the hidden size of language model backbone. - """ - - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ + video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] KeyeVideoInputs = Union[KeyeVideoPixelInputs, KeyeVideoEmbeddingInputs] @@ -1420,10 +1406,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - return KeyeImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -1436,9 +1418,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return KeyeImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -1474,9 +1453,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") - if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") return KeyeVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds,