Migrate FuyuImagePatchInputs to TensorSchema (#21662)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-07-26 19:34:14 -07:00 committed by GitHub
parent 0b8caf9095
commit 3339cba3ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 44 deletions

View File

@ -4,6 +4,7 @@
import pytest
import torch
from vllm.model_executor.models.fuyu import FuyuImagePatchInputs
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
@ -124,3 +125,24 @@ def test_tensor_schema_with_invalid_resolve_binding_dims():
"w": 336
},
)
def test_tensor_schema_with_list_of_symbolic_dim():
flat_data = torch.stack([torch.randn(768) for _ in range(3)]) # (bn=3, fn)
patches_per_image = [64, 64, 64] # len = bn = 3
FuyuImagePatchInputs(
flat_data=flat_data,
patches_per_image=patches_per_image,
)
def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length():
flat_data = torch.stack([torch.randn(768) for _ in range(4)]) # (bn=4, fn)
patches_per_image = [64, 64, 64] # len = 3 ≠ bn
with pytest.raises(ValueError, match="expected 'bn'=4, got 3"):
FuyuImagePatchInputs(
flat_data=flat_data,
patches_per_image=patches_per_image,
)

View File

@ -19,7 +19,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict
from typing import Annotated, Literal, Optional
import torch
import torch.nn as nn
@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
@ -50,18 +51,24 @@ _IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019
class FuyuImagePatchInputs(TypedDict):
type: Literal["image_patches"]
flat_data: torch.Tensor
class FuyuImagePatchInputs(TensorSchema):
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
Dimensions:
- bn: Batch size * number of images
- fn: Num channels * patch_size_x * patch_size_y
"""
patches_per_image: list[int]
type: Literal["image_patches"] = "image_patches"
flat_data: Annotated[
torch.Tensor,
TensorShape("bn", "fn"),
]
patches_per_image: Annotated[list[int], TensorShape("bn")]
"""
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""
@ -297,42 +304,18 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.patch_size
num_channels = self.config.num_channels
expected_dims = num_channels * h * w
def _validate_shape(d: torch.Tensor):
actual_dims = d.size(-1)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f"per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data.to(self.vision_embed_tokens.weight.dtype)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
image_patches = kwargs.pop("image_patches", None)
if image_patches is not None:
if not isinstance(image_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}")
image_patches_flat = flatten_bn(image_patches)
flat_data = flatten_bn(image_patches, concat=True).data.to(
self.vision_embed_tokens.weight.dtype)
return FuyuImagePatchInputs(
type="image_patches",
flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)),
flat_data=flat_data,
patches_per_image=[x.size(0) for x in image_patches_flat],
resolve_bindings={"fn": self.image_feature_size},
)
return None

View File

@ -86,9 +86,6 @@ class TensorSchema:
expected_shape: tuple[Union[int, str], ...],
dynamic_dims: set[str, ...]) -> tuple[int, ...]:
"""Validate a list/tuple of tensors and return the actual shape."""
if not value:
raise ValueError(f"{field_name} is an empty list")
# Ensure all tensors in the list have the same
# shape, besides dynamic dimensions
first = value[0]
@ -117,6 +114,7 @@ class TensorSchema:
int],
dynamic_dims: set[str, ...]) -> None:
"""Validate that the actual tensor shape matches the expected shape."""
if len(actual_shape) != len(expected_shape):
raise ValueError(f"{field_name} has rank {len(actual_shape)} "
f"but expected {len(expected_shape)}")
@ -160,12 +158,11 @@ class TensorSchema:
# Skip validation when Union contains None
if type(None) in args:
continue
# If not optional, raise error
# Otherwise field is required, raise error
raise ValueError(f"Required field '{field_name}' is missing")
# Field exists, proceed with validation
value = getattr(self, field_name)
if get_origin(field_type) is not None:
args = get_args(field_type)
@ -173,13 +170,23 @@ class TensorSchema:
if isinstance(arg, TensorShape):
expected_shape = arg.resolve(**self._resolve_bindings)
if isinstance(value, (list, tuple)):
actual_shape = self._validate_nested_tensors(
value, field_name, expected_shape,
arg.dynamic_dims)
# list/tuple of Tensors → shape = (len(value), ...)
if value and isinstance(value[0], torch.Tensor):
actual_shape = self._validate_nested_tensors(
value, field_name, expected_shape,
arg.dynamic_dims)
elif value:
# list/tuple of scalars → shape = (len(value),)
actual_shape = (len(value), )
else:
raise ValueError(
f"{field_name} is an empty list")
# Tensor → shape = tensor.shape
elif isinstance(value, torch.Tensor):
actual_shape = value.shape
# Otherwise, it's an unsupported type
else:
type_names = []
for arg in args: