mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:45:01 +08:00
Migrate FuyuImagePatchInputs to TensorSchema (#21662)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
0b8caf9095
commit
3339cba3ff
@ -4,6 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.fuyu import FuyuImagePatchInputs
|
||||
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
|
||||
|
||||
|
||||
@ -124,3 +125,24 @@ def test_tensor_schema_with_invalid_resolve_binding_dims():
|
||||
"w": 336
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_with_list_of_symbolic_dim():
|
||||
flat_data = torch.stack([torch.randn(768) for _ in range(3)]) # (bn=3, fn)
|
||||
patches_per_image = [64, 64, 64] # len = bn = 3
|
||||
|
||||
FuyuImagePatchInputs(
|
||||
flat_data=flat_data,
|
||||
patches_per_image=patches_per_image,
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length():
|
||||
flat_data = torch.stack([torch.randn(768) for _ in range(4)]) # (bn=4, fn)
|
||||
patches_per_image = [64, 64, 64] # len = 3 ≠ bn
|
||||
|
||||
with pytest.raises(ValueError, match="expected 'bn'=4, got 3"):
|
||||
FuyuImagePatchInputs(
|
||||
flat_data=flat_data,
|
||||
patches_per_image=patches_per_image,
|
||||
)
|
||||
@ -19,7 +19,7 @@
|
||||
""" PyTorch Fuyu model."""
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Literal, Optional, TypedDict
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
|
||||
@ -50,15 +51,21 @@ _IMAGE_TOKEN_ID = 71011
|
||||
_NEWLINE_TOKEN_ID = 71019
|
||||
|
||||
|
||||
class FuyuImagePatchInputs(TypedDict):
|
||||
type: Literal["image_patches"]
|
||||
flat_data: torch.Tensor
|
||||
class FuyuImagePatchInputs(TensorSchema):
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- fn: Num channels * patch_size_x * patch_size_y
|
||||
"""
|
||||
|
||||
patches_per_image: list[int]
|
||||
type: Literal["image_patches"] = "image_patches"
|
||||
|
||||
flat_data: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bn", "fn"),
|
||||
]
|
||||
|
||||
patches_per_image: Annotated[list[int], TensorShape("bn")]
|
||||
"""
|
||||
The number of total patches for each image in the batch.
|
||||
|
||||
@ -297,42 +304,18 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
h = w = self.config.patch_size
|
||||
num_channels = self.config.num_channels
|
||||
expected_dims = num_channels * h * w
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = d.size(-1)
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = str(expected_dims)
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values per image per batch "
|
||||
f"per patch is {expected_expr}. "
|
||||
f"You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data.to(self.vision_embed_tokens.weight.dtype)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
||||
image_patches = kwargs.pop("image_patches", None)
|
||||
if image_patches is not None:
|
||||
if not isinstance(image_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image patches. "
|
||||
f"Got type: {type(image_patches)}")
|
||||
|
||||
image_patches_flat = flatten_bn(image_patches)
|
||||
|
||||
flat_data = flatten_bn(image_patches, concat=True).data.to(
|
||||
self.vision_embed_tokens.weight.dtype)
|
||||
return FuyuImagePatchInputs(
|
||||
type="image_patches",
|
||||
flat_data=self._validate_pixel_values(
|
||||
flatten_bn(image_patches_flat, concat=True)),
|
||||
flat_data=flat_data,
|
||||
patches_per_image=[x.size(0) for x in image_patches_flat],
|
||||
resolve_bindings={"fn": self.image_feature_size},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@ -86,9 +86,6 @@ class TensorSchema:
|
||||
expected_shape: tuple[Union[int, str], ...],
|
||||
dynamic_dims: set[str, ...]) -> tuple[int, ...]:
|
||||
"""Validate a list/tuple of tensors and return the actual shape."""
|
||||
if not value:
|
||||
raise ValueError(f"{field_name} is an empty list")
|
||||
|
||||
# Ensure all tensors in the list have the same
|
||||
# shape, besides dynamic dimensions
|
||||
first = value[0]
|
||||
@ -117,6 +114,7 @@ class TensorSchema:
|
||||
int],
|
||||
dynamic_dims: set[str, ...]) -> None:
|
||||
"""Validate that the actual tensor shape matches the expected shape."""
|
||||
|
||||
if len(actual_shape) != len(expected_shape):
|
||||
raise ValueError(f"{field_name} has rank {len(actual_shape)} "
|
||||
f"but expected {len(expected_shape)}")
|
||||
@ -160,12 +158,11 @@ class TensorSchema:
|
||||
# Skip validation when Union contains None
|
||||
if type(None) in args:
|
||||
continue
|
||||
# If not optional, raise error
|
||||
# Otherwise field is required, raise error
|
||||
raise ValueError(f"Required field '{field_name}' is missing")
|
||||
|
||||
# Field exists, proceed with validation
|
||||
value = getattr(self, field_name)
|
||||
|
||||
if get_origin(field_type) is not None:
|
||||
args = get_args(field_type)
|
||||
|
||||
@ -173,13 +170,23 @@ class TensorSchema:
|
||||
if isinstance(arg, TensorShape):
|
||||
expected_shape = arg.resolve(**self._resolve_bindings)
|
||||
if isinstance(value, (list, tuple)):
|
||||
actual_shape = self._validate_nested_tensors(
|
||||
value, field_name, expected_shape,
|
||||
arg.dynamic_dims)
|
||||
# list/tuple of Tensors → shape = (len(value), ...)
|
||||
if value and isinstance(value[0], torch.Tensor):
|
||||
actual_shape = self._validate_nested_tensors(
|
||||
value, field_name, expected_shape,
|
||||
arg.dynamic_dims)
|
||||
elif value:
|
||||
# list/tuple of scalars → shape = (len(value),)
|
||||
actual_shape = (len(value), )
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{field_name} is an empty list")
|
||||
|
||||
# Tensor → shape = tensor.shape
|
||||
elif isinstance(value, torch.Tensor):
|
||||
actual_shape = value.shape
|
||||
|
||||
# Otherwise, it's an unsupported type
|
||||
else:
|
||||
type_names = []
|
||||
for arg in args:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user