mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:06:03 +08:00
[Model] Support nested structures for TensorSchema (#26212)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
d3d649efec
commit
44ea85137a
@ -6,37 +6,39 @@ import torch
|
|||||||
|
|
||||||
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
|
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
|
||||||
from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
|
from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
|
||||||
|
from vllm.model_executor.models.hyperclovax_vision import (
|
||||||
|
HCXVisionVideoPixelInputs)
|
||||||
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
|
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_schema_valid_tensor():
|
def test_tensor_schema_valid_tensor():
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=torch.randn(16, 64, 3, 32, 32),
|
pixel_values=torch.randn(16, 64, 3, 32, 32),
|
||||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_schema_optional_fields():
|
def test_tensor_schema_optional_fields():
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=torch.randn(16, 64, 3, 32, 32),
|
pixel_values=torch.randn(16, 64, 3, 32, 32),
|
||||||
image_sizes=None,
|
image_sizes=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), )
|
Phi3VImagePixelInputs(pixel_values=torch.randn(16, 64, 3, 32, 32))
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_schema_constant_dim_failure():
|
def test_tensor_schema_constant_dim_failure():
|
||||||
with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"):
|
with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"):
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
|
pixel_values=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
|
||||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_schema_invalid_types_in_list():
|
def test_tensor_schema_invalid_types_in_list():
|
||||||
with pytest.raises(ValueError, match="is not a torch.Tensor"):
|
with pytest.raises(TypeError, match="is not one of the expected types"):
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=[
|
pixel_values=[
|
||||||
torch.randn(64, 3, 32, 32),
|
torch.randn(64, 3, 32, 32),
|
||||||
"not_a_tensor",
|
"not_a_tensor",
|
||||||
torch.randn(64, 3, 32, 32),
|
torch.randn(64, 3, 32, 32),
|
||||||
@ -48,27 +50,28 @@ def test_tensor_schema_invalid_types_in_list():
|
|||||||
def test_tensor_schema_rank_mismatch():
|
def test_tensor_schema_rank_mismatch():
|
||||||
with pytest.raises(ValueError, match="has rank 3 but expected 5"):
|
with pytest.raises(ValueError, match="has rank 3 but expected 5"):
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=torch.randn(16, 64, 3),
|
pixel_values=torch.randn(16, 64, 3),
|
||||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_schema_missing_required_field():
|
def test_tensor_schema_missing_required_field():
|
||||||
with pytest.raises(ValueError, match="Required field 'data' is missing"):
|
with pytest.raises(ValueError,
|
||||||
|
match="Required field 'pixel_values' is missing"):
|
||||||
Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), )
|
Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), )
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_schema_symbolic_dim_mismatch():
|
def test_tensor_schema_symbolic_dim_mismatch():
|
||||||
with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
|
with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=torch.randn(12, 64, 3, 32, 32),
|
pixel_values=torch.randn(12, 64, 3, 32, 32),
|
||||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_schema_list_tensor_valid():
|
def test_tensor_schema_list_tensor_valid():
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=[torch.randn(64, 3, 32, 32) for _ in range(16)],
|
pixel_values=[torch.randn(64, 3, 32, 32) for _ in range(16)],
|
||||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -76,39 +79,46 @@ def test_tensor_schema_list_tensor_valid():
|
|||||||
def test_tensor_schema_variable_patch_counts_valid():
|
def test_tensor_schema_variable_patch_counts_valid():
|
||||||
# Each image has a different number of patches (p)
|
# Each image has a different number of patches (p)
|
||||||
# Each tensor has shape (p, 3, 32, 32)
|
# Each tensor has shape (p, 3, 32, 32)
|
||||||
data = [
|
|
||||||
torch.randn(16, 3, 32, 32), # p = 16
|
|
||||||
torch.randn(32, 3, 32, 32), # p = 32
|
|
||||||
torch.randn(64, 3, 32, 32), # p = 64
|
|
||||||
]
|
|
||||||
image_sizes = torch.randint(0, 256, (3, 2)) # bn = 3
|
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=data,
|
pixel_values=[
|
||||||
image_sizes=image_sizes,
|
torch.randn(16, 3, 32, 32), # p = 16
|
||||||
|
torch.randn(32, 3, 32, 32), # p = 32
|
||||||
|
torch.randn(64, 3, 32, 32), # p = 64
|
||||||
|
],
|
||||||
|
image_sizes=torch.randint(0, 256, (3, 2)), # bn = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_schema_tuple_tensor_valid():
|
def test_tensor_schema_tuple_tensor_valid():
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
|
pixel_values=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
|
||||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_schema_double_nested_tensors():
|
||||||
|
x = torch.rand(4, 3, 32, 32)
|
||||||
|
y = torch.rand(2, 3, 32, 32)
|
||||||
|
|
||||||
|
HCXVisionVideoPixelInputs(pixel_values_videos=([x, y, x], [y], [x, y]))
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_schema_inconsistent_shapes_in_list():
|
def test_tensor_schema_inconsistent_shapes_in_list():
|
||||||
with pytest.raises(ValueError, match="contains inconsistent shapes"):
|
with pytest.raises(ValueError, match="contains inconsistent shapes"):
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=[torch.randn(64, 3, 32, 32),
|
pixel_values=[
|
||||||
torch.randn(64, 3, 16, 16)] +
|
torch.randn(64, 3, 32, 32),
|
||||||
[torch.randn(64, 3, 32, 32) for _ in range(14)],
|
torch.randn(64, 3, 16, 16),
|
||||||
|
*(torch.randn(64, 3, 32, 32) for _ in range(14)),
|
||||||
|
],
|
||||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_schema_empty_list():
|
def test_tensor_schema_empty_list():
|
||||||
with pytest.raises(ValueError, match="is an empty list"):
|
with pytest.raises(ValueError, match="is an empty sequence"):
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=[],
|
pixel_values=[],
|
||||||
image_sizes=torch.randint(0, 256, (0, 2)),
|
image_sizes=torch.randint(0, 256, (0, 2)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -117,18 +127,18 @@ def test_tensor_schema_validation_disabled_skips_shape_check():
|
|||||||
# This should NOT raise, because validation is turned off
|
# This should NOT raise, because validation is turned off
|
||||||
# This would normally fail (dim[2] should be 3, not 4)
|
# This would normally fail (dim[2] should be 3, not 4)
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=torch.randn(16, 64, 4, 32, 32),
|
pixel_values=torch.randn(16, 64, 4, 32, 32),
|
||||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||||
validate=False,
|
validate=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_schema_with_valid_resolve_binding_dims():
|
def test_tensor_schema_with_valid_resolve_binding_dims():
|
||||||
data = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
|
pixel_values = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
|
||||||
image_sizes = torch.randint(0, 256, (16, 2))
|
image_sizes = torch.randint(0, 256, (16, 2))
|
||||||
|
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=data,
|
pixel_values=pixel_values,
|
||||||
image_sizes=image_sizes,
|
image_sizes=image_sizes,
|
||||||
resolve_bindings={
|
resolve_bindings={
|
||||||
"h": 336,
|
"h": 336,
|
||||||
@ -138,13 +148,13 @@ def test_tensor_schema_with_valid_resolve_binding_dims():
|
|||||||
|
|
||||||
|
|
||||||
def test_tensor_schema_with_invalid_resolve_binding_dims():
|
def test_tensor_schema_with_invalid_resolve_binding_dims():
|
||||||
data = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
|
pixel_values = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
|
||||||
image_sizes = torch.randint(0, 256, (16, 2))
|
image_sizes = torch.randint(0, 256, (16, 2))
|
||||||
|
|
||||||
# Should raise because 'h' and 'w' don't match resolve bindings
|
# Should raise because 'h' and 'w' don't match resolve bindings
|
||||||
with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"):
|
with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"):
|
||||||
Phi3VImagePixelInputs(
|
Phi3VImagePixelInputs(
|
||||||
data=data,
|
pixel_values=pixel_values,
|
||||||
image_sizes=image_sizes,
|
image_sizes=image_sizes,
|
||||||
resolve_bindings={
|
resolve_bindings={
|
||||||
"h": 336,
|
"h": 336,
|
||||||
|
|||||||
@ -29,7 +29,7 @@
|
|||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Annotated, Any, Callable, Literal, Optional, Union, override
|
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -1170,7 +1170,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
|
|||||||
"video.height override (%d) exceeds model's "
|
"video.height override (%d) exceeds model's "
|
||||||
"maximum height (%d), will be ignored",
|
"maximum height (%d), will be ignored",
|
||||||
overrides.height, height)
|
overrides.height, height)
|
||||||
height = min(height, override.height)
|
height = min(height, overrides.height)
|
||||||
|
|
||||||
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
||||||
video_items = []
|
video_items = []
|
||||||
|
|||||||
@ -2,27 +2,16 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
# copied from : https://github.com/huggingface/transformers
|
# copied from : https://github.com/huggingface/transformers
|
||||||
import ast
|
import ast
|
||||||
import sys
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import chain
|
from itertools import accumulate
|
||||||
from typing import Any, Literal, Optional, TypedDict, Union
|
from typing import Annotated, Any, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
|
||||||
from einops import rearrange
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 11):
|
|
||||||
import typing
|
|
||||||
Unpack = typing.Unpack
|
|
||||||
else:
|
|
||||||
import typing_extensions
|
|
||||||
Unpack = typing_extensions.Unpack
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
from timm.layers import LayerNorm, LayerNorm2d
|
from timm.layers import LayerNorm, LayerNorm2d
|
||||||
from timm.models.regnet import RegStage
|
from timm.models.regnet import RegStage
|
||||||
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
|
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
|
||||||
@ -42,11 +31,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptReplacement, PromptUpdate)
|
PromptReplacement, PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .clip import CLIPVisionModel
|
from .clip import CLIPVisionModel
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
|
maybe_prefix)
|
||||||
from .vision import get_vision_encoder_info
|
from .vision import get_vision_encoder_info
|
||||||
|
|
||||||
EOT = "<|endofturn|>"
|
EOT = "<|endofturn|>"
|
||||||
@ -69,28 +60,42 @@ def get_num_combined_frames(
|
|||||||
return num_canvases + (leftover_frames > 0)
|
return num_canvases + (leftover_frames > 0)
|
||||||
|
|
||||||
|
|
||||||
class HCXVisionMultimodalPixelInputs(TypedDict):
|
class HCXVisionImagePixelInputs(TensorSchema):
|
||||||
type: Literal["pixel_values"]
|
|
||||||
pixel_values_images: list[torch.Tensor]
|
|
||||||
"""
|
"""
|
||||||
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
|
Dimensions:
|
||||||
|
- n: Number of images
|
||||||
Note that `height` or `width` may be different per batch and image,
|
- g: Number of grids
|
||||||
in which case the data is passed as a list instead of a batched tensor.
|
- c: Number of channels (3)
|
||||||
|
- h: Height
|
||||||
|
- w: Width
|
||||||
"""
|
"""
|
||||||
image_sizes_images: list[tuple[Union[int, float]]]
|
type: Literal["pixel_values"] = "pixel_values"
|
||||||
"""
|
pixel_values_images: Annotated[
|
||||||
Shape: `[(height, width), ...]`
|
list[torch.Tensor],
|
||||||
"""
|
TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})]
|
||||||
vision_query_lengths_images: list[Union[int, float]]
|
image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)]
|
||||||
pixel_values_videos: list[tuple[Union[int, float]]]
|
|
||||||
"""
|
|
||||||
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
|
|
||||||
"""
|
|
||||||
vision_query_lengths_videos: list[Union[int, float]]
|
|
||||||
|
|
||||||
|
|
||||||
HCXVisionMultimodalInputs = Union[HCXVisionMultimodalPixelInputs]
|
HCXVisionImageInputs = HCXVisionImagePixelInputs
|
||||||
|
|
||||||
|
|
||||||
|
class HCXVisionVideoPixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- n: Number of videos
|
||||||
|
- f: Number of frames
|
||||||
|
- g: Number of grids
|
||||||
|
- c: Number of channels (3)
|
||||||
|
- h: Height
|
||||||
|
- w: Width
|
||||||
|
"""
|
||||||
|
type: Literal["pixel_values_videos"] = "pixel_values_videos"
|
||||||
|
pixel_values_videos: Annotated[
|
||||||
|
list[list[torch.Tensor]],
|
||||||
|
TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"})]
|
||||||
|
|
||||||
|
|
||||||
|
HCXVisionVideoInputs = HCXVisionVideoPixelInputs
|
||||||
|
|
||||||
|
|
||||||
class HCXVisionProcessingInfo(BaseProcessingInfo):
|
class HCXVisionProcessingInfo(BaseProcessingInfo):
|
||||||
@ -191,27 +196,9 @@ class HCXVisionMultiModalProcessor(
|
|||||||
mm_kwargs: Mapping[str, object],
|
mm_kwargs: Mapping[str, object],
|
||||||
tok_kwargs: Mapping[str, object],
|
tok_kwargs: Mapping[str, object],
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
|
|
||||||
def replace_multimodal_token(
|
|
||||||
token_ids: torch.Tensor,
|
|
||||||
target_token: int,
|
|
||||||
repeats: list[int],
|
|
||||||
):
|
|
||||||
output = list[int]()
|
|
||||||
_repeats_idx = 0
|
|
||||||
for token_id in token_ids:
|
|
||||||
if token_id == target_token:
|
|
||||||
output += [token_id.item()] * repeats[_repeats_idx]
|
|
||||||
_repeats_idx += 1
|
|
||||||
else:
|
|
||||||
output += [token_id.item()]
|
|
||||||
|
|
||||||
return torch.tensor(output, device=token_ids.device)
|
|
||||||
|
|
||||||
for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
|
for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
|
||||||
if video_arr.dtype == np.uint8:
|
if video_arr.dtype != np.uint8:
|
||||||
continue
|
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
|
||||||
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
|
|
||||||
|
|
||||||
processed_outputs = self.info.ctx.call_hf_processor(
|
processed_outputs = self.info.ctx.call_hf_processor(
|
||||||
hf_processor=self.info.get_hf_processor(**mm_kwargs),
|
hf_processor=self.info.get_hf_processor(**mm_kwargs),
|
||||||
@ -223,20 +210,16 @@ class HCXVisionMultiModalProcessor(
|
|||||||
) # text-only
|
) # text-only
|
||||||
|
|
||||||
if len(mm_data) > 0:
|
if len(mm_data) > 0:
|
||||||
|
images = mm_data.get("images")
|
||||||
|
videos = mm_data.get("videos")
|
||||||
|
|
||||||
# batchify input as a single item
|
# batchify input as a single item
|
||||||
images = mm_data.get("images", None)
|
|
||||||
batched_images = None if images is None else [images]
|
|
||||||
|
|
||||||
# list of video in single conversation
|
|
||||||
videos = mm_data.get("videos", None)
|
|
||||||
batched_videos = None if videos is None else [videos]
|
|
||||||
|
|
||||||
_processed_outputs = self.info.ctx.call_hf_processor(
|
_processed_outputs = self.info.ctx.call_hf_processor(
|
||||||
hf_processor=self.info.get_hf_processor(**mm_kwargs),
|
hf_processor=self.info.get_hf_processor(**mm_kwargs),
|
||||||
data=dict(
|
data=dict(
|
||||||
text=None,
|
text=None,
|
||||||
images=batched_images,
|
images=None if images is None else [images],
|
||||||
videos=batched_videos,
|
videos=None if videos is None else [videos],
|
||||||
),
|
),
|
||||||
) # mm-only
|
) # mm-only
|
||||||
|
|
||||||
@ -246,51 +229,43 @@ class HCXVisionMultiModalProcessor(
|
|||||||
_processed_outputs[k] = v[0]
|
_processed_outputs[k] = v[0]
|
||||||
|
|
||||||
if images:
|
if images:
|
||||||
tokenizer = self.info.get_tokenizer()
|
_processed_outputs["image_sizes_images"] = torch.tensor(
|
||||||
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
_processed_outputs["image_sizes_images"])
|
||||||
processed_outputs["input_ids"] = torch.stack([
|
_processed_outputs[
|
||||||
replace_multimodal_token(
|
"vision_query_lengths_images"] = torch.tensor(
|
||||||
token_ids=_input_ids,
|
_processed_outputs["vision_query_lengths_images"])
|
||||||
target_token=image_token_id,
|
|
||||||
repeats=_processed_outputs[
|
|
||||||
"vision_query_lengths_images"],
|
|
||||||
) for _input_ids in processed_outputs["input_ids"]
|
|
||||||
],
|
|
||||||
dim=0)
|
|
||||||
|
|
||||||
if videos:
|
if videos:
|
||||||
_num_per_videos = [
|
_idx_per_video = [
|
||||||
get_num_combined_frames(len(video)) for video in videos
|
0, *accumulate(
|
||||||
|
get_num_combined_frames(len(video))
|
||||||
|
for video in videos)
|
||||||
]
|
]
|
||||||
_processed_outputs["pixel_values_videos"] = [
|
_processed_outputs["pixel_values_videos"] = [
|
||||||
_processed_outputs["pixel_values_videos"]
|
_processed_outputs["pixel_values_videos"]
|
||||||
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
|
[_idx_per_video[i]:_idx_per_video[i + 1]]
|
||||||
for _i in range(len(videos))
|
for i in range(len(videos))
|
||||||
]
|
]
|
||||||
_processed_outputs["vision_query_lengths_videos"] = [
|
_processed_outputs["vision_query_lengths_videos"] = [
|
||||||
_processed_outputs["vision_query_lengths_videos"]
|
torch.tensor(
|
||||||
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
|
_processed_outputs["vision_query_lengths_videos"]
|
||||||
for _i in range(len(videos))
|
[_idx_per_video[i]:_idx_per_video[i + 1]])
|
||||||
|
for i in range(len(videos))
|
||||||
]
|
]
|
||||||
|
|
||||||
tokenizer = self.info.get_tokenizer()
|
|
||||||
video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN)
|
|
||||||
processed_outputs["input_ids"] = torch.stack([
|
|
||||||
replace_multimodal_token(
|
|
||||||
token_ids=_input_ids,
|
|
||||||
target_token=video_token_id,
|
|
||||||
repeats=[
|
|
||||||
sum(lens) for lens in
|
|
||||||
_processed_outputs["vision_query_lengths_videos"]
|
|
||||||
],
|
|
||||||
) for _input_ids in processed_outputs["input_ids"]
|
|
||||||
],
|
|
||||||
dim=0)
|
|
||||||
|
|
||||||
processed_outputs.update(_processed_outputs)
|
processed_outputs.update(_processed_outputs)
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
|
def _hf_processor_applies_updates(
|
||||||
|
self,
|
||||||
|
prompt_text: str,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
tokenization_kwargs: Mapping[str, object],
|
||||||
|
) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
self,
|
self,
|
||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
@ -311,11 +286,11 @@ class HCXVisionMultiModalProcessor(
|
|||||||
out_item = out_mm_kwargs[modality][item_idx]
|
out_item = out_mm_kwargs[modality][item_idx]
|
||||||
|
|
||||||
if modality == "image":
|
if modality == "image":
|
||||||
lens = out_item["vision_query_lengths_images"].data
|
lens = out_item["vision_query_lengths_images"].data.tolist()
|
||||||
num_tokens = self.info.get_num_image_tokens(
|
num_tokens = self.info.get_num_image_tokens(
|
||||||
vision_query_length=lens)
|
vision_query_length=lens)
|
||||||
elif modality == "video":
|
elif modality == "video":
|
||||||
lens = out_item["vision_query_lengths_videos"].data
|
lens = out_item["vision_query_lengths_videos"].data.tolist()
|
||||||
num_tokens = self.info.get_num_video_tokens(
|
num_tokens = self.info.get_num_video_tokens(
|
||||||
vision_query_length=lens)
|
vision_query_length=lens)
|
||||||
else:
|
else:
|
||||||
@ -343,26 +318,11 @@ class HCXVisionMultiModalProcessor(
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
return dict(
|
return dict(
|
||||||
# image
|
|
||||||
pixel_values_images=MultiModalFieldConfig.batched("image"),
|
pixel_values_images=MultiModalFieldConfig.batched("image"),
|
||||||
image_sizes_images=MultiModalFieldConfig.batched("image"),
|
image_sizes_images=MultiModalFieldConfig.batched("image"),
|
||||||
vision_query_lengths_images=MultiModalFieldConfig.batched("image"),
|
vision_query_lengths_images=MultiModalFieldConfig.batched("image"),
|
||||||
num_queries_vis_abstractors_images=MultiModalFieldConfig.batched(
|
|
||||||
"image"),
|
|
||||||
num_queries_vis_abstractors_slow_images=MultiModalFieldConfig.
|
|
||||||
batched("image"),
|
|
||||||
first_last_frames_slows_images=MultiModalFieldConfig.batched(
|
|
||||||
"image"),
|
|
||||||
# video
|
|
||||||
pixel_values_videos=MultiModalFieldConfig.batched("video"),
|
pixel_values_videos=MultiModalFieldConfig.batched("video"),
|
||||||
image_sizes_videos=MultiModalFieldConfig.batched("video"),
|
|
||||||
vision_query_lengths_videos=MultiModalFieldConfig.batched("video"),
|
vision_query_lengths_videos=MultiModalFieldConfig.batched("video"),
|
||||||
num_queries_vis_abstractors_videos=MultiModalFieldConfig.batched(
|
|
||||||
"video"),
|
|
||||||
num_queries_vis_abstractors_slow_videos=MultiModalFieldConfig.
|
|
||||||
batched("video"),
|
|
||||||
first_last_frames_slows_videos=MultiModalFieldConfig.batched(
|
|
||||||
"video"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -617,6 +577,7 @@ class HCXVisionCAbstractor(nn.Module):
|
|||||||
info=_build_hcxvision_hf_info,
|
info=_build_hcxvision_hf_info,
|
||||||
dummy_inputs=HCXVisionDummyInputsBuilder)
|
dummy_inputs=HCXVisionDummyInputsBuilder)
|
||||||
class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
@ -692,55 +653,94 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
raise ValueError("Only image or video modality is supported")
|
raise ValueError("Only image or video modality is supported")
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(
|
||||||
|
self,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> Optional[HCXVisionImageInputs]:
|
||||||
|
pixel_values_images = kwargs.pop("pixel_values_images", None)
|
||||||
|
|
||||||
|
if pixel_values_images is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
image_sizes_images = kwargs.pop("image_sizes_images")
|
||||||
|
|
||||||
|
return HCXVisionImagePixelInputs(
|
||||||
|
pixel_values_images=pixel_values_images,
|
||||||
|
image_sizes_images=image_sizes_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_and_validate_video_input(
|
||||||
|
self,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> Optional[HCXVisionVideoInputs]:
|
||||||
|
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
|
||||||
|
|
||||||
|
if pixel_values_videos is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return HCXVisionVideoPixelInputs(
|
||||||
|
pixel_values_videos=pixel_values_videos, )
|
||||||
|
|
||||||
|
def _process_image_input(
|
||||||
|
self,
|
||||||
|
image_input: HCXVisionImageInputs,
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
return self.forward_images(
|
||||||
|
pixel_values_images=image_input["pixel_values_images"],
|
||||||
|
image_sizes_images=image_input["image_sizes_images"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_video_input(
|
||||||
|
self,
|
||||||
|
video_input: HCXVisionVideoInputs,
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
return self.forward_videos(
|
||||||
|
pixel_values_videos=video_input["pixel_values_videos"], )
|
||||||
|
|
||||||
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
|
modalities = {}
|
||||||
|
|
||||||
|
# Preserve the order of modalities if there are multiple of them
|
||||||
|
# from the order of kwargs.
|
||||||
|
for input_key in kwargs:
|
||||||
|
if (input_key == "pixel_values_images"
|
||||||
|
and "images" not in modalities):
|
||||||
|
modalities["images"] = self._parse_and_validate_image_input(
|
||||||
|
**kwargs)
|
||||||
|
if (input_key == "pixel_values_videos"
|
||||||
|
and "videos" not in modalities):
|
||||||
|
modalities["videos"] = self._parse_and_validate_video_input(
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return modalities
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.language_model
|
return self.language_model
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self,
|
self,
|
||||||
**kwargs: Unpack[HCXVisionMultimodalInputs],
|
**kwargs: object,
|
||||||
) -> MultiModalEmbeddings:
|
) -> MultiModalEmbeddings:
|
||||||
|
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||||
|
if not modalities:
|
||||||
|
return []
|
||||||
|
|
||||||
multimodal_embeddings = list()
|
# The result multimodal_embeddings is tuple of tensors, with each
|
||||||
if kwargs.get("pixel_values_images") is not None:
|
# tensor correspoending to a multimodal data item (image or video).
|
||||||
for _pixel_values_images, _image_sizes_images in zip(
|
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||||
kwargs["pixel_values_images"],
|
|
||||||
kwargs["image_sizes_images"]):
|
# NOTE: It is important to iterate over the keys in this dictionary
|
||||||
_pixel_values_images = _pixel_values_images.unsqueeze(dim=0)
|
# to preserve the order of the modalities.
|
||||||
_image_sizes_images = _image_sizes_images.unsqueeze(dim=0)
|
for modality in modalities:
|
||||||
_len_pixel_values_images = [
|
if modality == "images":
|
||||||
len(pixel_value) for pixel_value in _pixel_values_images
|
image_input = modalities["images"]
|
||||||
]
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
if isinstance(_image_sizes_images, torch.Tensor):
|
multimodal_embeddings += vision_embeddings
|
||||||
_image_sizes_images = _image_sizes_images.detach().cpu(
|
if modality == "videos":
|
||||||
).tolist()
|
video_input = modalities["videos"]
|
||||||
_multimodal_embeddings_images = self.forward_images(
|
video_embeddings = self._process_video_input(video_input)
|
||||||
pixel_values_images=_pixel_values_images,
|
multimodal_embeddings += video_embeddings
|
||||||
image_sizes_images=_image_sizes_images,
|
|
||||||
len_pixel_values_images=_len_pixel_values_images,
|
|
||||||
)
|
|
||||||
_multimodal_embeddings_images = torch.cat(
|
|
||||||
_multimodal_embeddings_images, dim=0)
|
|
||||||
multimodal_embeddings.append(_multimodal_embeddings_images)
|
|
||||||
|
|
||||||
if kwargs.get("pixel_values_videos") is not None:
|
|
||||||
for _pixel_values_videos, _vision_query_lengths_videos in zip(
|
|
||||||
kwargs["pixel_values_videos"],
|
|
||||||
kwargs["vision_query_lengths_videos"]):
|
|
||||||
_len_pixel_values_videos = [
|
|
||||||
len(_vision_query_lengths)
|
|
||||||
for _vision_query_lengths in _vision_query_lengths_videos
|
|
||||||
]
|
|
||||||
_c, _w, _h = _pixel_values_videos.shape[-3:]
|
|
||||||
_pixel_values_videos = _pixel_values_videos.reshape(
|
|
||||||
sum(_len_pixel_values_videos), -1, _c, _w,
|
|
||||||
_h).unsqueeze(dim=0)
|
|
||||||
_multimodal_embeddings_videos = self.forward_videos(
|
|
||||||
pixel_values_videos=_pixel_values_videos,
|
|
||||||
len_pixel_values_videos=_len_pixel_values_videos,
|
|
||||||
)
|
|
||||||
_multimodal_embeddings_videos = torch.cat(
|
|
||||||
_multimodal_embeddings_videos, dim=0)
|
|
||||||
multimodal_embeddings.append(_multimodal_embeddings_videos)
|
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -762,28 +762,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
def forward_images(
|
def forward_images(
|
||||||
self,
|
self,
|
||||||
pixel_values_images: list[list[torch.FloatTensor]],
|
pixel_values_images: list[torch.Tensor],
|
||||||
image_sizes_images: list[list[tuple[int, int]]],
|
image_sizes_images: torch.Tensor,
|
||||||
len_pixel_values_images: list[int],
|
) -> tuple[torch.Tensor, ...]:
|
||||||
) -> list[list[torch.Tensor]]:
|
pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True)
|
||||||
if sum(len_pixel_values_images) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
concat_pixel_values_images = torch.cat(list(
|
|
||||||
chain(*pixel_values_images)),
|
|
||||||
dim=0)
|
|
||||||
|
|
||||||
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
|
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
|
||||||
image_forward_outs = self.vision_model(
|
image_forward_outs = self.vision_model(
|
||||||
concat_pixel_values_images)[:, visual_token_idx:]
|
pixel_values_image_flat)[:, visual_token_idx:]
|
||||||
|
|
||||||
image_forward_outs = image_forward_outs.to(
|
image_forward_outs = image_forward_outs.to(
|
||||||
dtype=self.mm_projector.dtype)
|
dtype=self.mm_projector.dtype)
|
||||||
image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d
|
image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d
|
||||||
|
|
||||||
split_sizes = [
|
split_sizes = [len(item) for item in pixel_values_images]
|
||||||
pixel_value.shape[0] for pixel_value in chain(*pixel_values_images)
|
|
||||||
]
|
|
||||||
image_forward_outs = torch.split(image_forward_outs,
|
image_forward_outs = torch.split(image_forward_outs,
|
||||||
split_sizes,
|
split_sizes,
|
||||||
dim=0)
|
dim=0)
|
||||||
@ -791,10 +783,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
# newline for anyres postprocessing
|
# newline for anyres postprocessing
|
||||||
image_features = anyres_postprocessing(
|
image_features = anyres_postprocessing(
|
||||||
image_forward_outs=image_forward_outs,
|
image_forward_outs=image_forward_outs,
|
||||||
image_sizes=[
|
image_sizes=image_sizes_images.tolist(),
|
||||||
image_size for image_sizes in image_sizes_images
|
|
||||||
for image_size in image_sizes
|
|
||||||
],
|
|
||||||
num_queries_vis_abstractor=self.config.
|
num_queries_vis_abstractor=self.config.
|
||||||
num_queries_vis_abstractor_image,
|
num_queries_vis_abstractor_image,
|
||||||
unpad=self.config.unpad,
|
unpad=self.config.unpad,
|
||||||
@ -803,26 +792,21 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
image_newline=self.image_newline,
|
image_newline=self.image_newline,
|
||||||
possible_resolutions=self.config.possible_resolutions,
|
possible_resolutions=self.config.possible_resolutions,
|
||||||
)
|
)
|
||||||
return image_features
|
|
||||||
|
return tuple(image_features)
|
||||||
|
|
||||||
def forward_videos(
|
def forward_videos(
|
||||||
self,
|
self,
|
||||||
pixel_values_videos: list[list[torch.FloatTensor]],
|
pixel_values_videos: list[list[torch.Tensor]],
|
||||||
len_pixel_values_videos: list[int],
|
) -> tuple[torch.Tensor, ...]:
|
||||||
) -> list[torch.Tensor]:
|
pixel_values_videos_flat = flatten_bn(
|
||||||
|
[frame for frames in pixel_values_videos for frame in frames],
|
||||||
len_video_grids = sum(len_pixel_values_videos)
|
concat=True,
|
||||||
if len_video_grids == 0:
|
)
|
||||||
return None
|
|
||||||
|
|
||||||
# Run Vision Model
|
|
||||||
concat_pixel_values_videos = torch.cat(list(
|
|
||||||
chain(*pixel_values_videos)),
|
|
||||||
dim=0)
|
|
||||||
|
|
||||||
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
|
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
|
||||||
video_forward_outs = self.vision_model(
|
video_forward_outs = self.vision_model(
|
||||||
concat_pixel_values_videos)[:, visual_token_idx:]
|
pixel_values_videos_flat)[:, visual_token_idx:]
|
||||||
|
|
||||||
video_forward_outs = video_forward_outs.to(
|
video_forward_outs = video_forward_outs.to(
|
||||||
dtype=self.mm_projector.dtype)
|
dtype=self.mm_projector.dtype)
|
||||||
@ -905,7 +889,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
) == 0, f"target_features is not empty!! {target_features}"
|
) == 0, f"target_features is not empty!! {target_features}"
|
||||||
assert len(video_groups) == len(video_features)
|
assert len(video_groups) == len(video_features)
|
||||||
|
|
||||||
return video_features
|
feats_per_video = [len(video) for video in pixel_values_videos]
|
||||||
|
idxs_per_video = [0, *accumulate(feats_per_video)]
|
||||||
|
return tuple(
|
||||||
|
torch.cat(video_features[idxs_per_video[i]:idxs_per_video[i + 1]])
|
||||||
|
for i in range(len(feats_per_video)))
|
||||||
|
|
||||||
def _prepare_multimodal_kwargs(self, **kwargs: object):
|
def _prepare_multimodal_kwargs(self, **kwargs: object):
|
||||||
output = defaultdict(list)
|
output = defaultdict(list)
|
||||||
@ -1111,15 +1099,15 @@ def reshape_and_unpad_image_features(
|
|||||||
|
|
||||||
|
|
||||||
def anyres_postprocessing(
|
def anyres_postprocessing(
|
||||||
image_forward_outs: list[torch.FloatTensor],
|
image_forward_outs: list[torch.Tensor],
|
||||||
image_sizes: list[list[int]],
|
image_sizes: list[list[int]],
|
||||||
possible_resolutions: list[tuple[int, int]],
|
possible_resolutions: list[tuple[int, int]],
|
||||||
patch_size: int,
|
patch_size: int,
|
||||||
grid_size: int,
|
grid_size: int,
|
||||||
image_newline: torch.FloatTensor,
|
image_newline: torch.Tensor,
|
||||||
num_queries_vis_abstractor: int = -1,
|
num_queries_vis_abstractor: int = -1,
|
||||||
unpad: bool = False,
|
unpad: bool = False,
|
||||||
) -> list[torch.FloatTensor]:
|
) -> list[torch.Tensor]:
|
||||||
height = width = grid_size // patch_size
|
height = width = grid_size // patch_size
|
||||||
|
|
||||||
if num_queries_vis_abstractor > 0:
|
if num_queries_vis_abstractor > 0:
|
||||||
@ -1147,26 +1135,5 @@ def anyres_postprocessing(
|
|||||||
(image_feature, image_newline[None].to(image_feature.device)),
|
(image_feature, image_newline[None].to(image_feature.device)),
|
||||||
dim=0)
|
dim=0)
|
||||||
new_image_features.append(image_feature)
|
new_image_features.append(image_feature)
|
||||||
image_features = new_image_features
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
|
return new_image_features
|
||||||
def resize_image(
|
|
||||||
image: Union[np.ndarray, PIL.Image.Image],
|
|
||||||
max_side: int = 378,
|
|
||||||
) -> np.ndarray:
|
|
||||||
image_arr = image
|
|
||||||
if isinstance(image, np.ndarray):
|
|
||||||
image = Image.fromarray(image)
|
|
||||||
|
|
||||||
width, height = image.size
|
|
||||||
cur_max_size = max(width, height)
|
|
||||||
if cur_max_size <= max_side:
|
|
||||||
return image_arr
|
|
||||||
|
|
||||||
scale = max_side / cur_max_size
|
|
||||||
width = int(width * scale)
|
|
||||||
height = int(height * scale)
|
|
||||||
image = image.resize((width, height), Image.LANCZOS)
|
|
||||||
image_arr = np.array(image)
|
|
||||||
return image_arr
|
|
||||||
|
|||||||
@ -109,7 +109,7 @@ class Phi3VImagePixelInputs(TensorSchema):
|
|||||||
type: Literal["pixel_values", "image_embeds"] = "pixel_values"
|
type: Literal["pixel_values", "image_embeds"] = "pixel_values"
|
||||||
|
|
||||||
# Supports either a stacked tensor or a list of (p, 3, h, w) tensors
|
# Supports either a stacked tensor or a list of (p, 3, h, w) tensors
|
||||||
data: Annotated[
|
pixel_values: Annotated[
|
||||||
Union[torch.Tensor, list[torch.Tensor]],
|
Union[torch.Tensor, list[torch.Tensor]],
|
||||||
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
|
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
|
||||||
), # 'p' may vary across items
|
), # 'p' may vary across items
|
||||||
@ -594,7 +594,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
return Phi3VImagePixelInputs(
|
return Phi3VImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=flatten_bn(pixel_values),
|
pixel_values=flatten_bn(pixel_values),
|
||||||
image_sizes=flatten_bn(image_sizes, concat=True),
|
image_sizes=flatten_bn(image_sizes, concat=True),
|
||||||
resolve_bindings={
|
resolve_bindings={
|
||||||
"h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
|
"h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
|
||||||
@ -628,7 +628,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert self.vision_embed_tokens is not None
|
assert self.vision_embed_tokens is not None
|
||||||
image_embeds = self.vision_embed_tokens(image_input["data"],
|
image_embeds = self.vision_embed_tokens(image_input["pixel_values"],
|
||||||
image_input["image_sizes"])
|
image_input["image_sizes"])
|
||||||
|
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|||||||
@ -94,34 +94,63 @@ class TensorSchema:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _validate_nested_tensors(
|
def _fmt_indexer(self, idxs: tuple[int, ...]) -> str:
|
||||||
self,
|
if not idxs:
|
||||||
value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
|
return ""
|
||||||
field_name: str,
|
|
||||||
expected_shape: tuple[Union[int, str], ...],
|
return str(list(idxs))
|
||||||
dynamic_dims: set[str],
|
|
||||||
|
def _validate_field(
|
||||||
|
self,
|
||||||
|
value: object,
|
||||||
|
field_name: str,
|
||||||
|
expected_shape: tuple[Union[int, str], ...],
|
||||||
|
dynamic_dims: set[str],
|
||||||
|
leading_idxs: tuple[int, ...] = (),
|
||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
"""Validate a list/tuple of tensors and return the actual shape."""
|
"""Validate a field and return the actual shape."""
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return () # Scalar
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
return value.shape
|
||||||
|
|
||||||
|
if not isinstance(value, (list, tuple)):
|
||||||
|
raise TypeError(
|
||||||
|
f"{field_name}{self._fmt_indexer(leading_idxs)} is not "
|
||||||
|
f"one of the expected types: int, float, Tensor, list, tuple. "
|
||||||
|
f"Got: {type(value)}")
|
||||||
|
|
||||||
|
if len(value) == 0:
|
||||||
|
raise ValueError(f"{field_name}{self._fmt_indexer(leading_idxs)} "
|
||||||
|
f"is an empty sequence")
|
||||||
|
|
||||||
# Ensure all tensors in the list have the same
|
# Ensure all tensors in the list have the same
|
||||||
# shape, besides dynamic dimensions
|
# shape, besides dynamic dimensions
|
||||||
first = value[0]
|
|
||||||
for i, v in enumerate(value):
|
for i, v in enumerate(value):
|
||||||
if not isinstance(v, torch.Tensor):
|
shape = self._validate_field(
|
||||||
raise ValueError(f"{field_name}[{i}] is not a "
|
v,
|
||||||
f"torch.Tensor")
|
field_name,
|
||||||
if not self._match_shape_with_dynamic(
|
expected_shape[1:],
|
||||||
v.shape,
|
dynamic_dims,
|
||||||
first.shape,
|
leading_idxs=leading_idxs + (i, ),
|
||||||
|
)
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
first_shape = shape
|
||||||
|
elif not self._match_shape_with_dynamic(
|
||||||
|
shape,
|
||||||
|
first_shape,
|
||||||
expected_shape,
|
expected_shape,
|
||||||
dynamic_dims,
|
dynamic_dims,
|
||||||
):
|
):
|
||||||
raise ValueError(f"{field_name} contains inconsistent "
|
raise ValueError(
|
||||||
f"shapes: {first.shape} vs {v.shape} "
|
f"{field_name}{self._fmt_indexer(leading_idxs)} "
|
||||||
f"at index {i}")
|
f"contains inconsistent shapes: {first_shape} "
|
||||||
|
f"(index 0) vs {shape} (index {i})")
|
||||||
|
|
||||||
# Treat the list as a stacked tensor:
|
# Treat the list as a stacked tensor:
|
||||||
# shape = (len(list), *tensor.shape)
|
# shape = (len(list), *tensor.shape)
|
||||||
return (len(value), ) + first.shape
|
return (len(value), ) + first_shape
|
||||||
|
|
||||||
def _validate_tensor_shape_expected(
|
def _validate_tensor_shape_expected(
|
||||||
self,
|
self,
|
||||||
@ -187,36 +216,12 @@ class TensorSchema:
|
|||||||
for arg in args:
|
for arg in args:
|
||||||
if isinstance(arg, TensorShape):
|
if isinstance(arg, TensorShape):
|
||||||
expected_shape = arg.resolve(**self._resolve_bindings)
|
expected_shape = arg.resolve(**self._resolve_bindings)
|
||||||
if isinstance(value, (list, tuple)):
|
actual_shape = self._validate_field(
|
||||||
# list/tuple of Tensors → shape = (len(value), ...)
|
value,
|
||||||
if value and isinstance(value[0], torch.Tensor):
|
field_name,
|
||||||
actual_shape = self._validate_nested_tensors(
|
expected_shape,
|
||||||
value, field_name, expected_shape,
|
arg.dynamic_dims,
|
||||||
arg.dynamic_dims)
|
)
|
||||||
elif value:
|
|
||||||
# list/tuple of scalars → shape = (len(value),)
|
|
||||||
actual_shape = (len(value), )
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"{field_name} is an empty list")
|
|
||||||
|
|
||||||
# Tensor → shape = tensor.shape
|
|
||||||
elif isinstance(value, torch.Tensor):
|
|
||||||
actual_shape = value.shape
|
|
||||||
|
|
||||||
# Otherwise, it's an unsupported type
|
|
||||||
else:
|
|
||||||
type_names = []
|
|
||||||
for arg in args:
|
|
||||||
if hasattr(arg, "__name__"):
|
|
||||||
type_names.append(str(arg.__name__))
|
|
||||||
else:
|
|
||||||
type_names.append(str(arg))
|
|
||||||
|
|
||||||
expected_types = ", ".join(type_names)
|
|
||||||
raise ValueError(
|
|
||||||
f"{field_name} is not one of the expected "
|
|
||||||
f"types: {expected_types}")
|
|
||||||
|
|
||||||
self._validate_tensor_shape_expected(
|
self._validate_tensor_shape_expected(
|
||||||
actual_shape, expected_shape, field_name,
|
actual_shape, expected_shape, field_name,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user