Migrate Llama4ImagePatchInputs to TensorSchema (#22021)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-08-28 10:29:37 -07:00 committed by GitHub
parent 8805ad9fa9
commit f32a5bc505
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,7 +19,7 @@
import math
from collections.abc import Iterable, Mapping
from itertools import tee
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
from torch import nn
@ -53,6 +53,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_vision_model
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llama4 import Llama4ForCausalLM
@ -60,28 +61,34 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
class Llama4ImagePatchInputs(TypedDict):
type: Literal["pixel_values"]
flat_data: torch.Tensor
class Llama4ImagePatchInputs(TensorSchema):
"""
Shape:
`(batch_size * num_chunks, num_channels, image size, image size)`
Dimensions:
- batch_size: Batch size
- total_num_chunks: Batch size * number of chunks
- num_channels: Number of channels
- image_size: Size of each image
"""
patches_per_image: torch.Tensor
type: Literal["pixel_values"] = "pixel_values"
flat_data: Annotated[torch.Tensor,
TensorShape("total_num_chunks", "num_channels",
"image_size", "image_size")]
patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")]
"""
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""
aspect_ratios: Union[torch.Tensor, list[torch.Tensor]]
aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
"""
A list of aspect ratios corresponding to the number of tiles
in each dimension that each image in the batch corresponds to.
Shape:
`(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)`
Each aspect ratio is a pair (ratio_h, ratio_w).
"""
@ -623,7 +630,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
for (r_h, r_w) in aspect_ratios
]
processed_outputs["aspect_ratios"] = aspect_ratios
processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios)
processed_outputs["patches_per_image"] = torch.tensor(
patches_per_image)
@ -770,11 +777,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
# TODO: confirm handling for variable lengths
flat_pixel_values = flatten_bn(pixel_values, concat=True)
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
aspect_ratios = kwargs.pop("aspect_ratios", None)
if not isinstance(aspect_ratios, (torch.Tensor, list)):
raise ValueError("Incorrect type of aspect_ratios. "
f"Got type: {type(aspect_ratios)}")
aspect_ratios = kwargs.pop("aspect_ratios")
if aspect_ratios.ndim == 3:
aspect_ratios = aspect_ratios.squeeze(1)
return Llama4ImagePatchInputs(
type="pixel_values",