diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index ac9b968f7a0cd..ecbbb5f57bec8 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -19,7 +19,7 @@ import math from collections.abc import Iterable, Mapping from itertools import tee -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -53,6 +53,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llama4 import Llama4ForCausalLM @@ -60,28 +61,34 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) -class Llama4ImagePatchInputs(TypedDict): - type: Literal["pixel_values"] - flat_data: torch.Tensor +class Llama4ImagePatchInputs(TensorSchema): """ - Shape: - `(batch_size * num_chunks, num_channels, image size, image size)` + Dimensions: + - batch_size: Batch size + - total_num_chunks: Batch size * number of chunks + - num_channels: Number of channels + - image_size: Size of each image """ - patches_per_image: torch.Tensor + + type: Literal["pixel_values"] = "pixel_values" + + flat_data: Annotated[torch.Tensor, + TensorShape("total_num_chunks", "num_channels", + "image_size", "image_size")] + + patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")] """ The number of total patches for each image in the batch. - + This is used to split the embeddings which has the first two dimensions flattened just like `flat_data`. """ - aspect_ratios: Union[torch.Tensor, list[torch.Tensor]] + aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)] """ A list of aspect ratios corresponding to the number of tiles in each dimension that each image in the batch corresponds to. - - Shape: - `(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)` + Each aspect ratio is a pair (ratio_h, ratio_w). """ @@ -623,7 +630,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] for (r_h, r_w) in aspect_ratios ] - processed_outputs["aspect_ratios"] = aspect_ratios + processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios) processed_outputs["patches_per_image"] = torch.tensor( patches_per_image) @@ -770,11 +777,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, # TODO: confirm handling for variable lengths flat_pixel_values = flatten_bn(pixel_values, concat=True) patches_per_image = flatten_bn(kwargs.pop("patches_per_image")) - - aspect_ratios = kwargs.pop("aspect_ratios", None) - if not isinstance(aspect_ratios, (torch.Tensor, list)): - raise ValueError("Incorrect type of aspect_ratios. " - f"Got type: {type(aspect_ratios)}") + aspect_ratios = kwargs.pop("aspect_ratios") + if aspect_ratios.ndim == 3: + aspect_ratios = aspect_ratios.squeeze(1) return Llama4ImagePatchInputs( type="pixel_values",