Migrate OvisImagePatchInputs to TensorSchema (#22024)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-09-01 21:01:36 -07:00 committed by GitHub
parent d59c986444
commit 1fa1d6a9a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,7 +19,7 @@
""" PyTorch Ovis model.""" """ PyTorch Ovis model."""
import math import math
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -49,6 +49,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis import OvisProcessor from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings from .utils import merge_multimodal_embeddings
@ -201,25 +202,22 @@ class VisualTokenizer(torch.nn.Module):
return tokens return tokens
class OvisImagePatchInputs(TypedDict): class OvisImagePatchInputs(TensorSchema):
"""
Dimensions:
- batch_patches: Batch size * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels
- patch_indicators: Batch size * (number of patches + 1)
- patches_per_image: List of number of total patches for each image
in the batch.
"""
type: Literal["image_patches"] type: Literal["image_patches"]
flat_data: torch.Tensor flat_data: Annotated[torch.Tensor,
""" TensorShape("batch_patches", "patch_size")]
Shape: indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` patches_per_image: Annotated[list[int],
""" TensorShape("num_patches_per_image")]
# This is used to restore the first two dimensions of `flat_data`.
indicator_tokens: torch.Tensor
"""
Shape:
`(batch_size * (num_patches + 1))`
"""
patches_per_image: list[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `flat_data`.
"""
class VisualEmbedding(torch.nn.Embedding): class VisualEmbedding(torch.nn.Embedding):
@ -458,9 +456,12 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of indicator_tokens. " raise ValueError("Incorrect type of indicator_tokens. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
flat_data = flatten_bn(pixel_values, concat=True)
if flat_data.ndim >= 3:
flat_data = flat_data.flatten(start_dim=1)
return OvisImagePatchInputs( return OvisImagePatchInputs(
type="image_patches", type="image_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), flat_data=flat_data,
patches_per_image=[ patches_per_image=[
x.shape[0] for x in flatten_bn(pixel_values) x.shape[0] for x in flatten_bn(pixel_values)
], ],