Migrate OvisImagePatchInputs to TensorSchema (#22024)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-09-01 21:01:36 -07:00 committed by GitHub
parent d59c986444
commit 1fa1d6a9a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,7 +19,7 @@
""" PyTorch Ovis model."""
import math
from collections.abc import Iterable, Mapping
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@ -49,6 +49,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings
@ -201,25 +202,22 @@ class VisualTokenizer(torch.nn.Module):
return tokens
class OvisImagePatchInputs(TypedDict):
class OvisImagePatchInputs(TensorSchema):
"""
Dimensions:
- batch_patches: Batch size * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels
- patch_indicators: Batch size * (number of patches + 1)
- patches_per_image: List of number of total patches for each image
in the batch.
"""
type: Literal["image_patches"]
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
indicator_tokens: torch.Tensor
"""
Shape:
`(batch_size * (num_patches + 1))`
"""
patches_per_image: list[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `flat_data`.
"""
flat_data: Annotated[torch.Tensor,
TensorShape("batch_patches", "patch_size")]
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
patches_per_image: Annotated[list[int],
TensorShape("num_patches_per_image")]
# This is used to restore the first two dimensions of `flat_data`.
class VisualEmbedding(torch.nn.Embedding):
@ -458,9 +456,12 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of indicator_tokens. "
f"Got type: {type(pixel_values)}")
flat_data = flatten_bn(pixel_values, concat=True)
if flat_data.ndim >= 3:
flat_data = flat_data.flatten(start_dim=1)
return OvisImagePatchInputs(
type="image_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
flat_data=flat_data,
patches_per_image=[
x.shape[0] for x in flatten_bn(pixel_values)
],