mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
Integrate TensorSchema with shape validation for Phi3VImagePixelInputs (#21232)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
807a328bb6
commit
965bc71b04
126
tests/standalone_tests/test_tensor_schema.py
Normal file
126
tests/standalone_tests/test_tensor_schema.py
Normal file
@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
|
||||
|
||||
|
||||
def test_tensor_schema_valid_tensor():
|
||||
Phi3VImagePixelInputs(
|
||||
data=torch.randn(16, 64, 3, 32, 32),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_optional_fields():
|
||||
Phi3VImagePixelInputs(
|
||||
data=torch.randn(16, 64, 3, 32, 32),
|
||||
image_sizes=None,
|
||||
)
|
||||
|
||||
Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), )
|
||||
|
||||
|
||||
def test_tensor_schema_constant_dim_failure():
|
||||
with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"):
|
||||
Phi3VImagePixelInputs(
|
||||
data=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_symbolic_dim_mismatch():
|
||||
with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
|
||||
Phi3VImagePixelInputs(
|
||||
data=torch.randn(12, 64, 3, 32, 32),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_list_tensor_valid():
|
||||
Phi3VImagePixelInputs(
|
||||
data=[torch.randn(64, 3, 32, 32) for _ in range(16)],
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_variable_patch_counts_valid():
|
||||
# Each image has a different number of patches (p)
|
||||
# Each tensor has shape (p, 3, 32, 32)
|
||||
data = [
|
||||
torch.randn(16, 3, 32, 32), # p = 16
|
||||
torch.randn(32, 3, 32, 32), # p = 32
|
||||
torch.randn(64, 3, 32, 32), # p = 64
|
||||
]
|
||||
image_sizes = torch.randint(0, 256, (3, 2)) # bn = 3
|
||||
Phi3VImagePixelInputs(
|
||||
data=data,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_tuple_tensor_valid():
|
||||
Phi3VImagePixelInputs(
|
||||
data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_inconsistent_shapes_in_list():
|
||||
with pytest.raises(ValueError, match="contains inconsistent shapes"):
|
||||
Phi3VImagePixelInputs(
|
||||
data=[torch.randn(64, 3, 32, 32),
|
||||
torch.randn(64, 3, 16, 16)] +
|
||||
[torch.randn(64, 3, 32, 32) for _ in range(14)],
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_empty_list():
|
||||
with pytest.raises(ValueError, match="is an empty list"):
|
||||
Phi3VImagePixelInputs(
|
||||
data=[],
|
||||
image_sizes=torch.randint(0, 256, (0, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_validation_disabled_skips_shape_check():
|
||||
# This should NOT raise, because validation is turned off
|
||||
# This would normally fail (dim[2] should be 3, not 4)
|
||||
Phi3VImagePixelInputs(
|
||||
data=torch.randn(16, 64, 4, 32, 32),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
validate=False,
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_with_valid_resolve_binding_dims():
|
||||
data = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
|
||||
image_sizes = torch.randint(0, 256, (16, 2))
|
||||
|
||||
Phi3VImagePixelInputs(
|
||||
data=data,
|
||||
image_sizes=image_sizes,
|
||||
resolve_bindings={
|
||||
"h": 336,
|
||||
"w": 336
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_with_invalid_resolve_binding_dims():
|
||||
data = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
|
||||
image_sizes = torch.randint(0, 256, (16, 2))
|
||||
|
||||
# Should raise because 'h' and 'w' don't match resolve bindings
|
||||
with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"):
|
||||
Phi3VImagePixelInputs(
|
||||
data=data,
|
||||
image_sizes=image_sizes,
|
||||
resolve_bindings={
|
||||
"h": 336,
|
||||
"w": 336
|
||||
},
|
||||
)
|
||||
@ -16,7 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@ -45,6 +45,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
||||
@ -93,32 +94,42 @@ def _init_img_processor(hf_config: PretrainedConfig,
|
||||
return img_processor
|
||||
|
||||
|
||||
class Phi3VImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
class Phi3VImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- n: Number of images
|
||||
- p: Number of patches
|
||||
- h: Height of each patch
|
||||
- w: Width of each patch
|
||||
"""
|
||||
|
||||
image_sizes: torch.Tensor
|
||||
type: Literal["pixel_values", "image_embeds"] = "pixel_values"
|
||||
|
||||
# Supports either a stacked tensor or a list of (p, 3, h, w) tensors
|
||||
data: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
|
||||
), # 'p' may vary across items
|
||||
]
|
||||
|
||||
# Stacked tensor with height and width for each image
|
||||
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
|
||||
|
||||
|
||||
class Phi3VImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Shape: `(batch_size * num_images, 2)`
|
||||
|
||||
This should be in `(height, width)` format.
|
||||
"""
|
||||
|
||||
|
||||
class Phi3VImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- n: Number of images
|
||||
- f: Image feature size (e.g., number of tokens per image)
|
||||
- h: Hidden size (must match language model backbone)
|
||||
"""
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
data: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "f", "h"),
|
||||
]
|
||||
|
||||
|
||||
Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
|
||||
@ -563,44 +574,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
expected_dims = (2, )
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape)
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = str(expected_dims)
|
||||
raise ValueError(
|
||||
f"The expected shape of image sizes per image per batch "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _validate_pixel_values(
|
||||
self, data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
|
||||
h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("num_patches", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values per image per batch "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Phi3VImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -611,25 +584,16 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(image_sizes, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
return Phi3VImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(flatten_bn(pixel_values)),
|
||||
image_sizes=self._validate_image_sizes(
|
||||
flatten_bn(image_sizes, concat=True)))
|
||||
data=flatten_bn(pixel_values),
|
||||
image_sizes=flatten_bn(image_sizes, concat=True),
|
||||
resolve_bindings={
|
||||
"h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
|
||||
"w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
|
||||
})
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return Phi3VImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
|
||||
210
vllm/utils/tensor_schema.py
Normal file
210
vllm/utils/tensor_schema.py
Normal file
@ -0,0 +1,210 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Annotated, Any, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TensorShape:
|
||||
|
||||
def __init__(self,
|
||||
*dims: Union[int, str],
|
||||
dynamic_dims: set[str, ...] = None) -> None:
|
||||
self.dims = dims
|
||||
self.dynamic_dims = dynamic_dims if dynamic_dims else set()
|
||||
|
||||
def resolve(self, **bindings: dict[str,
|
||||
int]) -> tuple[Union[int, str], ...]:
|
||||
resolved = []
|
||||
for dim in self.dims:
|
||||
if isinstance(dim, str) and dim in bindings:
|
||||
resolved.append(bindings[dim])
|
||||
else:
|
||||
resolved.append(dim)
|
||||
return tuple(resolved)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the tensor shape."""
|
||||
dim_strs = []
|
||||
for dim in self.dims:
|
||||
if isinstance(dim, str):
|
||||
if dim in self.dynamic_dims:
|
||||
dim_strs.append(
|
||||
f"{dim}*") # Mark dynamic dimensions with *
|
||||
else:
|
||||
dim_strs.append(dim)
|
||||
else:
|
||||
dim_strs.append(str(dim))
|
||||
return f"({', '.join(dim_strs)})"
|
||||
|
||||
|
||||
class TensorSchema:
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
validate: bool = True,
|
||||
resolve_bindings: dict[str, int] = None,
|
||||
**kwargs: Any) -> None:
|
||||
self._resolve_bindings = resolve_bindings if resolve_bindings else {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
if validate:
|
||||
self.validate()
|
||||
|
||||
def __getitem__(self, item) -> Any:
|
||||
return getattr(self, item)
|
||||
|
||||
def _match_shape_with_dynamic(self, actual: tuple[int, ...],
|
||||
reference: tuple[int, ...],
|
||||
expected_shape: tuple[Union[int, str], ...],
|
||||
dynamic_dims: set[str, ...]) -> bool:
|
||||
if len(actual) != len(reference) or len(actual) > len(expected_shape):
|
||||
return False
|
||||
|
||||
for i, (a, r) in enumerate(zip(actual, reference)):
|
||||
# When validating list inputs, we match shape suffixes only
|
||||
# (e.g. "p", 3, "h", "w"), assuming the list length corresponds
|
||||
# to the leading symbolic dim (e.g. "bn"). This allows comparing
|
||||
# only the trailing dimensions of each element in the list.
|
||||
dim = expected_shape[-len(actual) + i]
|
||||
# Skip this dimension if it's marked dynamic
|
||||
if dim in dynamic_dims:
|
||||
continue
|
||||
if a != r:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _validate_nested_tensors(
|
||||
self, value: Union[list[torch.Tensor, ...],
|
||||
tuple[torch.Tensor, ...]], field_name: str,
|
||||
expected_shape: tuple[Union[int, str], ...],
|
||||
dynamic_dims: set[str, ...]) -> tuple[int, ...]:
|
||||
"""Validate a list/tuple of tensors and return the actual shape."""
|
||||
if not value:
|
||||
raise ValueError(f"{field_name} is an empty list")
|
||||
|
||||
# Ensure all tensors in the list have the same
|
||||
# shape, besides dynamic dimensions
|
||||
first = value[0]
|
||||
for i, v in enumerate(value):
|
||||
if not isinstance(v, torch.Tensor):
|
||||
raise ValueError(f"{field_name}[{i}] is not a "
|
||||
f"torch.Tensor")
|
||||
if not self._match_shape_with_dynamic(
|
||||
v.shape,
|
||||
first.shape,
|
||||
expected_shape,
|
||||
dynamic_dims,
|
||||
):
|
||||
raise ValueError(f"{field_name} contains inconsistent "
|
||||
f"shapes: {first.shape} vs {v.shape} "
|
||||
f"at index {i}")
|
||||
|
||||
# Treat the list as a stacked tensor:
|
||||
# shape = (len(list), *tensor.shape)
|
||||
return (len(value), ) + first.shape
|
||||
|
||||
def _validate_tensor_shape_expected(self, actual_shape: tuple[int, ...],
|
||||
expected_shape: tuple[Union[int, str],
|
||||
...],
|
||||
field_name: str, shape_env: dict[str,
|
||||
int],
|
||||
dynamic_dims: set[str, ...]) -> None:
|
||||
"""Validate that the actual tensor shape matches the expected shape."""
|
||||
if len(actual_shape) != len(expected_shape):
|
||||
raise ValueError(f"{field_name} has rank {len(actual_shape)} "
|
||||
f"but expected {len(expected_shape)}")
|
||||
|
||||
for i, dim in enumerate(expected_shape):
|
||||
if dim in dynamic_dims:
|
||||
continue
|
||||
elif isinstance(dim, int):
|
||||
if actual_shape[i] != dim:
|
||||
raise ValueError(f"{field_name} dim[{i}] expected "
|
||||
f"{dim}, got {actual_shape[i]}")
|
||||
elif isinstance(dim, str):
|
||||
if dim in shape_env:
|
||||
if actual_shape[i] != shape_env[dim]:
|
||||
raise ValueError(f"{field_name} dim[{i}] expected "
|
||||
f"'{dim}'={shape_env[dim]}, got "
|
||||
f"{actual_shape[i]}")
|
||||
else:
|
||||
shape_env[dim] = actual_shape[i]
|
||||
else:
|
||||
raise TypeError(f"{field_name} dim[{i}] has unsupported "
|
||||
f"type: {type(dim)}")
|
||||
|
||||
def validate(self) -> None:
|
||||
type_hints = get_type_hints(self.__class__, include_extras=True)
|
||||
shape_env = {}
|
||||
|
||||
for field_name, field_type in type_hints.items():
|
||||
# Check if field is missing
|
||||
if (not hasattr(self, field_name)
|
||||
or getattr(self, field_name) is None):
|
||||
# Check if field is marked as optional
|
||||
actual_type = field_type
|
||||
if get_origin(field_type) is Annotated:
|
||||
args = get_args(field_type)
|
||||
actual_type = args[0]
|
||||
|
||||
# Check arg was provided as Union
|
||||
if get_origin(actual_type) is Union:
|
||||
args = get_args(actual_type)
|
||||
# Skip validation when Union contains None
|
||||
if type(None) in args:
|
||||
continue
|
||||
# If not optional, raise error
|
||||
raise ValueError(f"Required field '{field_name}' is missing")
|
||||
|
||||
# Field exists, proceed with validation
|
||||
value = getattr(self, field_name)
|
||||
|
||||
if get_origin(field_type) is not None:
|
||||
args = get_args(field_type)
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, TensorShape):
|
||||
expected_shape = arg.resolve(**self._resolve_bindings)
|
||||
if isinstance(value, (list, tuple)):
|
||||
actual_shape = self._validate_nested_tensors(
|
||||
value, field_name, expected_shape,
|
||||
arg.dynamic_dims)
|
||||
|
||||
elif isinstance(value, torch.Tensor):
|
||||
actual_shape = value.shape
|
||||
|
||||
else:
|
||||
type_names = []
|
||||
for arg in args:
|
||||
if hasattr(arg, "__name__"):
|
||||
type_names.append(str(arg.__name__))
|
||||
else:
|
||||
type_names.append(str(arg))
|
||||
|
||||
expected_types = ", ".join(type_names)
|
||||
raise ValueError(
|
||||
f"{field_name} is not one of the expected "
|
||||
f"types: {expected_types}")
|
||||
|
||||
self._validate_tensor_shape_expected(
|
||||
actual_shape, expected_shape, field_name,
|
||||
shape_env, arg.dynamic_dims)
|
||||
|
||||
def print_shapes(self) -> None:
|
||||
"""Print TensorShape annotations for debugging."""
|
||||
logger.debug("Shapes in %s:", self.__class__.__name__)
|
||||
type_hints = get_type_hints(self.__class__, include_extras=True)
|
||||
|
||||
for field_name, field_type in type_hints.items():
|
||||
if get_origin(field_type) is not None:
|
||||
args = get_args(field_type)
|
||||
for arg in args:
|
||||
if isinstance(arg, TensorShape):
|
||||
logger.debug(" %s: %s", field_name, str(arg))
|
||||
Loading…
x
Reference in New Issue
Block a user