mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 06:45:01 +08:00
Migrate OvisImagePatchInputs to TensorSchema (#22024)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
d59c986444
commit
1fa1d6a9a0
@ -19,7 +19,7 @@
|
|||||||
""" PyTorch Ovis model."""
|
""" PyTorch Ovis model."""
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from typing import Literal, Optional, TypedDict, Union
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -49,6 +49,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.processors.ovis import OvisProcessor
|
from vllm.transformers_utils.processors.ovis import OvisProcessor
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import merge_multimodal_embeddings
|
from .utils import merge_multimodal_embeddings
|
||||||
@ -201,25 +202,22 @@ class VisualTokenizer(torch.nn.Module):
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
class OvisImagePatchInputs(TypedDict):
|
class OvisImagePatchInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- batch_patches: Batch size * number of patches
|
||||||
|
- patch_size: patch_size_x * patch_size_y * num_channels
|
||||||
|
- patch_indicators: Batch size * (number of patches + 1)
|
||||||
|
- patches_per_image: List of number of total patches for each image
|
||||||
|
in the batch.
|
||||||
|
"""
|
||||||
type: Literal["image_patches"]
|
type: Literal["image_patches"]
|
||||||
flat_data: torch.Tensor
|
flat_data: Annotated[torch.Tensor,
|
||||||
"""
|
TensorShape("batch_patches", "patch_size")]
|
||||||
Shape:
|
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
|
||||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
patches_per_image: Annotated[list[int],
|
||||||
"""
|
TensorShape("num_patches_per_image")]
|
||||||
|
# This is used to restore the first two dimensions of `flat_data`.
|
||||||
indicator_tokens: torch.Tensor
|
|
||||||
"""
|
|
||||||
Shape:
|
|
||||||
`(batch_size * (num_patches + 1))`
|
|
||||||
"""
|
|
||||||
|
|
||||||
patches_per_image: list[int]
|
|
||||||
"""
|
|
||||||
List of number of total patches for each image in the batch.
|
|
||||||
This is used to restore the first two dimensions of `flat_data`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class VisualEmbedding(torch.nn.Embedding):
|
class VisualEmbedding(torch.nn.Embedding):
|
||||||
@ -458,9 +456,12 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
raise ValueError("Incorrect type of indicator_tokens. "
|
raise ValueError("Incorrect type of indicator_tokens. "
|
||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
flat_data = flatten_bn(pixel_values, concat=True)
|
||||||
|
if flat_data.ndim >= 3:
|
||||||
|
flat_data = flat_data.flatten(start_dim=1)
|
||||||
return OvisImagePatchInputs(
|
return OvisImagePatchInputs(
|
||||||
type="image_patches",
|
type="image_patches",
|
||||||
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
|
flat_data=flat_data,
|
||||||
patches_per_image=[
|
patches_per_image=[
|
||||||
x.shape[0] for x in flatten_bn(pixel_values)
|
x.shape[0] for x in flatten_bn(pixel_values)
|
||||||
],
|
],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user