mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-04 11:11:19 +08:00
Migrate Llama4ImagePatchInputs to TensorSchema (#22021)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
8805ad9fa9
commit
f32a5bc505
@ -19,7 +19,7 @@
|
|||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from itertools import tee
|
from itertools import tee
|
||||||
from typing import Literal, Optional, TypedDict, Union
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -53,6 +53,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .llama4 import Llama4ForCausalLM
|
from .llama4 import Llama4ForCausalLM
|
||||||
@ -60,14 +61,22 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
|||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
|
||||||
|
|
||||||
class Llama4ImagePatchInputs(TypedDict):
|
class Llama4ImagePatchInputs(TensorSchema):
|
||||||
type: Literal["pixel_values"]
|
|
||||||
flat_data: torch.Tensor
|
|
||||||
"""
|
"""
|
||||||
Shape:
|
Dimensions:
|
||||||
`(batch_size * num_chunks, num_channels, image size, image size)`
|
- batch_size: Batch size
|
||||||
|
- total_num_chunks: Batch size * number of chunks
|
||||||
|
- num_channels: Number of channels
|
||||||
|
- image_size: Size of each image
|
||||||
"""
|
"""
|
||||||
patches_per_image: torch.Tensor
|
|
||||||
|
type: Literal["pixel_values"] = "pixel_values"
|
||||||
|
|
||||||
|
flat_data: Annotated[torch.Tensor,
|
||||||
|
TensorShape("total_num_chunks", "num_channels",
|
||||||
|
"image_size", "image_size")]
|
||||||
|
|
||||||
|
patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")]
|
||||||
"""
|
"""
|
||||||
The number of total patches for each image in the batch.
|
The number of total patches for each image in the batch.
|
||||||
|
|
||||||
@ -75,13 +84,11 @@ class Llama4ImagePatchInputs(TypedDict):
|
|||||||
flattened just like `flat_data`.
|
flattened just like `flat_data`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
aspect_ratios: Union[torch.Tensor, list[torch.Tensor]]
|
aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
|
||||||
"""
|
"""
|
||||||
A list of aspect ratios corresponding to the number of tiles
|
A list of aspect ratios corresponding to the number of tiles
|
||||||
in each dimension that each image in the batch corresponds to.
|
in each dimension that each image in the batch corresponds to.
|
||||||
|
Each aspect ratio is a pair (ratio_h, ratio_w).
|
||||||
Shape:
|
|
||||||
`(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -623,7 +630,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
|||||||
for (r_h, r_w) in aspect_ratios
|
for (r_h, r_w) in aspect_ratios
|
||||||
]
|
]
|
||||||
|
|
||||||
processed_outputs["aspect_ratios"] = aspect_ratios
|
processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios)
|
||||||
processed_outputs["patches_per_image"] = torch.tensor(
|
processed_outputs["patches_per_image"] = torch.tensor(
|
||||||
patches_per_image)
|
patches_per_image)
|
||||||
|
|
||||||
@ -770,11 +777,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
# TODO: confirm handling for variable lengths
|
# TODO: confirm handling for variable lengths
|
||||||
flat_pixel_values = flatten_bn(pixel_values, concat=True)
|
flat_pixel_values = flatten_bn(pixel_values, concat=True)
|
||||||
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
|
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
|
||||||
|
aspect_ratios = kwargs.pop("aspect_ratios")
|
||||||
aspect_ratios = kwargs.pop("aspect_ratios", None)
|
if aspect_ratios.ndim == 3:
|
||||||
if not isinstance(aspect_ratios, (torch.Tensor, list)):
|
aspect_ratios = aspect_ratios.squeeze(1)
|
||||||
raise ValueError("Incorrect type of aspect_ratios. "
|
|
||||||
f"Got type: {type(aspect_ratios)}")
|
|
||||||
|
|
||||||
return Llama4ImagePatchInputs(
|
return Llama4ImagePatchInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user