Migrate MllamaImagePixelInputs to TensorSchema (#22020)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Benji Beck 2025-08-21 20:28:49 -07:00 committed by GitHub
parent 8896eb72eb
commit 0b9cc56fac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,7 +17,7 @@
"""PyTorch Mllama model.""" """PyTorch Mllama model."""
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -64,6 +64,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor, EncDecMultiModalProcessor,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPMLP from .clip import CLIPMLP
from .interfaces import SupportsMultiModal, SupportsV0Only from .interfaces import SupportsMultiModal, SupportsV0Only
@ -73,15 +74,30 @@ from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
logger = init_logger(__name__) logger = init_logger(__name__)
class MllamaImagePixelInputs(TypedDict): class MllamaImagePixelInputs(TensorSchema):
type: Literal["pixel_values"] """
data: torch.Tensor Dimensions:
"""Shape: """ - batch_size: Batch size
"""(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" - max_num_image: Max number of images
aspect_ratio_ids: torch.Tensor - max_num_chunk: Max number of chunks
"""Shape: `(batch_size, max_num_image)`""" - max_num_tiles: Max number of tiles per image
aspect_ratio_mask: torch.Tensor - num_channel: Number of channels
"""Shape: `(batch_size, max_num_image, max_num_tiles)`""" - height: Height
- width: Width
"""
type: Literal["pixel_values"] = "pixel_values"
data: Annotated[torch.Tensor,
TensorShape("batch_size", "max_num_image", "max_num_chunk",
"num_channel", "height", "width")]
aspect_ratio_ids: Annotated[torch.Tensor,
TensorShape("batch_size", "max_num_image")]
aspect_ratio_mask: Annotated[
torch.Tensor,
TensorShape("batch_size", "max_num_image", "max_num_tiles")]
# TODO: support LlamaImageEmbeddingInputs # TODO: support LlamaImageEmbeddingInputs