Migrate Llama4ImagePatchInputs to TensorSchema (#22021)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-08-28 10:29:37 -07:00 committed by GitHub
parent 8805ad9fa9
commit f32a5bc505
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,7 +19,7 @@
import math import math
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from itertools import tee from itertools import tee
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
from torch import nn from torch import nn
@ -53,6 +53,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.multimodal.utils import run_dp_sharded_vision_model
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llama4 import Llama4ForCausalLM from .llama4 import Llama4ForCausalLM
@ -60,28 +61,34 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
class Llama4ImagePatchInputs(TypedDict): class Llama4ImagePatchInputs(TensorSchema):
type: Literal["pixel_values"]
flat_data: torch.Tensor
""" """
Shape: Dimensions:
`(batch_size * num_chunks, num_channels, image size, image size)` - batch_size: Batch size
- total_num_chunks: Batch size * number of chunks
- num_channels: Number of channels
- image_size: Size of each image
""" """
patches_per_image: torch.Tensor
type: Literal["pixel_values"] = "pixel_values"
flat_data: Annotated[torch.Tensor,
TensorShape("total_num_chunks", "num_channels",
"image_size", "image_size")]
patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")]
""" """
The number of total patches for each image in the batch. The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`. flattened just like `flat_data`.
""" """
aspect_ratios: Union[torch.Tensor, list[torch.Tensor]] aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
""" """
A list of aspect ratios corresponding to the number of tiles A list of aspect ratios corresponding to the number of tiles
in each dimension that each image in the batch corresponds to. in each dimension that each image in the batch corresponds to.
Each aspect ratio is a pair (ratio_h, ratio_w).
Shape:
`(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)`
""" """
@ -623,7 +630,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
for (r_h, r_w) in aspect_ratios for (r_h, r_w) in aspect_ratios
] ]
processed_outputs["aspect_ratios"] = aspect_ratios processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios)
processed_outputs["patches_per_image"] = torch.tensor( processed_outputs["patches_per_image"] = torch.tensor(
patches_per_image) patches_per_image)
@ -770,11 +777,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
# TODO: confirm handling for variable lengths # TODO: confirm handling for variable lengths
flat_pixel_values = flatten_bn(pixel_values, concat=True) flat_pixel_values = flatten_bn(pixel_values, concat=True)
patches_per_image = flatten_bn(kwargs.pop("patches_per_image")) patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
aspect_ratios = kwargs.pop("aspect_ratios")
aspect_ratios = kwargs.pop("aspect_ratios", None) if aspect_ratios.ndim == 3:
if not isinstance(aspect_ratios, (torch.Tensor, list)): aspect_ratios = aspect_ratios.squeeze(1)
raise ValueError("Incorrect type of aspect_ratios. "
f"Got type: {type(aspect_ratios)}")
return Llama4ImagePatchInputs( return Llama4ImagePatchInputs(
type="pixel_values", type="pixel_values",