mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 02:25:42 +08:00
Migrate OvisImagePatchInputs to TensorSchema (#22024)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
d59c986444
commit
1fa1d6a9a0
@ -19,7 +19,7 @@
|
||||
""" PyTorch Ovis model."""
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -49,6 +49,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.processors.ovis import OvisProcessor
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import merge_multimodal_embeddings
|
||||
@ -201,25 +202,22 @@ class VisualTokenizer(torch.nn.Module):
|
||||
return tokens
|
||||
|
||||
|
||||
class OvisImagePatchInputs(TypedDict):
|
||||
class OvisImagePatchInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- batch_patches: Batch size * number of patches
|
||||
- patch_size: patch_size_x * patch_size_y * num_channels
|
||||
- patch_indicators: Batch size * (number of patches + 1)
|
||||
- patches_per_image: List of number of total patches for each image
|
||||
in the batch.
|
||||
"""
|
||||
type: Literal["image_patches"]
|
||||
flat_data: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||
"""
|
||||
|
||||
indicator_tokens: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * (num_patches + 1))`
|
||||
"""
|
||||
|
||||
patches_per_image: list[int]
|
||||
"""
|
||||
List of number of total patches for each image in the batch.
|
||||
This is used to restore the first two dimensions of `flat_data`.
|
||||
"""
|
||||
flat_data: Annotated[torch.Tensor,
|
||||
TensorShape("batch_patches", "patch_size")]
|
||||
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
|
||||
patches_per_image: Annotated[list[int],
|
||||
TensorShape("num_patches_per_image")]
|
||||
# This is used to restore the first two dimensions of `flat_data`.
|
||||
|
||||
|
||||
class VisualEmbedding(torch.nn.Embedding):
|
||||
@ -458,9 +456,12 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of indicator_tokens. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
flat_data = flatten_bn(pixel_values, concat=True)
|
||||
if flat_data.ndim >= 3:
|
||||
flat_data = flat_data.flatten(start_dim=1)
|
||||
return OvisImagePatchInputs(
|
||||
type="image_patches",
|
||||
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
|
||||
flat_data=flat_data,
|
||||
patches_per_image=[
|
||||
x.shape[0] for x in flatten_bn(pixel_values)
|
||||
],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user