Migrate tarsier inputs to TensorSchema (#23500)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-08-24 21:42:36 -07:00 committed by GitHub
parent 170e8ea9ea
commit 99f8094400
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union, cast) Union, cast)
import torch import torch
@ -34,6 +34,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -43,14 +44,28 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
from .vision import VisionEncoderInfo, get_vision_encoder_info from .vision import VisionEncoderInfo, get_vision_encoder_info
class TarsierImagePixelInputs(TypedDict): class TarsierImagePixelInputs(TensorSchema):
type: Literal["pixel_values"] """
pixel_values: torch.Tensor Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class TarsierImageEmbeddingInputs(TypedDict): class TarsierImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"] """
data: torch.Tensor Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
TarsierImageInputs = Union[TarsierImagePixelInputs, TarsierImageInputs = Union[TarsierImagePixelInputs,
@ -432,18 +447,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) # Assuming 3 channels
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[TarsierImageInputs]: self, **kwargs: object) -> Optional[TarsierImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
@ -459,8 +462,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
return TarsierImagePixelInputs( return TarsierImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values( pixel_values=flatten_bn(pixel_values, concat=True),
flatten_bn(pixel_values, concat=True)),
) )
if image_embeds is not None: if image_embeds is not None: