mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 20:41:51 +08:00
Migrate Llama4ImagePatchInputs to TensorSchema (#22021)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
8805ad9fa9
commit
f32a5bc505
@ -19,7 +19,7 @@
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping
|
||||
from itertools import tee
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -53,6 +53,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .llama4 import Llama4ForCausalLM
|
||||
@ -60,28 +61,34 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
|
||||
class Llama4ImagePatchInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
flat_data: torch.Tensor
|
||||
class Llama4ImagePatchInputs(TensorSchema):
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_chunks, num_channels, image size, image size)`
|
||||
Dimensions:
|
||||
- batch_size: Batch size
|
||||
- total_num_chunks: Batch size * number of chunks
|
||||
- num_channels: Number of channels
|
||||
- image_size: Size of each image
|
||||
"""
|
||||
patches_per_image: torch.Tensor
|
||||
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
|
||||
flat_data: Annotated[torch.Tensor,
|
||||
TensorShape("total_num_chunks", "num_channels",
|
||||
"image_size", "image_size")]
|
||||
|
||||
patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")]
|
||||
"""
|
||||
The number of total patches for each image in the batch.
|
||||
|
||||
|
||||
This is used to split the embeddings which has the first two dimensions
|
||||
flattened just like `flat_data`.
|
||||
"""
|
||||
|
||||
aspect_ratios: Union[torch.Tensor, list[torch.Tensor]]
|
||||
aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
|
||||
"""
|
||||
A list of aspect ratios corresponding to the number of tiles
|
||||
in each dimension that each image in the batch corresponds to.
|
||||
|
||||
Shape:
|
||||
`(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)`
|
||||
Each aspect ratio is a pair (ratio_h, ratio_w).
|
||||
"""
|
||||
|
||||
|
||||
@ -623,7 +630,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
for (r_h, r_w) in aspect_ratios
|
||||
]
|
||||
|
||||
processed_outputs["aspect_ratios"] = aspect_ratios
|
||||
processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios)
|
||||
processed_outputs["patches_per_image"] = torch.tensor(
|
||||
patches_per_image)
|
||||
|
||||
@ -770,11 +777,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# TODO: confirm handling for variable lengths
|
||||
flat_pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
|
||||
|
||||
aspect_ratios = kwargs.pop("aspect_ratios", None)
|
||||
if not isinstance(aspect_ratios, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of aspect_ratios. "
|
||||
f"Got type: {type(aspect_ratios)}")
|
||||
aspect_ratios = kwargs.pop("aspect_ratios")
|
||||
if aspect_ratios.ndim == 3:
|
||||
aspect_ratios = aspect_ratios.squeeze(1)
|
||||
|
||||
return Llama4ImagePatchInputs(
|
||||
type="pixel_values",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user