Integrate TensorSchema with shape validation for Phi3VImagePixelInputs (#21232)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-07-24 21:43:52 -07:00 committed by GitHub
parent 807a328bb6
commit 965bc71b04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 375 additions and 75 deletions

View File

@ -0,0 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
def test_tensor_schema_valid_tensor():
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_optional_fields():
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3, 32, 32),
image_sizes=None,
)
Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), )
def test_tensor_schema_constant_dim_failure():
with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"):
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_symbolic_dim_mismatch():
with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
Phi3VImagePixelInputs(
data=torch.randn(12, 64, 3, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_list_tensor_valid():
Phi3VImagePixelInputs(
data=[torch.randn(64, 3, 32, 32) for _ in range(16)],
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_variable_patch_counts_valid():
# Each image has a different number of patches (p)
# Each tensor has shape (p, 3, 32, 32)
data = [
torch.randn(16, 3, 32, 32), # p = 16
torch.randn(32, 3, 32, 32), # p = 32
torch.randn(64, 3, 32, 32), # p = 64
]
image_sizes = torch.randint(0, 256, (3, 2)) # bn = 3
Phi3VImagePixelInputs(
data=data,
image_sizes=image_sizes,
)
def test_tensor_schema_tuple_tensor_valid():
Phi3VImagePixelInputs(
data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_inconsistent_shapes_in_list():
with pytest.raises(ValueError, match="contains inconsistent shapes"):
Phi3VImagePixelInputs(
data=[torch.randn(64, 3, 32, 32),
torch.randn(64, 3, 16, 16)] +
[torch.randn(64, 3, 32, 32) for _ in range(14)],
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_empty_list():
with pytest.raises(ValueError, match="is an empty list"):
Phi3VImagePixelInputs(
data=[],
image_sizes=torch.randint(0, 256, (0, 2)),
)
def test_tensor_schema_validation_disabled_skips_shape_check():
# This should NOT raise, because validation is turned off
# This would normally fail (dim[2] should be 3, not 4)
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 4, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)),
validate=False,
)
def test_tensor_schema_with_valid_resolve_binding_dims():
data = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
image_sizes = torch.randint(0, 256, (16, 2))
Phi3VImagePixelInputs(
data=data,
image_sizes=image_sizes,
resolve_bindings={
"h": 336,
"w": 336
},
)
def test_tensor_schema_with_invalid_resolve_binding_dims():
data = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
image_sizes = torch.randint(0, 256, (16, 2))
# Should raise because 'h' and 'w' don't match resolve bindings
with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"):
Phi3VImagePixelInputs(
data=data,
image_sizes=image_sizes,
resolve_bindings={
"h": 336,
"w": 336
},
)

View File

@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import regex as re
import torch
@ -45,6 +45,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
@ -93,32 +94,42 @@ def _init_img_processor(hf_config: PretrainedConfig,
return img_processor
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, list[torch.Tensor]]
class Phi3VImagePixelInputs(TensorSchema):
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
Dimensions:
- b: Batch size
- n: Number of images
- p: Number of patches
- h: Height of each patch
- w: Width of each patch
"""
image_sizes: torch.Tensor
type: Literal["pixel_values", "image_embeds"] = "pixel_values"
# Supports either a stacked tensor or a list of (p, 3, h, w) tensors
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
), # 'p' may vary across items
]
# Stacked tensor with height and width for each image
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
class Phi3VImageEmbeddingInputs(TensorSchema):
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
class Phi3VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
Dimensions:
- b: Batch size
- n: Number of images
- f: Image feature size (e.g., number of tokens per image)
- h: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "f", "h"),
]
Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
@ -563,44 +574,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, )
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Phi3VImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -611,25 +584,16 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)))
data=flatten_bn(pixel_values),
image_sizes=flatten_bn(image_sizes, concat=True),
resolve_bindings={
"h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
"w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
})
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Phi3VImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),

210
vllm/utils/tensor_schema.py Normal file
View File

@ -0,0 +1,210 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Annotated, Any, Union, get_args, get_origin, get_type_hints
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
class TensorShape:
def __init__(self,
*dims: Union[int, str],
dynamic_dims: set[str, ...] = None) -> None:
self.dims = dims
self.dynamic_dims = dynamic_dims if dynamic_dims else set()
def resolve(self, **bindings: dict[str,
int]) -> tuple[Union[int, str], ...]:
resolved = []
for dim in self.dims:
if isinstance(dim, str) and dim in bindings:
resolved.append(bindings[dim])
else:
resolved.append(dim)
return tuple(resolved)
def __str__(self) -> str:
"""Return a string representation of the tensor shape."""
dim_strs = []
for dim in self.dims:
if isinstance(dim, str):
if dim in self.dynamic_dims:
dim_strs.append(
f"{dim}*") # Mark dynamic dimensions with *
else:
dim_strs.append(dim)
else:
dim_strs.append(str(dim))
return f"({', '.join(dim_strs)})"
class TensorSchema:
def __init__(self,
*,
validate: bool = True,
resolve_bindings: dict[str, int] = None,
**kwargs: Any) -> None:
self._resolve_bindings = resolve_bindings if resolve_bindings else {}
for key, value in kwargs.items():
setattr(self, key, value)
if validate:
self.validate()
def __getitem__(self, item) -> Any:
return getattr(self, item)
def _match_shape_with_dynamic(self, actual: tuple[int, ...],
reference: tuple[int, ...],
expected_shape: tuple[Union[int, str], ...],
dynamic_dims: set[str, ...]) -> bool:
if len(actual) != len(reference) or len(actual) > len(expected_shape):
return False
for i, (a, r) in enumerate(zip(actual, reference)):
# When validating list inputs, we match shape suffixes only
# (e.g. "p", 3, "h", "w"), assuming the list length corresponds
# to the leading symbolic dim (e.g. "bn"). This allows comparing
# only the trailing dimensions of each element in the list.
dim = expected_shape[-len(actual) + i]
# Skip this dimension if it's marked dynamic
if dim in dynamic_dims:
continue
if a != r:
return False
return True
def _validate_nested_tensors(
self, value: Union[list[torch.Tensor, ...],
tuple[torch.Tensor, ...]], field_name: str,
expected_shape: tuple[Union[int, str], ...],
dynamic_dims: set[str, ...]) -> tuple[int, ...]:
"""Validate a list/tuple of tensors and return the actual shape."""
if not value:
raise ValueError(f"{field_name} is an empty list")
# Ensure all tensors in the list have the same
# shape, besides dynamic dimensions
first = value[0]
for i, v in enumerate(value):
if not isinstance(v, torch.Tensor):
raise ValueError(f"{field_name}[{i}] is not a "
f"torch.Tensor")
if not self._match_shape_with_dynamic(
v.shape,
first.shape,
expected_shape,
dynamic_dims,
):
raise ValueError(f"{field_name} contains inconsistent "
f"shapes: {first.shape} vs {v.shape} "
f"at index {i}")
# Treat the list as a stacked tensor:
# shape = (len(list), *tensor.shape)
return (len(value), ) + first.shape
def _validate_tensor_shape_expected(self, actual_shape: tuple[int, ...],
expected_shape: tuple[Union[int, str],
...],
field_name: str, shape_env: dict[str,
int],
dynamic_dims: set[str, ...]) -> None:
"""Validate that the actual tensor shape matches the expected shape."""
if len(actual_shape) != len(expected_shape):
raise ValueError(f"{field_name} has rank {len(actual_shape)} "
f"but expected {len(expected_shape)}")
for i, dim in enumerate(expected_shape):
if dim in dynamic_dims:
continue
elif isinstance(dim, int):
if actual_shape[i] != dim:
raise ValueError(f"{field_name} dim[{i}] expected "
f"{dim}, got {actual_shape[i]}")
elif isinstance(dim, str):
if dim in shape_env:
if actual_shape[i] != shape_env[dim]:
raise ValueError(f"{field_name} dim[{i}] expected "
f"'{dim}'={shape_env[dim]}, got "
f"{actual_shape[i]}")
else:
shape_env[dim] = actual_shape[i]
else:
raise TypeError(f"{field_name} dim[{i}] has unsupported "
f"type: {type(dim)}")
def validate(self) -> None:
type_hints = get_type_hints(self.__class__, include_extras=True)
shape_env = {}
for field_name, field_type in type_hints.items():
# Check if field is missing
if (not hasattr(self, field_name)
or getattr(self, field_name) is None):
# Check if field is marked as optional
actual_type = field_type
if get_origin(field_type) is Annotated:
args = get_args(field_type)
actual_type = args[0]
# Check arg was provided as Union
if get_origin(actual_type) is Union:
args = get_args(actual_type)
# Skip validation when Union contains None
if type(None) in args:
continue
# If not optional, raise error
raise ValueError(f"Required field '{field_name}' is missing")
# Field exists, proceed with validation
value = getattr(self, field_name)
if get_origin(field_type) is not None:
args = get_args(field_type)
for arg in args:
if isinstance(arg, TensorShape):
expected_shape = arg.resolve(**self._resolve_bindings)
if isinstance(value, (list, tuple)):
actual_shape = self._validate_nested_tensors(
value, field_name, expected_shape,
arg.dynamic_dims)
elif isinstance(value, torch.Tensor):
actual_shape = value.shape
else:
type_names = []
for arg in args:
if hasattr(arg, "__name__"):
type_names.append(str(arg.__name__))
else:
type_names.append(str(arg))
expected_types = ", ".join(type_names)
raise ValueError(
f"{field_name} is not one of the expected "
f"types: {expected_types}")
self._validate_tensor_shape_expected(
actual_shape, expected_shape, field_name,
shape_env, arg.dynamic_dims)
def print_shapes(self) -> None:
"""Print TensorShape annotations for debugging."""
logger.debug("Shapes in %s:", self.__class__.__name__)
type_hints = get_type_hints(self.__class__, include_extras=True)
for field_name, field_type in type_hints.items():
if get_origin(field_type) is not None:
args = get_args(field_type)
for arg in args:
if isinstance(arg, TensorShape):
logger.debug(" %s: %s", field_name, str(arg))