mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
Migrate Florence2ImagePixelInputs to TensorSchema (#21663)
Signed-off-by: Benji Beck <benjibeck@meta.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
1cbf951ba2
commit
5f8c9a425e
@ -4,7 +4,7 @@
|
|||||||
import math
|
import math
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Literal, Optional, TypedDict, Union
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -29,16 +29,28 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
|
|||||||
PromptUpdate)
|
PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||||
SupportsV0Only)
|
SupportsV0Only)
|
||||||
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
|
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
|
||||||
|
|
||||||
|
|
||||||
class Florence2ImagePixelInputs(TypedDict):
|
class Florence2ImagePixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- b: Batch size
|
||||||
|
- c: Number of channels (3)
|
||||||
|
- h: Height of the image
|
||||||
|
- w: Width of the image
|
||||||
|
"""
|
||||||
|
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
data: torch.Tensor
|
|
||||||
"""Shape: (batch_size, num_channel, height, width)"""
|
data: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("b", 3, "h", "w"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# ViT implementation are all copied from
|
# ViT implementation are all copied from
|
||||||
@ -931,28 +943,6 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Florence2 only supports COSINE as temporal embedding.')
|
'Florence2 only supports COSINE as temporal embedding.')
|
||||||
|
|
||||||
def _validate_pixel_values(
|
|
||||||
self, data: Union[torch.Tensor, list[torch.Tensor]]
|
|
||||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
|
||||||
|
|
||||||
size = self.processor_config["size"]
|
|
||||||
h, w = size["height"], size["width"]
|
|
||||||
expected_dims = (3, h, w)
|
|
||||||
|
|
||||||
def _validate_shape(d: torch.Tensor):
|
|
||||||
actual_dims = tuple(d.shape)
|
|
||||||
|
|
||||||
if actual_dims != expected_dims:
|
|
||||||
expected_expr = tuple(*map(str, expected_dims))
|
|
||||||
raise ValueError(
|
|
||||||
"The expected shape of pixel values per batch "
|
|
||||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
|
||||||
|
|
||||||
for d in data:
|
|
||||||
_validate_shape(d)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
def _parse_and_validate_image_input(self, **kwargs: object):
|
||||||
pixel_values: Optional[Union[list[list[torch.Tensor]],
|
pixel_values: Optional[Union[list[list[torch.Tensor]],
|
||||||
list[torch.Tensor],
|
list[torch.Tensor],
|
||||||
@ -971,10 +961,16 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
"Both pixel values and image embeds are provided.")
|
"Both pixel values and image embeds are provided.")
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
|
size = self.processor_config["size"]
|
||||||
|
expected_h, expected_w = size["height"], size["width"]
|
||||||
|
|
||||||
return Florence2ImagePixelInputs(
|
return Florence2ImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(
|
data=flatten_bn(pixel_values, concat=True),
|
||||||
flatten_bn(pixel_values, concat=True)),
|
resolve_bindings={
|
||||||
|
"h": expected_h,
|
||||||
|
"w": expected_w
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user