diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 7db3a1bb90b47..88dd1a57626f2 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -85,6 +85,23 @@ class MiniCPMVImagePixelInputs(TensorSchema): - w: Width """ + def _validate_nested_tensors( + self, + value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], + field_name: str, + expected_shape: tuple[Union[int, str], ...], + dynamic_dims: set[str], + ) -> tuple[int, ...]: + # value[0] is the scaled image, + # and value[1:] is a collection of image slices. + # It is ensured that all slices in the collection + # have the same shape. + if field_name == "pixel_values": + value = value[1:] if len(value) > 1 else value + + return super()._validate_nested_tensors(value, field_name, + expected_shape, dynamic_dims) + type: Literal["pixel_values"] = "pixel_values" # Note that the image size may vary, so we pass it as a list instead of a diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py index 4c3acf0094c74..21d3249fe1547 100644 --- a/vllm/utils/tensor_schema.py +++ b/vllm/utils/tensor_schema.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Annotated, Any, Union, get_args, get_origin, get_type_hints +from typing import (Annotated, Any, Optional, Union, get_args, get_origin, + get_type_hints) import torch @@ -11,9 +12,13 @@ logger = init_logger(__name__) class TensorShape: - def __init__(self, - *dims: Union[int, str], - dynamic_dims: set[str, ...] = None) -> None: + def __init__( + self, + *dims: Union[int, str], + dynamic_dims: Optional[set[str]] = None, + ) -> None: + super().__init__() + self.dims = dims self.dynamic_dims = dynamic_dims if dynamic_dims else set() @@ -44,11 +49,15 @@ class TensorShape: class TensorSchema: - def __init__(self, - *, - validate: bool = True, - resolve_bindings: dict[str, int] = None, - **kwargs: Any) -> None: + def __init__( + self, + *, + validate: bool = True, + resolve_bindings: Optional[dict[str, int]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + self._resolve_bindings = resolve_bindings if resolve_bindings else {} for key, value in kwargs.items(): @@ -57,16 +66,19 @@ class TensorSchema: if validate: self.validate() - def __getitem__(self, item) -> Any: - return getattr(self, item) + def __getitem__(self, key: str) -> Any: + return getattr(self, key) - def get(self, item, default=None) -> Any: - return getattr(self, item, default) + def get(self, key: str, default: Any = None) -> Any: + return getattr(self, key, default) - def _match_shape_with_dynamic(self, actual: tuple[int, ...], - reference: tuple[int, ...], - expected_shape: tuple[Union[int, str], ...], - dynamic_dims: set[str, ...]) -> bool: + def _match_shape_with_dynamic( + self, + actual: tuple[int, ...], + reference: tuple[int, ...], + expected_shape: tuple[Union[int, str], ...], + dynamic_dims: set[str], + ) -> bool: if len(actual) != len(reference) or len(actual) > len(expected_shape): return False @@ -84,10 +96,12 @@ class TensorSchema: return True def _validate_nested_tensors( - self, value: Union[list[torch.Tensor, ...], - tuple[torch.Tensor, ...]], field_name: str, - expected_shape: tuple[Union[int, str], ...], - dynamic_dims: set[str, ...]) -> tuple[int, ...]: + self, + value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], + field_name: str, + expected_shape: tuple[Union[int, str], ...], + dynamic_dims: set[str], + ) -> tuple[int, ...]: """Validate a list/tuple of tensors and return the actual shape.""" # Ensure all tensors in the list have the same # shape, besides dynamic dimensions @@ -110,12 +124,14 @@ class TensorSchema: # shape = (len(list), *tensor.shape) return (len(value), ) + first.shape - def _validate_tensor_shape_expected(self, actual_shape: tuple[int, ...], - expected_shape: tuple[Union[int, str], - ...], - field_name: str, shape_env: dict[str, - int], - dynamic_dims: set[str, ...]) -> None: + def _validate_tensor_shape_expected( + self, + actual_shape: tuple[int, ...], + expected_shape: tuple[Union[int, str], ...], + field_name: str, + shape_env: dict[str, int], + dynamic_dims: set[str], + ) -> None: """Validate that the actual tensor shape matches the expected shape.""" if len(actual_shape) != len(expected_shape):