mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 14:05:39 +08:00
Migrate InternVLImageInputs and InternVLVideoInputs to TensorSchema (#21684)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
12a223ef9b
commit
f1e2c095ec
@ -9,7 +9,7 @@
|
|||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Any, Literal, Optional, TypedDict, TypeVar, Union
|
from typing import Annotated, Any, Literal, Optional, TypeVar, Union
|
||||||
|
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
@ -37,6 +37,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
@ -51,54 +52,60 @@ IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
|||||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||||
|
|
||||||
|
|
||||||
class InternVLImagePixelInputs(TypedDict):
|
class InternVLImagePixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- bn: Batch size * number of images
|
||||||
|
- bnp: Batch size * number of images * (1 + num_patches)
|
||||||
|
- c: Number of channels (3)
|
||||||
|
- h: Height of each image patch
|
||||||
|
- w: Width of each image patch
|
||||||
|
"""
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
pixel_values_flat: torch.Tensor
|
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
|
||||||
|
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLImageEmbeddingInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
Shape:
|
Dimensions:
|
||||||
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
|
- n: Number of images
|
||||||
|
- f: Total image feature size
|
||||||
|
- h: Hidden size (must match the hidden size of language model backbone)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_patches: torch.Tensor
|
|
||||||
"""Shape: `(batch_size * num_images)`"""
|
|
||||||
|
|
||||||
|
|
||||||
class InternVLImageEmbeddingInputs(TypedDict):
|
|
||||||
type: Literal["image_embeds"]
|
type: Literal["image_embeds"]
|
||||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||||
"""
|
TensorShape("n", "f", "h")]
|
||||||
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
|
|
||||||
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
|
|
||||||
|
|
||||||
`hidden_size` must match the hidden size of language model backbone.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
InternVLImageInputs = Union[InternVLImagePixelInputs,
|
InternVLImageInputs = Union[InternVLImagePixelInputs,
|
||||||
InternVLImageEmbeddingInputs]
|
InternVLImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
class InternVLVideoPixelInputs(TypedDict):
|
class InternVLVideoPixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- bvf: Batch size * number of videos * num_frames
|
||||||
|
- bn: Batch size * number of images
|
||||||
|
- c: Number of channels (3)
|
||||||
|
- h: Height of each video frame
|
||||||
|
- w: Width of each video frame
|
||||||
|
"""
|
||||||
type: Literal["pixel_values_videos"]
|
type: Literal["pixel_values_videos"]
|
||||||
pixel_values_flat: torch.Tensor
|
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
|
||||||
|
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLVideoEmbeddingInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
Shape:
|
Dimensions:
|
||||||
`(batch_size * num_video * num_frames, num_channels, height, width)`
|
- n: Number of videos
|
||||||
|
- f: Total video feature size
|
||||||
|
- h: Hidden size (must match the hidden size of language model backbone)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_patches: torch.Tensor
|
|
||||||
"""Shape: `(batch_size * num_images)`"""
|
|
||||||
|
|
||||||
|
|
||||||
class InternVLVideoEmbeddingInputs(TypedDict):
|
|
||||||
type: Literal["video_embeds"]
|
type: Literal["video_embeds"]
|
||||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||||
"""
|
TensorShape("n", "f", "h")]
|
||||||
A tensor of shape `(num_videos, total_video_feature_size, hidden_size)`
|
|
||||||
or a list of tensors of shape `(total_video_feature_size, hidden_size)`
|
|
||||||
|
|
||||||
`hidden_size` must match the hidden size of language model backbone.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
InternVLVideoInputs = Union[InternVLVideoPixelInputs,
|
InternVLVideoInputs = Union[InternVLVideoPixelInputs,
|
||||||
@ -1151,26 +1158,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
vit_embeds = self.mlp1(vit_embeds)
|
vit_embeds = self.mlp1(vit_embeds)
|
||||||
return vit_embeds
|
return vit_embeds
|
||||||
|
|
||||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
|
||||||
|
|
||||||
h = w = self.config.vision_config.image_size
|
|
||||||
expected_dims = (3, h, w)
|
|
||||||
|
|
||||||
def _validate_shape(d: torch.Tensor):
|
|
||||||
actual_dims = tuple(d.shape)
|
|
||||||
|
|
||||||
if actual_dims != expected_dims:
|
|
||||||
expected_expr = str(expected_dims)
|
|
||||||
raise ValueError(
|
|
||||||
"The expected shape of pixel values per image per batch "
|
|
||||||
f" per patch is {expected_expr}. "
|
|
||||||
f"You supplied {tuple(d.shape)}.")
|
|
||||||
|
|
||||||
for d in data:
|
|
||||||
_validate_shape(d)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
||||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||||
@ -1205,12 +1192,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
|
|
||||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||||
|
expected_h = expected_w = self.config.vision_config.image_size
|
||||||
|
resolve_bindings = {"h": expected_h, "w": expected_w}
|
||||||
|
|
||||||
return InternVLImagePixelInputs(
|
return InternVLImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values_flat=self._validate_pixel_values(
|
pixel_values_flat=pixel_values_flat,
|
||||||
pixel_values_flat),
|
|
||||||
num_patches=image_num_patches,
|
num_patches=image_num_patches,
|
||||||
|
resolve_bindings=resolve_bindings,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
@ -1225,11 +1214,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if video_embeds is not None:
|
if video_embeds is not None:
|
||||||
if not isinstance(video_embeds, (torch.Tensor, list)):
|
return InternVLVideoEmbeddingInputs(
|
||||||
raise ValueError("Incorrect type of video embeddings. "
|
|
||||||
f"Got type: {type(video_embeds)}")
|
|
||||||
|
|
||||||
return InternVLImageEmbeddingInputs(
|
|
||||||
type="video_embeds",
|
type="video_embeds",
|
||||||
data=flatten_bn(video_embeds),
|
data=flatten_bn(video_embeds),
|
||||||
)
|
)
|
||||||
@ -1250,12 +1235,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
|
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
|
||||||
concat=True)
|
concat=True)
|
||||||
video_num_patches = flatten_bn(video_num_patches, concat=True)
|
video_num_patches = flatten_bn(video_num_patches, concat=True)
|
||||||
|
expected_h = expected_w = self.config.vision_config.image_size
|
||||||
|
resolve_bindings = {"h": expected_h, "w": expected_w}
|
||||||
|
|
||||||
return InternVLVideoPixelInputs(
|
return InternVLVideoPixelInputs(
|
||||||
type="pixel_values_videos",
|
type="pixel_values_videos",
|
||||||
pixel_values_flat=self._validate_pixel_values(
|
pixel_values_flat=pixel_values_flat_video,
|
||||||
pixel_values_flat_video),
|
|
||||||
num_patches=video_num_patches,
|
num_patches=video_num_patches,
|
||||||
|
resolve_bindings=resolve_bindings,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user