mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:05:35 +08:00
[Model] Support nested structures for TensorSchema (#26212)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
d3d649efec
commit
44ea85137a
@ -6,37 +6,39 @@ import torch
|
||||
|
||||
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
|
||||
from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
|
||||
from vllm.model_executor.models.hyperclovax_vision import (
|
||||
HCXVisionVideoPixelInputs)
|
||||
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
|
||||
|
||||
|
||||
def test_tensor_schema_valid_tensor():
|
||||
Phi3VImagePixelInputs(
|
||||
data=torch.randn(16, 64, 3, 32, 32),
|
||||
pixel_values=torch.randn(16, 64, 3, 32, 32),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_optional_fields():
|
||||
Phi3VImagePixelInputs(
|
||||
data=torch.randn(16, 64, 3, 32, 32),
|
||||
pixel_values=torch.randn(16, 64, 3, 32, 32),
|
||||
image_sizes=None,
|
||||
)
|
||||
|
||||
Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), )
|
||||
Phi3VImagePixelInputs(pixel_values=torch.randn(16, 64, 3, 32, 32))
|
||||
|
||||
|
||||
def test_tensor_schema_constant_dim_failure():
|
||||
with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"):
|
||||
Phi3VImagePixelInputs(
|
||||
data=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
|
||||
pixel_values=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_invalid_types_in_list():
|
||||
with pytest.raises(ValueError, match="is not a torch.Tensor"):
|
||||
with pytest.raises(TypeError, match="is not one of the expected types"):
|
||||
Phi3VImagePixelInputs(
|
||||
data=[
|
||||
pixel_values=[
|
||||
torch.randn(64, 3, 32, 32),
|
||||
"not_a_tensor",
|
||||
torch.randn(64, 3, 32, 32),
|
||||
@ -48,27 +50,28 @@ def test_tensor_schema_invalid_types_in_list():
|
||||
def test_tensor_schema_rank_mismatch():
|
||||
with pytest.raises(ValueError, match="has rank 3 but expected 5"):
|
||||
Phi3VImagePixelInputs(
|
||||
data=torch.randn(16, 64, 3),
|
||||
pixel_values=torch.randn(16, 64, 3),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_missing_required_field():
|
||||
with pytest.raises(ValueError, match="Required field 'data' is missing"):
|
||||
with pytest.raises(ValueError,
|
||||
match="Required field 'pixel_values' is missing"):
|
||||
Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), )
|
||||
|
||||
|
||||
def test_tensor_schema_symbolic_dim_mismatch():
|
||||
with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
|
||||
Phi3VImagePixelInputs(
|
||||
data=torch.randn(12, 64, 3, 32, 32),
|
||||
pixel_values=torch.randn(12, 64, 3, 32, 32),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_list_tensor_valid():
|
||||
Phi3VImagePixelInputs(
|
||||
data=[torch.randn(64, 3, 32, 32) for _ in range(16)],
|
||||
pixel_values=[torch.randn(64, 3, 32, 32) for _ in range(16)],
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
@ -76,39 +79,46 @@ def test_tensor_schema_list_tensor_valid():
|
||||
def test_tensor_schema_variable_patch_counts_valid():
|
||||
# Each image has a different number of patches (p)
|
||||
# Each tensor has shape (p, 3, 32, 32)
|
||||
data = [
|
||||
torch.randn(16, 3, 32, 32), # p = 16
|
||||
torch.randn(32, 3, 32, 32), # p = 32
|
||||
torch.randn(64, 3, 32, 32), # p = 64
|
||||
]
|
||||
image_sizes = torch.randint(0, 256, (3, 2)) # bn = 3
|
||||
Phi3VImagePixelInputs(
|
||||
data=data,
|
||||
image_sizes=image_sizes,
|
||||
pixel_values=[
|
||||
torch.randn(16, 3, 32, 32), # p = 16
|
||||
torch.randn(32, 3, 32, 32), # p = 32
|
||||
torch.randn(64, 3, 32, 32), # p = 64
|
||||
],
|
||||
image_sizes=torch.randint(0, 256, (3, 2)), # bn = 3
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_tuple_tensor_valid():
|
||||
Phi3VImagePixelInputs(
|
||||
data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
|
||||
pixel_values=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_double_nested_tensors():
|
||||
x = torch.rand(4, 3, 32, 32)
|
||||
y = torch.rand(2, 3, 32, 32)
|
||||
|
||||
HCXVisionVideoPixelInputs(pixel_values_videos=([x, y, x], [y], [x, y]))
|
||||
|
||||
|
||||
def test_tensor_schema_inconsistent_shapes_in_list():
|
||||
with pytest.raises(ValueError, match="contains inconsistent shapes"):
|
||||
Phi3VImagePixelInputs(
|
||||
data=[torch.randn(64, 3, 32, 32),
|
||||
torch.randn(64, 3, 16, 16)] +
|
||||
[torch.randn(64, 3, 32, 32) for _ in range(14)],
|
||||
pixel_values=[
|
||||
torch.randn(64, 3, 32, 32),
|
||||
torch.randn(64, 3, 16, 16),
|
||||
*(torch.randn(64, 3, 32, 32) for _ in range(14)),
|
||||
],
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_empty_list():
|
||||
with pytest.raises(ValueError, match="is an empty list"):
|
||||
with pytest.raises(ValueError, match="is an empty sequence"):
|
||||
Phi3VImagePixelInputs(
|
||||
data=[],
|
||||
pixel_values=[],
|
||||
image_sizes=torch.randint(0, 256, (0, 2)),
|
||||
)
|
||||
|
||||
@ -117,18 +127,18 @@ def test_tensor_schema_validation_disabled_skips_shape_check():
|
||||
# This should NOT raise, because validation is turned off
|
||||
# This would normally fail (dim[2] should be 3, not 4)
|
||||
Phi3VImagePixelInputs(
|
||||
data=torch.randn(16, 64, 4, 32, 32),
|
||||
pixel_values=torch.randn(16, 64, 4, 32, 32),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
validate=False,
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_with_valid_resolve_binding_dims():
|
||||
data = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
|
||||
pixel_values = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
|
||||
image_sizes = torch.randint(0, 256, (16, 2))
|
||||
|
||||
Phi3VImagePixelInputs(
|
||||
data=data,
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
resolve_bindings={
|
||||
"h": 336,
|
||||
@ -138,13 +148,13 @@ def test_tensor_schema_with_valid_resolve_binding_dims():
|
||||
|
||||
|
||||
def test_tensor_schema_with_invalid_resolve_binding_dims():
|
||||
data = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
|
||||
pixel_values = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
|
||||
image_sizes = torch.randint(0, 256, (16, 2))
|
||||
|
||||
# Should raise because 'h' and 'w' don't match resolve bindings
|
||||
with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"):
|
||||
Phi3VImagePixelInputs(
|
||||
data=data,
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
resolve_bindings={
|
||||
"h": 336,
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union, override
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -1170,7 +1170,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
|
||||
"video.height override (%d) exceeds model's "
|
||||
"maximum height (%d), will be ignored",
|
||||
overrides.height, height)
|
||||
height = min(height, override.height)
|
||||
height = min(height, overrides.height)
|
||||
|
||||
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
||||
video_items = []
|
||||
|
||||
@ -2,27 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# copied from : https://github.com/huggingface/transformers
|
||||
import ast
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
from itertools import accumulate
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import typing
|
||||
Unpack = typing.Unpack
|
||||
else:
|
||||
import typing_extensions
|
||||
Unpack = typing_extensions.Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from timm.layers import LayerNorm, LayerNorm2d
|
||||
from timm.models.regnet import RegStage
|
||||
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
|
||||
@ -42,11 +31,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
EOT = "<|endofturn|>"
|
||||
@ -69,28 +60,42 @@ def get_num_combined_frames(
|
||||
return num_canvases + (leftover_frames > 0)
|
||||
|
||||
|
||||
class HCXVisionMultimodalPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values_images: list[torch.Tensor]
|
||||
class HCXVisionImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
|
||||
|
||||
Note that `height` or `width` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
Dimensions:
|
||||
- n: Number of images
|
||||
- g: Number of grids
|
||||
- c: Number of channels (3)
|
||||
- h: Height
|
||||
- w: Width
|
||||
"""
|
||||
image_sizes_images: list[tuple[Union[int, float]]]
|
||||
"""
|
||||
Shape: `[(height, width), ...]`
|
||||
"""
|
||||
vision_query_lengths_images: list[Union[int, float]]
|
||||
pixel_values_videos: list[tuple[Union[int, float]]]
|
||||
"""
|
||||
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
|
||||
"""
|
||||
vision_query_lengths_videos: list[Union[int, float]]
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
pixel_values_images: Annotated[
|
||||
list[torch.Tensor],
|
||||
TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})]
|
||||
image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)]
|
||||
|
||||
|
||||
HCXVisionMultimodalInputs = Union[HCXVisionMultimodalPixelInputs]
|
||||
HCXVisionImageInputs = HCXVisionImagePixelInputs
|
||||
|
||||
|
||||
class HCXVisionVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- n: Number of videos
|
||||
- f: Number of frames
|
||||
- g: Number of grids
|
||||
- c: Number of channels (3)
|
||||
- h: Height
|
||||
- w: Width
|
||||
"""
|
||||
type: Literal["pixel_values_videos"] = "pixel_values_videos"
|
||||
pixel_values_videos: Annotated[
|
||||
list[list[torch.Tensor]],
|
||||
TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"})]
|
||||
|
||||
|
||||
HCXVisionVideoInputs = HCXVisionVideoPixelInputs
|
||||
|
||||
|
||||
class HCXVisionProcessingInfo(BaseProcessingInfo):
|
||||
@ -191,27 +196,9 @@ class HCXVisionMultiModalProcessor(
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
|
||||
def replace_multimodal_token(
|
||||
token_ids: torch.Tensor,
|
||||
target_token: int,
|
||||
repeats: list[int],
|
||||
):
|
||||
output = list[int]()
|
||||
_repeats_idx = 0
|
||||
for token_id in token_ids:
|
||||
if token_id == target_token:
|
||||
output += [token_id.item()] * repeats[_repeats_idx]
|
||||
_repeats_idx += 1
|
||||
else:
|
||||
output += [token_id.item()]
|
||||
|
||||
return torch.tensor(output, device=token_ids.device)
|
||||
|
||||
for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
|
||||
if video_arr.dtype == np.uint8:
|
||||
continue
|
||||
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
|
||||
if video_arr.dtype != np.uint8:
|
||||
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
|
||||
|
||||
processed_outputs = self.info.ctx.call_hf_processor(
|
||||
hf_processor=self.info.get_hf_processor(**mm_kwargs),
|
||||
@ -223,20 +210,16 @@ class HCXVisionMultiModalProcessor(
|
||||
) # text-only
|
||||
|
||||
if len(mm_data) > 0:
|
||||
images = mm_data.get("images")
|
||||
videos = mm_data.get("videos")
|
||||
|
||||
# batchify input as a single item
|
||||
images = mm_data.get("images", None)
|
||||
batched_images = None if images is None else [images]
|
||||
|
||||
# list of video in single conversation
|
||||
videos = mm_data.get("videos", None)
|
||||
batched_videos = None if videos is None else [videos]
|
||||
|
||||
_processed_outputs = self.info.ctx.call_hf_processor(
|
||||
hf_processor=self.info.get_hf_processor(**mm_kwargs),
|
||||
data=dict(
|
||||
text=None,
|
||||
images=batched_images,
|
||||
videos=batched_videos,
|
||||
images=None if images is None else [images],
|
||||
videos=None if videos is None else [videos],
|
||||
),
|
||||
) # mm-only
|
||||
|
||||
@ -246,51 +229,43 @@ class HCXVisionMultiModalProcessor(
|
||||
_processed_outputs[k] = v[0]
|
||||
|
||||
if images:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
processed_outputs["input_ids"] = torch.stack([
|
||||
replace_multimodal_token(
|
||||
token_ids=_input_ids,
|
||||
target_token=image_token_id,
|
||||
repeats=_processed_outputs[
|
||||
"vision_query_lengths_images"],
|
||||
) for _input_ids in processed_outputs["input_ids"]
|
||||
],
|
||||
dim=0)
|
||||
_processed_outputs["image_sizes_images"] = torch.tensor(
|
||||
_processed_outputs["image_sizes_images"])
|
||||
_processed_outputs[
|
||||
"vision_query_lengths_images"] = torch.tensor(
|
||||
_processed_outputs["vision_query_lengths_images"])
|
||||
|
||||
if videos:
|
||||
_num_per_videos = [
|
||||
get_num_combined_frames(len(video)) for video in videos
|
||||
_idx_per_video = [
|
||||
0, *accumulate(
|
||||
get_num_combined_frames(len(video))
|
||||
for video in videos)
|
||||
]
|
||||
_processed_outputs["pixel_values_videos"] = [
|
||||
_processed_outputs["pixel_values_videos"]
|
||||
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
|
||||
for _i in range(len(videos))
|
||||
[_idx_per_video[i]:_idx_per_video[i + 1]]
|
||||
for i in range(len(videos))
|
||||
]
|
||||
_processed_outputs["vision_query_lengths_videos"] = [
|
||||
_processed_outputs["vision_query_lengths_videos"]
|
||||
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
|
||||
for _i in range(len(videos))
|
||||
torch.tensor(
|
||||
_processed_outputs["vision_query_lengths_videos"]
|
||||
[_idx_per_video[i]:_idx_per_video[i + 1]])
|
||||
for i in range(len(videos))
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN)
|
||||
processed_outputs["input_ids"] = torch.stack([
|
||||
replace_multimodal_token(
|
||||
token_ids=_input_ids,
|
||||
target_token=video_token_id,
|
||||
repeats=[
|
||||
sum(lens) for lens in
|
||||
_processed_outputs["vision_query_lengths_videos"]
|
||||
],
|
||||
) for _input_ids in processed_outputs["input_ids"]
|
||||
],
|
||||
dim=0)
|
||||
|
||||
processed_outputs.update(_processed_outputs)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
@ -311,11 +286,11 @@ class HCXVisionMultiModalProcessor(
|
||||
out_item = out_mm_kwargs[modality][item_idx]
|
||||
|
||||
if modality == "image":
|
||||
lens = out_item["vision_query_lengths_images"].data
|
||||
lens = out_item["vision_query_lengths_images"].data.tolist()
|
||||
num_tokens = self.info.get_num_image_tokens(
|
||||
vision_query_length=lens)
|
||||
elif modality == "video":
|
||||
lens = out_item["vision_query_lengths_videos"].data
|
||||
lens = out_item["vision_query_lengths_videos"].data.tolist()
|
||||
num_tokens = self.info.get_num_video_tokens(
|
||||
vision_query_length=lens)
|
||||
else:
|
||||
@ -343,26 +318,11 @@ class HCXVisionMultiModalProcessor(
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
# image
|
||||
pixel_values_images=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes_images=MultiModalFieldConfig.batched("image"),
|
||||
vision_query_lengths_images=MultiModalFieldConfig.batched("image"),
|
||||
num_queries_vis_abstractors_images=MultiModalFieldConfig.batched(
|
||||
"image"),
|
||||
num_queries_vis_abstractors_slow_images=MultiModalFieldConfig.
|
||||
batched("image"),
|
||||
first_last_frames_slows_images=MultiModalFieldConfig.batched(
|
||||
"image"),
|
||||
# video
|
||||
pixel_values_videos=MultiModalFieldConfig.batched("video"),
|
||||
image_sizes_videos=MultiModalFieldConfig.batched("video"),
|
||||
vision_query_lengths_videos=MultiModalFieldConfig.batched("video"),
|
||||
num_queries_vis_abstractors_videos=MultiModalFieldConfig.batched(
|
||||
"video"),
|
||||
num_queries_vis_abstractors_slow_videos=MultiModalFieldConfig.
|
||||
batched("video"),
|
||||
first_last_frames_slows_videos=MultiModalFieldConfig.batched(
|
||||
"video"),
|
||||
)
|
||||
|
||||
|
||||
@ -617,6 +577,7 @@ class HCXVisionCAbstractor(nn.Module):
|
||||
info=_build_hcxvision_hf_info,
|
||||
dummy_inputs=HCXVisionDummyInputsBuilder)
|
||||
class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
@ -692,55 +653,94 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> Optional[HCXVisionImageInputs]:
|
||||
pixel_values_images = kwargs.pop("pixel_values_images", None)
|
||||
|
||||
if pixel_values_images is None:
|
||||
return None
|
||||
|
||||
image_sizes_images = kwargs.pop("image_sizes_images")
|
||||
|
||||
return HCXVisionImagePixelInputs(
|
||||
pixel_values_images=pixel_values_images,
|
||||
image_sizes_images=image_sizes_images,
|
||||
)
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> Optional[HCXVisionVideoInputs]:
|
||||
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
|
||||
|
||||
if pixel_values_videos is None:
|
||||
return None
|
||||
|
||||
return HCXVisionVideoPixelInputs(
|
||||
pixel_values_videos=pixel_values_videos, )
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: HCXVisionImageInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
return self.forward_images(
|
||||
pixel_values_images=image_input["pixel_values_images"],
|
||||
image_sizes_images=image_input["image_sizes_images"],
|
||||
)
|
||||
|
||||
def _process_video_input(
|
||||
self,
|
||||
video_input: HCXVisionVideoInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
return self.forward_videos(
|
||||
pixel_values_videos=video_input["pixel_values_videos"], )
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if (input_key == "pixel_values_images"
|
||||
and "images" not in modalities):
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if (input_key == "pixel_values_videos"
|
||||
and "videos" not in modalities):
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
|
||||
return modalities
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self,
|
||||
**kwargs: Unpack[HCXVisionMultimodalInputs],
|
||||
**kwargs: object,
|
||||
) -> MultiModalEmbeddings:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return []
|
||||
|
||||
multimodal_embeddings = list()
|
||||
if kwargs.get("pixel_values_images") is not None:
|
||||
for _pixel_values_images, _image_sizes_images in zip(
|
||||
kwargs["pixel_values_images"],
|
||||
kwargs["image_sizes_images"]):
|
||||
_pixel_values_images = _pixel_values_images.unsqueeze(dim=0)
|
||||
_image_sizes_images = _image_sizes_images.unsqueeze(dim=0)
|
||||
_len_pixel_values_images = [
|
||||
len(pixel_value) for pixel_value in _pixel_values_images
|
||||
]
|
||||
if isinstance(_image_sizes_images, torch.Tensor):
|
||||
_image_sizes_images = _image_sizes_images.detach().cpu(
|
||||
).tolist()
|
||||
_multimodal_embeddings_images = self.forward_images(
|
||||
pixel_values_images=_pixel_values_images,
|
||||
image_sizes_images=_image_sizes_images,
|
||||
len_pixel_values_images=_len_pixel_values_images,
|
||||
)
|
||||
_multimodal_embeddings_images = torch.cat(
|
||||
_multimodal_embeddings_images, dim=0)
|
||||
multimodal_embeddings.append(_multimodal_embeddings_images)
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
# tensor correspoending to a multimodal data item (image or video).
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_video_input(video_input)
|
||||
multimodal_embeddings += video_embeddings
|
||||
|
||||
if kwargs.get("pixel_values_videos") is not None:
|
||||
for _pixel_values_videos, _vision_query_lengths_videos in zip(
|
||||
kwargs["pixel_values_videos"],
|
||||
kwargs["vision_query_lengths_videos"]):
|
||||
_len_pixel_values_videos = [
|
||||
len(_vision_query_lengths)
|
||||
for _vision_query_lengths in _vision_query_lengths_videos
|
||||
]
|
||||
_c, _w, _h = _pixel_values_videos.shape[-3:]
|
||||
_pixel_values_videos = _pixel_values_videos.reshape(
|
||||
sum(_len_pixel_values_videos), -1, _c, _w,
|
||||
_h).unsqueeze(dim=0)
|
||||
_multimodal_embeddings_videos = self.forward_videos(
|
||||
pixel_values_videos=_pixel_values_videos,
|
||||
len_pixel_values_videos=_len_pixel_values_videos,
|
||||
)
|
||||
_multimodal_embeddings_videos = torch.cat(
|
||||
_multimodal_embeddings_videos, dim=0)
|
||||
multimodal_embeddings.append(_multimodal_embeddings_videos)
|
||||
return multimodal_embeddings
|
||||
|
||||
def forward(
|
||||
@ -762,28 +762,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def forward_images(
|
||||
self,
|
||||
pixel_values_images: list[list[torch.FloatTensor]],
|
||||
image_sizes_images: list[list[tuple[int, int]]],
|
||||
len_pixel_values_images: list[int],
|
||||
) -> list[list[torch.Tensor]]:
|
||||
if sum(len_pixel_values_images) == 0:
|
||||
return None
|
||||
|
||||
concat_pixel_values_images = torch.cat(list(
|
||||
chain(*pixel_values_images)),
|
||||
dim=0)
|
||||
pixel_values_images: list[torch.Tensor],
|
||||
image_sizes_images: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True)
|
||||
|
||||
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
|
||||
image_forward_outs = self.vision_model(
|
||||
concat_pixel_values_images)[:, visual_token_idx:]
|
||||
pixel_values_image_flat)[:, visual_token_idx:]
|
||||
|
||||
image_forward_outs = image_forward_outs.to(
|
||||
dtype=self.mm_projector.dtype)
|
||||
image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d
|
||||
|
||||
split_sizes = [
|
||||
pixel_value.shape[0] for pixel_value in chain(*pixel_values_images)
|
||||
]
|
||||
split_sizes = [len(item) for item in pixel_values_images]
|
||||
image_forward_outs = torch.split(image_forward_outs,
|
||||
split_sizes,
|
||||
dim=0)
|
||||
@ -791,10 +783,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# newline for anyres postprocessing
|
||||
image_features = anyres_postprocessing(
|
||||
image_forward_outs=image_forward_outs,
|
||||
image_sizes=[
|
||||
image_size for image_sizes in image_sizes_images
|
||||
for image_size in image_sizes
|
||||
],
|
||||
image_sizes=image_sizes_images.tolist(),
|
||||
num_queries_vis_abstractor=self.config.
|
||||
num_queries_vis_abstractor_image,
|
||||
unpad=self.config.unpad,
|
||||
@ -803,26 +792,21 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_newline=self.image_newline,
|
||||
possible_resolutions=self.config.possible_resolutions,
|
||||
)
|
||||
return image_features
|
||||
|
||||
return tuple(image_features)
|
||||
|
||||
def forward_videos(
|
||||
self,
|
||||
pixel_values_videos: list[list[torch.FloatTensor]],
|
||||
len_pixel_values_videos: list[int],
|
||||
) -> list[torch.Tensor]:
|
||||
|
||||
len_video_grids = sum(len_pixel_values_videos)
|
||||
if len_video_grids == 0:
|
||||
return None
|
||||
|
||||
# Run Vision Model
|
||||
concat_pixel_values_videos = torch.cat(list(
|
||||
chain(*pixel_values_videos)),
|
||||
dim=0)
|
||||
pixel_values_videos: list[list[torch.Tensor]],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
pixel_values_videos_flat = flatten_bn(
|
||||
[frame for frames in pixel_values_videos for frame in frames],
|
||||
concat=True,
|
||||
)
|
||||
|
||||
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
|
||||
video_forward_outs = self.vision_model(
|
||||
concat_pixel_values_videos)[:, visual_token_idx:]
|
||||
pixel_values_videos_flat)[:, visual_token_idx:]
|
||||
|
||||
video_forward_outs = video_forward_outs.to(
|
||||
dtype=self.mm_projector.dtype)
|
||||
@ -905,7 +889,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) == 0, f"target_features is not empty!! {target_features}"
|
||||
assert len(video_groups) == len(video_features)
|
||||
|
||||
return video_features
|
||||
feats_per_video = [len(video) for video in pixel_values_videos]
|
||||
idxs_per_video = [0, *accumulate(feats_per_video)]
|
||||
return tuple(
|
||||
torch.cat(video_features[idxs_per_video[i]:idxs_per_video[i + 1]])
|
||||
for i in range(len(feats_per_video)))
|
||||
|
||||
def _prepare_multimodal_kwargs(self, **kwargs: object):
|
||||
output = defaultdict(list)
|
||||
@ -1111,15 +1099,15 @@ def reshape_and_unpad_image_features(
|
||||
|
||||
|
||||
def anyres_postprocessing(
|
||||
image_forward_outs: list[torch.FloatTensor],
|
||||
image_forward_outs: list[torch.Tensor],
|
||||
image_sizes: list[list[int]],
|
||||
possible_resolutions: list[tuple[int, int]],
|
||||
patch_size: int,
|
||||
grid_size: int,
|
||||
image_newline: torch.FloatTensor,
|
||||
image_newline: torch.Tensor,
|
||||
num_queries_vis_abstractor: int = -1,
|
||||
unpad: bool = False,
|
||||
) -> list[torch.FloatTensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
height = width = grid_size // patch_size
|
||||
|
||||
if num_queries_vis_abstractor > 0:
|
||||
@ -1147,26 +1135,5 @@ def anyres_postprocessing(
|
||||
(image_feature, image_newline[None].to(image_feature.device)),
|
||||
dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = new_image_features
|
||||
return image_features
|
||||
|
||||
|
||||
def resize_image(
|
||||
image: Union[np.ndarray, PIL.Image.Image],
|
||||
max_side: int = 378,
|
||||
) -> np.ndarray:
|
||||
image_arr = image
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
|
||||
width, height = image.size
|
||||
cur_max_size = max(width, height)
|
||||
if cur_max_size <= max_side:
|
||||
return image_arr
|
||||
|
||||
scale = max_side / cur_max_size
|
||||
width = int(width * scale)
|
||||
height = int(height * scale)
|
||||
image = image.resize((width, height), Image.LANCZOS)
|
||||
image_arr = np.array(image)
|
||||
return image_arr
|
||||
return new_image_features
|
||||
|
||||
@ -109,7 +109,7 @@ class Phi3VImagePixelInputs(TensorSchema):
|
||||
type: Literal["pixel_values", "image_embeds"] = "pixel_values"
|
||||
|
||||
# Supports either a stacked tensor or a list of (p, 3, h, w) tensors
|
||||
data: Annotated[
|
||||
pixel_values: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
|
||||
), # 'p' may vary across items
|
||||
@ -594,7 +594,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
if pixel_values is not None:
|
||||
return Phi3VImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=flatten_bn(pixel_values),
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
image_sizes=flatten_bn(image_sizes, concat=True),
|
||||
resolve_bindings={
|
||||
"h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
|
||||
@ -628,7 +628,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
)
|
||||
|
||||
assert self.vision_embed_tokens is not None
|
||||
image_embeds = self.vision_embed_tokens(image_input["data"],
|
||||
image_embeds = self.vision_embed_tokens(image_input["pixel_values"],
|
||||
image_input["image_sizes"])
|
||||
|
||||
return image_embeds
|
||||
|
||||
@ -94,34 +94,63 @@ class TensorSchema:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _validate_nested_tensors(
|
||||
self,
|
||||
value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
|
||||
field_name: str,
|
||||
expected_shape: tuple[Union[int, str], ...],
|
||||
dynamic_dims: set[str],
|
||||
def _fmt_indexer(self, idxs: tuple[int, ...]) -> str:
|
||||
if not idxs:
|
||||
return ""
|
||||
|
||||
return str(list(idxs))
|
||||
|
||||
def _validate_field(
|
||||
self,
|
||||
value: object,
|
||||
field_name: str,
|
||||
expected_shape: tuple[Union[int, str], ...],
|
||||
dynamic_dims: set[str],
|
||||
leading_idxs: tuple[int, ...] = (),
|
||||
) -> tuple[int, ...]:
|
||||
"""Validate a list/tuple of tensors and return the actual shape."""
|
||||
"""Validate a field and return the actual shape."""
|
||||
if isinstance(value, (int, float)):
|
||||
return () # Scalar
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.shape
|
||||
|
||||
if not isinstance(value, (list, tuple)):
|
||||
raise TypeError(
|
||||
f"{field_name}{self._fmt_indexer(leading_idxs)} is not "
|
||||
f"one of the expected types: int, float, Tensor, list, tuple. "
|
||||
f"Got: {type(value)}")
|
||||
|
||||
if len(value) == 0:
|
||||
raise ValueError(f"{field_name}{self._fmt_indexer(leading_idxs)} "
|
||||
f"is an empty sequence")
|
||||
|
||||
# Ensure all tensors in the list have the same
|
||||
# shape, besides dynamic dimensions
|
||||
first = value[0]
|
||||
for i, v in enumerate(value):
|
||||
if not isinstance(v, torch.Tensor):
|
||||
raise ValueError(f"{field_name}[{i}] is not a "
|
||||
f"torch.Tensor")
|
||||
if not self._match_shape_with_dynamic(
|
||||
v.shape,
|
||||
first.shape,
|
||||
shape = self._validate_field(
|
||||
v,
|
||||
field_name,
|
||||
expected_shape[1:],
|
||||
dynamic_dims,
|
||||
leading_idxs=leading_idxs + (i, ),
|
||||
)
|
||||
|
||||
if i == 0:
|
||||
first_shape = shape
|
||||
elif not self._match_shape_with_dynamic(
|
||||
shape,
|
||||
first_shape,
|
||||
expected_shape,
|
||||
dynamic_dims,
|
||||
):
|
||||
raise ValueError(f"{field_name} contains inconsistent "
|
||||
f"shapes: {first.shape} vs {v.shape} "
|
||||
f"at index {i}")
|
||||
raise ValueError(
|
||||
f"{field_name}{self._fmt_indexer(leading_idxs)} "
|
||||
f"contains inconsistent shapes: {first_shape} "
|
||||
f"(index 0) vs {shape} (index {i})")
|
||||
|
||||
# Treat the list as a stacked tensor:
|
||||
# shape = (len(list), *tensor.shape)
|
||||
return (len(value), ) + first.shape
|
||||
return (len(value), ) + first_shape
|
||||
|
||||
def _validate_tensor_shape_expected(
|
||||
self,
|
||||
@ -187,36 +216,12 @@ class TensorSchema:
|
||||
for arg in args:
|
||||
if isinstance(arg, TensorShape):
|
||||
expected_shape = arg.resolve(**self._resolve_bindings)
|
||||
if isinstance(value, (list, tuple)):
|
||||
# list/tuple of Tensors → shape = (len(value), ...)
|
||||
if value and isinstance(value[0], torch.Tensor):
|
||||
actual_shape = self._validate_nested_tensors(
|
||||
value, field_name, expected_shape,
|
||||
arg.dynamic_dims)
|
||||
elif value:
|
||||
# list/tuple of scalars → shape = (len(value),)
|
||||
actual_shape = (len(value), )
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{field_name} is an empty list")
|
||||
|
||||
# Tensor → shape = tensor.shape
|
||||
elif isinstance(value, torch.Tensor):
|
||||
actual_shape = value.shape
|
||||
|
||||
# Otherwise, it's an unsupported type
|
||||
else:
|
||||
type_names = []
|
||||
for arg in args:
|
||||
if hasattr(arg, "__name__"):
|
||||
type_names.append(str(arg.__name__))
|
||||
else:
|
||||
type_names.append(str(arg))
|
||||
|
||||
expected_types = ", ".join(type_names)
|
||||
raise ValueError(
|
||||
f"{field_name} is not one of the expected "
|
||||
f"types: {expected_types}")
|
||||
actual_shape = self._validate_field(
|
||||
value,
|
||||
field_name,
|
||||
expected_shape,
|
||||
arg.dynamic_dims,
|
||||
)
|
||||
|
||||
self._validate_tensor_shape_expected(
|
||||
actual_shape, expected_shape, field_name,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user