Migrate KeyeImageInputs and KeyeVideoInputs to TensorSchema (#21686)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-07-28 01:16:35 -07:00 committed by GitHub
parent a6c050286a
commit d128d0d554
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import numpy as np
import torch
@ -46,6 +46,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import (
cached_image_processor_from_config)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
@ -102,77 +103,62 @@ def smart_resize(
return h_bar, w_bar
class KeyeImagePixelInputs(TypedDict):
class KeyeImagePixelInputs(TensorSchema):
"""
Dimensions:
- np: Number of patches
- cps: Number of channels * patch_size * patch_size
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
class KeyeImageEmbeddingInputs(TensorSchema):
"""
image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- nf: Number of image features
- hs: Hidden size (must match the hidden size of language model
backbone)
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
class KeyeImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
"""Supported types:
- list[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the images.
- `hidden_size` must match the hidden size of language model backbone.
"""
image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs]
class KeyeVideoPixelInputs(TypedDict):
class KeyeVideoPixelInputs(TensorSchema):
"""
Dimensions:
- np: Number of patches
- ctps: Number of channels * temporal_patch_size * patch_size *
patch_size
- nv: Number of videos
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values_videos"]
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctps")]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
class KeyeVideoEmbeddingInputs(TensorSchema):
"""
video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- nf: Number of video features
- hs: Hidden size (must match the hidden size of language model
backbone)
- nv: Number of videos
- g: Grid dimensions (3 for t, h, w)
"""
class KeyeVideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"]
video_embeds: torch.Tensor
"""Supported types:
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features
(concatenation of all videos' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the videos.
- `hidden_size` must match the hidden size of language model backbone.
"""
video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
KeyeVideoInputs = Union[KeyeVideoPixelInputs, KeyeVideoEmbeddingInputs]
@ -1420,10 +1406,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
return KeyeImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
@ -1436,9 +1418,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return KeyeImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
@ -1474,9 +1453,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
if not isinstance(video_embeds, torch.Tensor):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return KeyeVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,