From 965bc71b0445d7010bd40b3808d423d49ee68e58 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Thu, 24 Jul 2025 21:43:52 -0700 Subject: [PATCH] Integrate TensorSchema with shape validation for Phi3VImagePixelInputs (#21232) Signed-off-by: Benji Beck --- tests/standalone_tests/test_tensor_schema.py | 126 +++++++++++ vllm/model_executor/models/phi3v.py | 114 ++++------ vllm/utils/tensor_schema.py | 210 +++++++++++++++++++ 3 files changed, 375 insertions(+), 75 deletions(-) create mode 100644 tests/standalone_tests/test_tensor_schema.py create mode 100644 vllm/utils/tensor_schema.py diff --git a/tests/standalone_tests/test_tensor_schema.py b/tests/standalone_tests/test_tensor_schema.py new file mode 100644 index 000000000000..c5b77bb09bbb --- /dev/null +++ b/tests/standalone_tests/test_tensor_schema.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs + + +def test_tensor_schema_valid_tensor(): + Phi3VImagePixelInputs( + data=torch.randn(16, 64, 3, 32, 32), + image_sizes=torch.randint(0, 256, (16, 2)), + ) + + +def test_tensor_schema_optional_fields(): + Phi3VImagePixelInputs( + data=torch.randn(16, 64, 3, 32, 32), + image_sizes=None, + ) + + Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), ) + + +def test_tensor_schema_constant_dim_failure(): + with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"): + Phi3VImagePixelInputs( + data=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4 + image_sizes=torch.randint(0, 256, (16, 2)), + ) + + +def test_tensor_schema_symbolic_dim_mismatch(): + with pytest.raises(ValueError, match="expected 'bn'=12, got 16"): + Phi3VImagePixelInputs( + data=torch.randn(12, 64, 3, 32, 32), + image_sizes=torch.randint(0, 256, (16, 2)), + ) + + +def test_tensor_schema_list_tensor_valid(): + Phi3VImagePixelInputs( + data=[torch.randn(64, 3, 32, 32) for _ in range(16)], + image_sizes=torch.randint(0, 256, (16, 2)), + ) + + +def test_tensor_schema_variable_patch_counts_valid(): + # Each image has a different number of patches (p) + # Each tensor has shape (p, 3, 32, 32) + data = [ + torch.randn(16, 3, 32, 32), # p = 16 + torch.randn(32, 3, 32, 32), # p = 32 + torch.randn(64, 3, 32, 32), # p = 64 + ] + image_sizes = torch.randint(0, 256, (3, 2)) # bn = 3 + Phi3VImagePixelInputs( + data=data, + image_sizes=image_sizes, + ) + + +def test_tensor_schema_tuple_tensor_valid(): + Phi3VImagePixelInputs( + data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)), + image_sizes=torch.randint(0, 256, (16, 2)), + ) + + +def test_tensor_schema_inconsistent_shapes_in_list(): + with pytest.raises(ValueError, match="contains inconsistent shapes"): + Phi3VImagePixelInputs( + data=[torch.randn(64, 3, 32, 32), + torch.randn(64, 3, 16, 16)] + + [torch.randn(64, 3, 32, 32) for _ in range(14)], + image_sizes=torch.randint(0, 256, (16, 2)), + ) + + +def test_tensor_schema_empty_list(): + with pytest.raises(ValueError, match="is an empty list"): + Phi3VImagePixelInputs( + data=[], + image_sizes=torch.randint(0, 256, (0, 2)), + ) + + +def test_tensor_schema_validation_disabled_skips_shape_check(): + # This should NOT raise, because validation is turned off + # This would normally fail (dim[2] should be 3, not 4) + Phi3VImagePixelInputs( + data=torch.randn(16, 64, 4, 32, 32), + image_sizes=torch.randint(0, 256, (16, 2)), + validate=False, + ) + + +def test_tensor_schema_with_valid_resolve_binding_dims(): + data = torch.randn(16, 64, 3, 336, 336) # h=336, w=336 + image_sizes = torch.randint(0, 256, (16, 2)) + + Phi3VImagePixelInputs( + data=data, + image_sizes=image_sizes, + resolve_bindings={ + "h": 336, + "w": 336 + }, + ) + + +def test_tensor_schema_with_invalid_resolve_binding_dims(): + data = torch.randn(16, 64, 3, 36, 36) # h=36, w=36 + image_sizes = torch.randint(0, 256, (16, 2)) + + # Should raise because 'h' and 'w' don't match resolve bindings + with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"): + Phi3VImagePixelInputs( + data=data, + image_sizes=image_sizes, + resolve_bindings={ + "h": 336, + "w": 336 + }, + ) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 745cf7aa2512..aa739f22fd7b 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import regex as re import torch @@ -45,6 +45,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, @@ -93,32 +94,42 @@ def _init_img_processor(hf_config: PretrainedConfig, return img_processor -class Phi3VImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: Union[torch.Tensor, list[torch.Tensor]] +class Phi3VImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + Dimensions: + - b: Batch size + - n: Number of images + - p: Number of patches + - h: Height of each patch + - w: Width of each patch """ - image_sizes: torch.Tensor + type: Literal["pixel_values", "image_embeds"] = "pixel_values" + + # Supports either a stacked tensor or a list of (p, 3, h, w) tensors + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # 'p' may vary across items + ] + + # Stacked tensor with height and width for each image + image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + + +class Phi3VImageEmbeddingInputs(TensorSchema): """ - Shape: `(batch_size * num_images, 2)` - - This should be in `(height, width)` format. - """ - - -class Phi3VImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. + Dimensions: + - b: Batch size + - n: Number of images + - f: Image feature size (e.g., number of tokens per image) + - h: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "f", "h"), + ] Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs] @@ -563,44 +574,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - expected_dims = (2, ) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - f"The expected shape of image sizes per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - - def _validate_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("num_patches", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Phi3VImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -611,25 +584,16 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(image_sizes, (torch.Tensor, list)): - raise ValueError("Incorrect type of image sizes. " - f"Got type: {type(image_sizes)}") - return Phi3VImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(flatten_bn(pixel_values)), - image_sizes=self._validate_image_sizes( - flatten_bn(image_sizes, concat=True))) + data=flatten_bn(pixel_values), + image_sizes=flatten_bn(image_sizes, concat=True), + resolve_bindings={ + "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, + "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size + }) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return Phi3VImageEmbeddingInputs( type="image_embeds", data=flatten_bn(image_embeds), diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py new file mode 100644 index 000000000000..485a0a72ddca --- /dev/null +++ b/vllm/utils/tensor_schema.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Annotated, Any, Union, get_args, get_origin, get_type_hints + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class TensorShape: + + def __init__(self, + *dims: Union[int, str], + dynamic_dims: set[str, ...] = None) -> None: + self.dims = dims + self.dynamic_dims = dynamic_dims if dynamic_dims else set() + + def resolve(self, **bindings: dict[str, + int]) -> tuple[Union[int, str], ...]: + resolved = [] + for dim in self.dims: + if isinstance(dim, str) and dim in bindings: + resolved.append(bindings[dim]) + else: + resolved.append(dim) + return tuple(resolved) + + def __str__(self) -> str: + """Return a string representation of the tensor shape.""" + dim_strs = [] + for dim in self.dims: + if isinstance(dim, str): + if dim in self.dynamic_dims: + dim_strs.append( + f"{dim}*") # Mark dynamic dimensions with * + else: + dim_strs.append(dim) + else: + dim_strs.append(str(dim)) + return f"({', '.join(dim_strs)})" + + +class TensorSchema: + + def __init__(self, + *, + validate: bool = True, + resolve_bindings: dict[str, int] = None, + **kwargs: Any) -> None: + self._resolve_bindings = resolve_bindings if resolve_bindings else {} + + for key, value in kwargs.items(): + setattr(self, key, value) + + if validate: + self.validate() + + def __getitem__(self, item) -> Any: + return getattr(self, item) + + def _match_shape_with_dynamic(self, actual: tuple[int, ...], + reference: tuple[int, ...], + expected_shape: tuple[Union[int, str], ...], + dynamic_dims: set[str, ...]) -> bool: + if len(actual) != len(reference) or len(actual) > len(expected_shape): + return False + + for i, (a, r) in enumerate(zip(actual, reference)): + # When validating list inputs, we match shape suffixes only + # (e.g. "p", 3, "h", "w"), assuming the list length corresponds + # to the leading symbolic dim (e.g. "bn"). This allows comparing + # only the trailing dimensions of each element in the list. + dim = expected_shape[-len(actual) + i] + # Skip this dimension if it's marked dynamic + if dim in dynamic_dims: + continue + if a != r: + return False + return True + + def _validate_nested_tensors( + self, value: Union[list[torch.Tensor, ...], + tuple[torch.Tensor, ...]], field_name: str, + expected_shape: tuple[Union[int, str], ...], + dynamic_dims: set[str, ...]) -> tuple[int, ...]: + """Validate a list/tuple of tensors and return the actual shape.""" + if not value: + raise ValueError(f"{field_name} is an empty list") + + # Ensure all tensors in the list have the same + # shape, besides dynamic dimensions + first = value[0] + for i, v in enumerate(value): + if not isinstance(v, torch.Tensor): + raise ValueError(f"{field_name}[{i}] is not a " + f"torch.Tensor") + if not self._match_shape_with_dynamic( + v.shape, + first.shape, + expected_shape, + dynamic_dims, + ): + raise ValueError(f"{field_name} contains inconsistent " + f"shapes: {first.shape} vs {v.shape} " + f"at index {i}") + + # Treat the list as a stacked tensor: + # shape = (len(list), *tensor.shape) + return (len(value), ) + first.shape + + def _validate_tensor_shape_expected(self, actual_shape: tuple[int, ...], + expected_shape: tuple[Union[int, str], + ...], + field_name: str, shape_env: dict[str, + int], + dynamic_dims: set[str, ...]) -> None: + """Validate that the actual tensor shape matches the expected shape.""" + if len(actual_shape) != len(expected_shape): + raise ValueError(f"{field_name} has rank {len(actual_shape)} " + f"but expected {len(expected_shape)}") + + for i, dim in enumerate(expected_shape): + if dim in dynamic_dims: + continue + elif isinstance(dim, int): + if actual_shape[i] != dim: + raise ValueError(f"{field_name} dim[{i}] expected " + f"{dim}, got {actual_shape[i]}") + elif isinstance(dim, str): + if dim in shape_env: + if actual_shape[i] != shape_env[dim]: + raise ValueError(f"{field_name} dim[{i}] expected " + f"'{dim}'={shape_env[dim]}, got " + f"{actual_shape[i]}") + else: + shape_env[dim] = actual_shape[i] + else: + raise TypeError(f"{field_name} dim[{i}] has unsupported " + f"type: {type(dim)}") + + def validate(self) -> None: + type_hints = get_type_hints(self.__class__, include_extras=True) + shape_env = {} + + for field_name, field_type in type_hints.items(): + # Check if field is missing + if (not hasattr(self, field_name) + or getattr(self, field_name) is None): + # Check if field is marked as optional + actual_type = field_type + if get_origin(field_type) is Annotated: + args = get_args(field_type) + actual_type = args[0] + + # Check arg was provided as Union + if get_origin(actual_type) is Union: + args = get_args(actual_type) + # Skip validation when Union contains None + if type(None) in args: + continue + # If not optional, raise error + raise ValueError(f"Required field '{field_name}' is missing") + + # Field exists, proceed with validation + value = getattr(self, field_name) + + if get_origin(field_type) is not None: + args = get_args(field_type) + + for arg in args: + if isinstance(arg, TensorShape): + expected_shape = arg.resolve(**self._resolve_bindings) + if isinstance(value, (list, tuple)): + actual_shape = self._validate_nested_tensors( + value, field_name, expected_shape, + arg.dynamic_dims) + + elif isinstance(value, torch.Tensor): + actual_shape = value.shape + + else: + type_names = [] + for arg in args: + if hasattr(arg, "__name__"): + type_names.append(str(arg.__name__)) + else: + type_names.append(str(arg)) + + expected_types = ", ".join(type_names) + raise ValueError( + f"{field_name} is not one of the expected " + f"types: {expected_types}") + + self._validate_tensor_shape_expected( + actual_shape, expected_shape, field_name, + shape_env, arg.dynamic_dims) + + def print_shapes(self) -> None: + """Print TensorShape annotations for debugging.""" + logger.debug("Shapes in %s:", self.__class__.__name__) + type_hints = get_type_hints(self.__class__, include_extras=True) + + for field_name, field_type in type_hints.items(): + if get_origin(field_type) is not None: + args = get_args(field_type) + for arg in args: + if isinstance(arg, TensorShape): + logger.debug(" %s: %s", field_name, str(arg))