From 3339cba3ff4de0d4a516b31076dc73673510c227 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sat, 26 Jul 2025 19:34:14 -0700 Subject: [PATCH] Migrate FuyuImagePatchInputs to TensorSchema (#21662) Signed-off-by: Benji Beck --- tests/standalone_tests/test_tensor_schema.py | 22 ++++++++ vllm/model_executor/models/fuyu.py | 55 +++++++------------- vllm/utils/tensor_schema.py | 23 +++++--- 3 files changed, 56 insertions(+), 44 deletions(-) diff --git a/tests/standalone_tests/test_tensor_schema.py b/tests/standalone_tests/test_tensor_schema.py index c5b77bb09bbb..b276b88fac1f 100644 --- a/tests/standalone_tests/test_tensor_schema.py +++ b/tests/standalone_tests/test_tensor_schema.py @@ -4,6 +4,7 @@ import pytest import torch +from vllm.model_executor.models.fuyu import FuyuImagePatchInputs from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs @@ -124,3 +125,24 @@ def test_tensor_schema_with_invalid_resolve_binding_dims(): "w": 336 }, ) + + +def test_tensor_schema_with_list_of_symbolic_dim(): + flat_data = torch.stack([torch.randn(768) for _ in range(3)]) # (bn=3, fn) + patches_per_image = [64, 64, 64] # len = bn = 3 + + FuyuImagePatchInputs( + flat_data=flat_data, + patches_per_image=patches_per_image, + ) + + +def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length(): + flat_data = torch.stack([torch.randn(768) for _ in range(4)]) # (bn=4, fn) + patches_per_image = [64, 64, 64] # len = 3 ≠ bn + + with pytest.raises(ValueError, match="expected 'bn'=4, got 3"): + FuyuImagePatchInputs( + flat_data=flat_data, + patches_per_image=patches_per_image, + ) \ No newline at end of file diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 558d4fbb4de1..4fb571122abb 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -19,7 +19,7 @@ """ PyTorch Fuyu model.""" import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict +from typing import Annotated, Literal, Optional import torch import torch.nn as nn @@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, @@ -50,18 +51,24 @@ _IMAGE_TOKEN_ID = 71011 _NEWLINE_TOKEN_ID = 71019 -class FuyuImagePatchInputs(TypedDict): - type: Literal["image_patches"] - flat_data: torch.Tensor +class FuyuImagePatchInputs(TensorSchema): """ - Shape: - `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` + Dimensions: + - bn: Batch size * number of images + - fn: Num channels * patch_size_x * patch_size_y """ - patches_per_image: list[int] + type: Literal["image_patches"] = "image_patches" + + flat_data: Annotated[ + torch.Tensor, + TensorShape("bn", "fn"), + ] + + patches_per_image: Annotated[list[int], TensorShape("bn")] """ The number of total patches for each image in the batch. - + This is used to split the embeddings which has the first two dimensions flattened just like `flat_data`. """ @@ -297,42 +304,18 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - h = w = self.config.patch_size - num_channels = self.config.num_channels - expected_dims = num_channels * h * w - - def _validate_shape(d: torch.Tensor): - actual_dims = d.size(-1) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f"per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data.to(self.vision_embed_tokens.weight.dtype) - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[FuyuImagePatchInputs]: image_patches = kwargs.pop("image_patches", None) if image_patches is not None: - if not isinstance(image_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image patches. " - f"Got type: {type(image_patches)}") - image_patches_flat = flatten_bn(image_patches) - + flat_data = flatten_bn(image_patches, concat=True).data.to( + self.vision_embed_tokens.weight.dtype) return FuyuImagePatchInputs( type="image_patches", - flat_data=self._validate_pixel_values( - flatten_bn(image_patches_flat, concat=True)), + flat_data=flat_data, patches_per_image=[x.size(0) for x in image_patches_flat], + resolve_bindings={"fn": self.image_feature_size}, ) return None diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py index 485a0a72ddca..343df71e1058 100644 --- a/vllm/utils/tensor_schema.py +++ b/vllm/utils/tensor_schema.py @@ -86,9 +86,6 @@ class TensorSchema: expected_shape: tuple[Union[int, str], ...], dynamic_dims: set[str, ...]) -> tuple[int, ...]: """Validate a list/tuple of tensors and return the actual shape.""" - if not value: - raise ValueError(f"{field_name} is an empty list") - # Ensure all tensors in the list have the same # shape, besides dynamic dimensions first = value[0] @@ -117,6 +114,7 @@ class TensorSchema: int], dynamic_dims: set[str, ...]) -> None: """Validate that the actual tensor shape matches the expected shape.""" + if len(actual_shape) != len(expected_shape): raise ValueError(f"{field_name} has rank {len(actual_shape)} " f"but expected {len(expected_shape)}") @@ -160,12 +158,11 @@ class TensorSchema: # Skip validation when Union contains None if type(None) in args: continue - # If not optional, raise error + # Otherwise field is required, raise error raise ValueError(f"Required field '{field_name}' is missing") # Field exists, proceed with validation value = getattr(self, field_name) - if get_origin(field_type) is not None: args = get_args(field_type) @@ -173,13 +170,23 @@ class TensorSchema: if isinstance(arg, TensorShape): expected_shape = arg.resolve(**self._resolve_bindings) if isinstance(value, (list, tuple)): - actual_shape = self._validate_nested_tensors( - value, field_name, expected_shape, - arg.dynamic_dims) + # list/tuple of Tensors → shape = (len(value), ...) + if value and isinstance(value[0], torch.Tensor): + actual_shape = self._validate_nested_tensors( + value, field_name, expected_shape, + arg.dynamic_dims) + elif value: + # list/tuple of scalars → shape = (len(value),) + actual_shape = (len(value), ) + else: + raise ValueError( + f"{field_name} is an empty list") + # Tensor → shape = tensor.shape elif isinstance(value, torch.Tensor): actual_shape = value.shape + # Otherwise, it's an unsupported type else: type_names = [] for arg in args: