Migrate FuyuImagePatchInputs to TensorSchema (#21662)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-07-26 19:34:14 -07:00 committed by GitHub
parent 0b8caf9095
commit 3339cba3ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 44 deletions

View File

@ -4,6 +4,7 @@
import pytest import pytest
import torch import torch
from vllm.model_executor.models.fuyu import FuyuImagePatchInputs
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
@ -124,3 +125,24 @@ def test_tensor_schema_with_invalid_resolve_binding_dims():
"w": 336 "w": 336
}, },
) )
def test_tensor_schema_with_list_of_symbolic_dim():
flat_data = torch.stack([torch.randn(768) for _ in range(3)]) # (bn=3, fn)
patches_per_image = [64, 64, 64] # len = bn = 3
FuyuImagePatchInputs(
flat_data=flat_data,
patches_per_image=patches_per_image,
)
def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length():
flat_data = torch.stack([torch.randn(768) for _ in range(4)]) # (bn=4, fn)
patches_per_image = [64, 64, 64] # len = 3 ≠ bn
with pytest.raises(ValueError, match="expected 'bn'=4, got 3"):
FuyuImagePatchInputs(
flat_data=flat_data,
patches_per_image=patches_per_image,
)

View File

@ -19,7 +19,7 @@
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict from typing import Annotated, Literal, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
@ -50,18 +51,24 @@ _IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019 _NEWLINE_TOKEN_ID = 71019
class FuyuImagePatchInputs(TypedDict): class FuyuImagePatchInputs(TensorSchema):
type: Literal["image_patches"]
flat_data: torch.Tensor
""" """
Shape: Dimensions:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` - bn: Batch size * number of images
- fn: Num channels * patch_size_x * patch_size_y
""" """
patches_per_image: list[int] type: Literal["image_patches"] = "image_patches"
flat_data: Annotated[
torch.Tensor,
TensorShape("bn", "fn"),
]
patches_per_image: Annotated[list[int], TensorShape("bn")]
""" """
The number of total patches for each image in the batch. The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`. flattened just like `flat_data`.
""" """
@ -297,42 +304,18 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.patch_size
num_channels = self.config.num_channels
expected_dims = num_channels * h * w
def _validate_shape(d: torch.Tensor):
actual_dims = d.size(-1)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f"per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data.to(self.vision_embed_tokens.weight.dtype)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]: self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
image_patches = kwargs.pop("image_patches", None) image_patches = kwargs.pop("image_patches", None)
if image_patches is not None: if image_patches is not None:
if not isinstance(image_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}")
image_patches_flat = flatten_bn(image_patches) image_patches_flat = flatten_bn(image_patches)
flat_data = flatten_bn(image_patches, concat=True).data.to(
self.vision_embed_tokens.weight.dtype)
return FuyuImagePatchInputs( return FuyuImagePatchInputs(
type="image_patches", type="image_patches",
flat_data=self._validate_pixel_values( flat_data=flat_data,
flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat], patches_per_image=[x.size(0) for x in image_patches_flat],
resolve_bindings={"fn": self.image_feature_size},
) )
return None return None

View File

@ -86,9 +86,6 @@ class TensorSchema:
expected_shape: tuple[Union[int, str], ...], expected_shape: tuple[Union[int, str], ...],
dynamic_dims: set[str, ...]) -> tuple[int, ...]: dynamic_dims: set[str, ...]) -> tuple[int, ...]:
"""Validate a list/tuple of tensors and return the actual shape.""" """Validate a list/tuple of tensors and return the actual shape."""
if not value:
raise ValueError(f"{field_name} is an empty list")
# Ensure all tensors in the list have the same # Ensure all tensors in the list have the same
# shape, besides dynamic dimensions # shape, besides dynamic dimensions
first = value[0] first = value[0]
@ -117,6 +114,7 @@ class TensorSchema:
int], int],
dynamic_dims: set[str, ...]) -> None: dynamic_dims: set[str, ...]) -> None:
"""Validate that the actual tensor shape matches the expected shape.""" """Validate that the actual tensor shape matches the expected shape."""
if len(actual_shape) != len(expected_shape): if len(actual_shape) != len(expected_shape):
raise ValueError(f"{field_name} has rank {len(actual_shape)} " raise ValueError(f"{field_name} has rank {len(actual_shape)} "
f"but expected {len(expected_shape)}") f"but expected {len(expected_shape)}")
@ -160,12 +158,11 @@ class TensorSchema:
# Skip validation when Union contains None # Skip validation when Union contains None
if type(None) in args: if type(None) in args:
continue continue
# If not optional, raise error # Otherwise field is required, raise error
raise ValueError(f"Required field '{field_name}' is missing") raise ValueError(f"Required field '{field_name}' is missing")
# Field exists, proceed with validation # Field exists, proceed with validation
value = getattr(self, field_name) value = getattr(self, field_name)
if get_origin(field_type) is not None: if get_origin(field_type) is not None:
args = get_args(field_type) args = get_args(field_type)
@ -173,13 +170,23 @@ class TensorSchema:
if isinstance(arg, TensorShape): if isinstance(arg, TensorShape):
expected_shape = arg.resolve(**self._resolve_bindings) expected_shape = arg.resolve(**self._resolve_bindings)
if isinstance(value, (list, tuple)): if isinstance(value, (list, tuple)):
actual_shape = self._validate_nested_tensors( # list/tuple of Tensors → shape = (len(value), ...)
value, field_name, expected_shape, if value and isinstance(value[0], torch.Tensor):
arg.dynamic_dims) actual_shape = self._validate_nested_tensors(
value, field_name, expected_shape,
arg.dynamic_dims)
elif value:
# list/tuple of scalars → shape = (len(value),)
actual_shape = (len(value), )
else:
raise ValueError(
f"{field_name} is an empty list")
# Tensor → shape = tensor.shape
elif isinstance(value, torch.Tensor): elif isinstance(value, torch.Tensor):
actual_shape = value.shape actual_shape = value.shape
# Otherwise, it's an unsupported type
else: else:
type_names = [] type_names = []
for arg in args: for arg in args: