mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-29 02:24:47 +08:00
1829 lines
80 KiB
Python
1829 lines
80 KiB
Python
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE.
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
import math
|
|
from typing import Callable, Optional
|
|
import numpy as np
|
|
import random
|
|
from PIL import Image
|
|
import albumentations as A
|
|
|
|
import einops
|
|
import torch
|
|
from torchvision import transforms as T
|
|
from torchvision.transforms import Compose
|
|
from torchvision.transforms.functional import InterpolationMode
|
|
|
|
from data_loading.conversation_sample import (
|
|
ImageMedia,
|
|
VideoFrameMedia,
|
|
)
|
|
|
|
IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
|
|
IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
|
|
SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
|
|
SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
|
|
CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
|
CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
|
|
RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060]
|
|
RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250]
|
|
|
|
|
|
pixel_statistics = {
|
|
"clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
|
"siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
|
|
"internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
|
|
"radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
|
"radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
|
|
"huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
|
|
"radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
|
"cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
|
"cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
|
}
|
|
|
|
|
|
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685
|
|
# Copyright (c) 2023 OpenGVLab.
|
|
def find_closest_aspect_ratio(
|
|
aspect_ratio: float,
|
|
target_ratios: list[tuple[int, int]],
|
|
width: int,
|
|
height: int,
|
|
image_size: int,
|
|
) -> tuple[int, int]:
|
|
best_ratio_diff = float("inf")
|
|
best_ratio = (1, 1)
|
|
area = width * height
|
|
for ratio in target_ratios:
|
|
target_aspect_ratio = ratio[0] / ratio[1]
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
|
if ratio_diff < best_ratio_diff:
|
|
best_ratio_diff = ratio_diff
|
|
best_ratio = ratio
|
|
elif ratio_diff == best_ratio_diff:
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
|
best_ratio = ratio
|
|
return best_ratio
|
|
|
|
|
|
def find_closest_area_weighted_aspect_ratio(
|
|
aspect_ratio: float,
|
|
target_ratios: list[tuple[int, int]],
|
|
width: int,
|
|
height: int,
|
|
image_size: int,
|
|
):
|
|
"""
|
|
Find the best number of tiles based on the aspect ratio and the area covered by the tiles.
|
|
"""
|
|
best_factor = float("-inf")
|
|
best_ratio = (1, 1)
|
|
area = width * height
|
|
for ratio in target_ratios:
|
|
target_aspect_ratio = ratio[0] / ratio[1]
|
|
factor_based_on_area_n_ratio = min(
|
|
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
|
|
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
|
|
if factor_based_on_area_n_ratio > best_factor:
|
|
best_factor = factor_based_on_area_n_ratio
|
|
best_ratio = ratio
|
|
return best_ratio
|
|
|
|
|
|
# Mike's optimized ToTensor.
|
|
def _fast_to_tensor(pic) -> torch.Tensor:
|
|
np_img = np.array(pic, copy=False)
|
|
img = torch.from_numpy(np_img)
|
|
img = img.permute(2, 0, 1) # HWC to CHW
|
|
fp_img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
|
fp_img.div_(255)
|
|
return fp_img
|
|
|
|
|
|
@dataclass
|
|
class ImageTilingParams:
|
|
media: ImageMedia | VideoFrameMedia
|
|
num_tiles: int
|
|
num_embeddings: int
|
|
|
|
|
|
class ImageTilingStrategy(ABC):
|
|
"""
|
|
Base class for image tiling strategies.
|
|
A tiling strategy is a function that takes a list of media and returns a list of image tiling parameters.
|
|
These can then be used to apply the tiling to the media.
|
|
|
|
Subclasses must implement the `compute_params` and `apply_params` methods.
|
|
|
|
The `transform` method is a convenience method that computes the transformation parameters and applies the transformation to the media.
|
|
|
|
"""
|
|
|
|
def transform(
|
|
self,
|
|
media_list: list[ImageMedia | VideoFrameMedia],
|
|
num_tokens_available: int | None = None,
|
|
) -> list[torch.Tensor]:
|
|
"""
|
|
Transform the media and compute the transformation parameters.
|
|
"""
|
|
transform_media_list = self.compute_params(media_list, num_tokens_available)
|
|
return [
|
|
self.apply_params(transform_media, **kwargs)
|
|
for transform_media in transform_media_list
|
|
]
|
|
|
|
@abstractmethod
|
|
def compute_params(
|
|
self, media_list: list[ImageMedia | VideoFrameMedia], num_tokens_available: int, max_num_tiles: int | None = None, **kwargs
|
|
) -> list[ImageTilingParams]:
|
|
"""
|
|
Compute the transformation parameters and the number of tokens to use for the media.
|
|
|
|
Args:
|
|
media_list: List of media to transform
|
|
num_tokens_available: Number of tokens available for all media
|
|
max_num_tiles: Maximum number of tiles allowed (optional, defaults to instance's max_num_tiles if not provided)
|
|
|
|
Returns:
|
|
list of transformation parameters with the media
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]:
|
|
"""
|
|
Apply the transformation parameters to the media.
|
|
|
|
Args:
|
|
transform_media: The media to apply the transformation to
|
|
|
|
Returns:
|
|
list of transformed media tensors
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def stack(
|
|
self, images: list[torch.Tensor]
|
|
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]:
|
|
"""
|
|
Stack the images into a single tensor.
|
|
|
|
Args:
|
|
media_list: List of images to stack
|
|
|
|
Returns:
|
|
tuple of (stacked media, image sizes, vision cu lengths, vision max lengths)
|
|
"""
|
|
...
|
|
|
|
|
|
class _FixedSizeStrategy(ImageTilingStrategy):
|
|
"""
|
|
Base class for fixed size image tiling strategies.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vision_model_type: str,
|
|
target_width: int,
|
|
target_height: int,
|
|
embeddings_per_image: int,
|
|
):
|
|
self._vision_model_type = vision_model_type
|
|
self._target_width = target_width
|
|
self._target_height = target_height
|
|
self._embeddings_per_image = embeddings_per_image
|
|
self._transform = self._build_transform(
|
|
(target_width, target_height), vision_model_type
|
|
)
|
|
|
|
# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
|
|
# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
|
|
@staticmethod
|
|
def _build_transform(target_size: tuple[int, int], vision_model_type: str):
|
|
"""
|
|
Build a transform for a given vision model type and target size.
|
|
"""
|
|
if vision_model_type in ("siglip", "internvit", "radio", "radio-g", "cradio-g"):
|
|
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
|
|
|
|
transform = T.Compose(
|
|
[
|
|
T.Lambda(
|
|
lambda img: img.convert("RGB") if img.mode != "RGB" else img
|
|
),
|
|
T.Resize(
|
|
(target_size[1], target_size[0]),
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)),
|
|
T.Normalize(mean=pixel_mean, std=pixel_std),
|
|
]
|
|
)
|
|
# From the official CLIP repo.
|
|
elif vision_model_type == "clip":
|
|
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
|
|
|
|
transform = Compose(
|
|
[
|
|
T.Resize(
|
|
(target_size[1], target_size[0]),
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
T.Lambda(
|
|
lambda img: img.convert("RGB") if img.mode != "RGB" else img
|
|
),
|
|
T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)),
|
|
T.Normalize(mean=pixel_mean, std=pixel_std),
|
|
]
|
|
)
|
|
elif vision_model_type.startswith("hf://"):
|
|
from megatron.core.models.huggingface.module import get_hf_model_type
|
|
|
|
model_type = get_hf_model_type(vision_model_type)
|
|
if "siglip" in model_type:
|
|
from transformers.models.siglip.image_processing_siglip import (
|
|
SiglipImageProcessor,
|
|
)
|
|
|
|
processor = SiglipImageProcessor(
|
|
size={"height": target_size[1], "width": target_size[0]}
|
|
)
|
|
|
|
def transform(x):
|
|
x = x.convert("RGB") if x.mode != "RGB" else x
|
|
x = processor(x, return_tensors="pt")
|
|
return x["pixel_values"][0]
|
|
else:
|
|
raise NotImplementedError(
|
|
f"image processing not defined for huggingface model {vision_model_type}"
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"image processing not defined for vision model {vision_model_type}"
|
|
)
|
|
|
|
return transform
|
|
|
|
def stack(
|
|
self, images: list[torch.Tensor]
|
|
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
|
|
return (
|
|
torch.stack(images) if len(images) > 0 else None,
|
|
torch.tensor(
|
|
[(img.shape[1], img.shape[2]) for img in images], dtype=torch.int32
|
|
) if len(images) > 0 else None,
|
|
None,
|
|
None,
|
|
)
|
|
|
|
|
|
class NoTilingStrategy(_FixedSizeStrategy):
|
|
"""
|
|
A simple image transformation that resizes the image to the target width and height.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vision_model_type: str,
|
|
target_width: int,
|
|
target_height: int,
|
|
embeddings_per_image: int,
|
|
):
|
|
super().__init__(
|
|
vision_model_type=vision_model_type,
|
|
target_width=target_width,
|
|
target_height=target_height,
|
|
embeddings_per_image=embeddings_per_image,
|
|
)
|
|
|
|
def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]:
|
|
return [self._transform(transform_media.media.value)]
|
|
|
|
def compute_params(
|
|
self,
|
|
media_list: list[ImageMedia | VideoFrameMedia],
|
|
num_tokens_available: Optional[int] = None,
|
|
max_num_tiles: int | None = None,
|
|
**kwargs,
|
|
) -> list[ImageTilingParams]:
|
|
return [
|
|
ImageTilingParams(
|
|
media=media, num_tiles=1, num_embeddings=self._embeddings_per_image
|
|
)
|
|
for media in media_list
|
|
]
|
|
|
|
def __str__(self):
|
|
return f"SimpleImageTransform(vision_model_type={self._vision_model_type}, num_tokens_per_image={self._embeddings_per_image})"
|
|
|
|
|
|
@dataclass
|
|
class ImageTilingParamsV1(ImageTilingParams):
|
|
tiling: tuple[int, int]
|
|
|
|
|
|
class ImageTilingStrategyV1(_FixedSizeStrategy):
|
|
"""Tiling image transformation.
|
|
|
|
This transformation splits the image into a grid of tiles and applies the transformation to each tile.
|
|
"""
|
|
|
|
# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
|
|
# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
|
|
|
|
def __init__(
|
|
self,
|
|
vision_model_type: str,
|
|
tile_size: int,
|
|
use_thumbnail: bool,
|
|
min_num_tiles: int,
|
|
max_num_tiles: int,
|
|
embeddings_per_tile: int,
|
|
find_closest_aspect_ratio_fn=find_closest_aspect_ratio,
|
|
):
|
|
super().__init__(
|
|
vision_model_type=vision_model_type,
|
|
target_width=tile_size,
|
|
target_height=tile_size,
|
|
embeddings_per_image=embeddings_per_tile,
|
|
)
|
|
|
|
# print(f"Transformation params: {vision_model_type=}, {use_tiling=}, {tile_size=}, {use_thumbnail=}, {augment=}, {min_num_tiles=}, {max_num_tiles=}, {find_closest_aspect_ratio_fn=}")
|
|
self._tile_size = tile_size
|
|
self._use_thumbnail = use_thumbnail
|
|
self._min_num_tiles = min_num_tiles
|
|
self._max_num_tiles = max_num_tiles
|
|
self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn
|
|
|
|
# Calculate all possible aspect ratios for each max_num_tiles.
|
|
self.target_ratios = {
|
|
max_num_tiles: sorted(
|
|
set(
|
|
(x, y)
|
|
for n in range(self._min_num_tiles, max_num_tiles + 1)
|
|
for x in range(1, n + 1)
|
|
for y in range(1, n + 1)
|
|
if x * y <= max_num_tiles and x * y >= self._min_num_tiles
|
|
),
|
|
key=lambda x: x[0] * x[1],
|
|
)
|
|
for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1)
|
|
}
|
|
|
|
self.transform = A.Compose([
|
|
A.OneOf([
|
|
A.GaussNoise(var_limit=(5.0, 30.0)),
|
|
A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5)),
|
|
], p=0.3),
|
|
A.OneOf([
|
|
A.MedianBlur(blur_limit=5),
|
|
A.GaussianBlur(blur_limit=5),
|
|
], p=0.2),
|
|
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.5),
|
|
A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=15, val_shift_limit=15, p=0.3),
|
|
A.ImageCompression(quality_lower=70, quality_upper=100, p=0.3),
|
|
])
|
|
|
|
def apply_params(self, transform_media: ImageTilingParams, data_augment: bool = False, **kwargs) -> list[torch.Tensor]:
|
|
assert isinstance(transform_media, ImageTilingParamsV1)
|
|
image = transform_media.media.value
|
|
|
|
if data_augment:
|
|
image = self.transform(image=np.asarray(image))["image"]
|
|
image = Image.fromarray(image)
|
|
|
|
# calculate the target width and height
|
|
target_width = self._tile_size * transform_media.tiling[0]
|
|
target_height = self._tile_size * transform_media.tiling[1]
|
|
blocks = transform_media.tiling[0] * transform_media.tiling[1]
|
|
|
|
# resize the image
|
|
resized_img = image.resize((target_width, target_height))
|
|
processed_images = []
|
|
for i in range(blocks):
|
|
box = (
|
|
(i % (target_width // self._tile_size)) * self._tile_size,
|
|
(i // (target_width // self._tile_size)) * self._tile_size,
|
|
((i % (target_width // self._tile_size)) + 1) * self._tile_size,
|
|
((i // (target_width // self._tile_size)) + 1) * self._tile_size,
|
|
)
|
|
# split the image
|
|
split_img = resized_img.crop(box)
|
|
processed_images.append(split_img)
|
|
assert len(processed_images) == blocks
|
|
if self._use_thumbnail and len(processed_images) != 1:
|
|
thumbnail_img = image.resize((self._tile_size, self._tile_size))
|
|
processed_images.append(thumbnail_img)
|
|
|
|
return [self._transform(img) for img in processed_images]
|
|
|
|
def compute_params(
|
|
self,
|
|
media_list: list[ImageMedia | VideoFrameMedia],
|
|
num_tokens_available: Optional[int] = None,
|
|
max_num_tiles: int | None = None,
|
|
data_augment: bool = False,
|
|
tiling_augment_prob: float = 0.4,
|
|
**kwargs,
|
|
) -> list[ImageTilingParamsV1]:
|
|
# Use provided max_num_tiles or fall back to instance's max_num_tiles
|
|
# Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value
|
|
effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
|
|
effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles)
|
|
|
|
max_num_tiles_to_use = min(
|
|
num_tokens_available // self._embeddings_per_image, effective_max_num_tiles
|
|
)
|
|
|
|
# calculate the existing image aspect ratio
|
|
target_ratios = self.target_ratios[max_num_tiles_to_use]
|
|
|
|
params = []
|
|
for media in media_list:
|
|
if isinstance(media, ImageMedia):
|
|
img_size = (media.width, media.height)
|
|
elif isinstance(media, VideoFrameMedia):
|
|
img_size = (media.video_width, media.video_height)
|
|
else:
|
|
raise ValueError(f"Unsupported media type: {type(media)}")
|
|
|
|
aspect_ratio = img_size[0] / img_size[1]
|
|
|
|
# find the closest aspect ratio to the target
|
|
tiling = self._find_closest_aspect_ratio_fn(
|
|
aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size
|
|
)
|
|
if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob:
|
|
tiling = self.augment_tiling(tiling)
|
|
num_tiles = tiling[0] * tiling[1]
|
|
if self._use_thumbnail and num_tiles != 1:
|
|
num_tiles += 1
|
|
|
|
params.append(
|
|
ImageTilingParamsV1(
|
|
media=media,
|
|
num_tiles=num_tiles,
|
|
num_embeddings=num_tiles * self._embeddings_per_image,
|
|
tiling=tiling,
|
|
)
|
|
)
|
|
|
|
return params
|
|
|
|
def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]:
|
|
def num_tiles(tiling: tuple[int, int]) -> int:
|
|
return tiling[0] * tiling[1]
|
|
|
|
def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]:
|
|
if random.random() < minus_prob:
|
|
# Minus one
|
|
if tiling[0] == 1 and tiling[1] == 1:
|
|
return tiling
|
|
elif tiling[0] == 1:
|
|
return (tiling[0], tiling[1] - 1)
|
|
elif tiling[1] == 1:
|
|
return (tiling[0] - 1, tiling[1])
|
|
else:
|
|
if random.random() < 0.5:
|
|
return (tiling[0] - 1, tiling[1])
|
|
else:
|
|
return (tiling[0], tiling[1] - 1)
|
|
else:
|
|
# Plus one
|
|
if num_tiles(tiling) < self._max_num_tiles:
|
|
tiling0 = (tiling[0] + 1, tiling[1])
|
|
tiling1 = (tiling[0], tiling[1] + 1)
|
|
if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles:
|
|
return tiling
|
|
elif num_tiles(tiling0) > self._max_num_tiles:
|
|
return tiling1
|
|
elif num_tiles(tiling1) > self._max_num_tiles:
|
|
return tiling0
|
|
else:
|
|
if random.random() < 0.5:
|
|
return tiling0
|
|
else:
|
|
return tiling1
|
|
return tiling
|
|
|
|
new_tiling = plus_minus_one(tiling)
|
|
return new_tiling
|
|
|
|
def __str__(self):
|
|
return f"TilingImageTransform(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, embeddings_per_tile={self._embeddings_per_image}, find_closest_aspect_ratio_fn={self._find_closest_aspect_ratio_fn})"
|
|
|
|
|
|
class TileDegradationStrategy(ImageTilingStrategy):
|
|
"""Strategy for tiling images and video frames, each with their own tiling strategy, while trying to match the
|
|
number of tokens left in the sample by reducing the number of tiles if needed.
|
|
"""
|
|
|
|
# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
|
|
# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
|
|
|
|
def __init__(
|
|
self,
|
|
image_strategy: ImageTilingStrategy,
|
|
video_frame_strategy: ImageTilingStrategy,
|
|
embeddings_per_tile: int,
|
|
max_num_tiles: int,
|
|
tile_degradation_map: dict[int, int] = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1},
|
|
):
|
|
self._image_strategy = image_strategy
|
|
self._video_frame_strategy = video_frame_strategy
|
|
self._embeddings_per_tile = embeddings_per_tile
|
|
self._max_num_tiles = max_num_tiles
|
|
self._tile_degradation_map = tile_degradation_map
|
|
|
|
def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]:
|
|
if isinstance(transform_media.media, ImageMedia):
|
|
return self._image_strategy.apply_params(transform_media, **kwargs)
|
|
elif isinstance(transform_media.media, VideoFrameMedia):
|
|
return self._video_frame_strategy.apply_params(transform_media, **kwargs)
|
|
else:
|
|
raise ValueError(f"Unsupported media type: {type(transform_media.media)}")
|
|
|
|
def compute_params(
|
|
self,
|
|
media_list: list[ImageMedia | VideoFrameMedia],
|
|
num_tokens_available: int | None = None,
|
|
max_num_tiles: int | None = None,
|
|
**kwargs,
|
|
) -> list[ImageTilingParams]:
|
|
# Use provided max_num_tiles or fall back to instance's max_num_tiles
|
|
effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
|
|
max_num_tiles_to_use = effective_max_num_tiles
|
|
degradation_map = self._tile_degradation_map
|
|
|
|
while True:
|
|
params = []
|
|
img_num_tiles = []
|
|
for media in media_list:
|
|
if isinstance(media, ImageMedia):
|
|
media_params = self._image_strategy.compute_params(
|
|
[media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs
|
|
)[0]
|
|
elif isinstance(media, VideoFrameMedia):
|
|
max_num_tiles_to_use = 1
|
|
media_params = self._video_frame_strategy.compute_params(
|
|
[media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs
|
|
)[0]
|
|
else:
|
|
raise ValueError(f"Unsupported media type: {type(media)}")
|
|
img_num_tiles.append(media_params.num_tiles)
|
|
params.append(media_params)
|
|
if max_num_tiles_to_use == 1 or num_tokens_available is None:
|
|
break
|
|
if sum(img_num_tiles) * self._embeddings_per_tile > num_tokens_available:
|
|
if max_num_tiles_to_use in degradation_map:
|
|
max_num_tiles_to_use = degradation_map[max_num_tiles_to_use]
|
|
else:
|
|
# End of degradation
|
|
break
|
|
else:
|
|
break
|
|
return params
|
|
|
|
def stack(
|
|
self, images: list[torch.Tensor]
|
|
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
|
|
return self._image_strategy.stack(images)
|
|
|
|
def __str__(self):
|
|
return f"TileDegradationImageTransform(max_num_tiles={self._max_num_tiles}, image_transform={self._image_strategy}, video_frame_transform={self._video_frame_strategy})"
|
|
|
|
|
|
@dataclass
|
|
class DynamicResolutionParams(ImageTilingParams):
|
|
patch_size: tuple[int, int]
|
|
|
|
|
|
class DynamicResolutionImageTilingStrategy(ImageTilingStrategy):
|
|
"""Preprocess an image with dynamic resolution for vision transformers.
|
|
|
|
This function resizes an image to optimize the number of patches while respecting
|
|
constraints on minimum/maximum patches, minimum side length, and compatibility
|
|
with pixel shuffle or convolution merging operations.
|
|
|
|
The algorithm works by:
|
|
1. Computing the initial patch grid size based on the image dimensions and res_step
|
|
2. Scaling the patch grid to fit within the max_patches constraint
|
|
3. Ensuring the result has at least min_patches
|
|
4. Optionally enforcing a minimum side length constraint
|
|
5. Rounding patch dimensions to even numbers for pixel_shuffle/conv_merging compatibility
|
|
6. Resizing the image to the computed target dimensions
|
|
|
|
Note:
|
|
The function preserves aspect ratio as much as possible while satisfying all constraints.
|
|
When constraints conflict (e.g., min_side vs max_patches), the function prioritizes
|
|
staying within max_patches while maximizing the image size.
|
|
|
|
Example:
|
|
>>> from PIL import Image
|
|
>>> img = Image.open("example.jpg") # 800x600 image
|
|
>>> strategy = DynamicResolutionImageTilingStrategy(vision_model_type="radio", min_patches=4, max_patches=64, res_step=14, get_num_embeddings=lambda x, y: x * y * 2)
|
|
>>> params = strategy.compute_params([img])
|
|
>>> img_tensor = strategy.apply_params(params[0])
|
|
>>> # Returns image resized to maintain aspect ratio with 4-64 patches of size 14x14
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vision_model_type: str,
|
|
min_num_patches: int,
|
|
patch_size: int,
|
|
get_num_embeddings: Callable[[int, int], int],
|
|
factor_max: float = 1.0,
|
|
pixel_shuffle: bool = False,
|
|
min_side: int | None = None,
|
|
conv_merging: bool = False,
|
|
use_thumbnail: bool = False,
|
|
thumbnail_size: int = 448,
|
|
thumbnail_area_threshold: float = 0.8,
|
|
max_num_patches: int = 0,
|
|
apply_data_augment: bool = False,
|
|
):
|
|
"""
|
|
Args:
|
|
vision_model_type: Vision model type.
|
|
min_num_patches: Minimum number of patches required. Defaults to 1.
|
|
max_num_patches: Maximum number of patches allowed. Defaults to 0 (no maximum).
|
|
patch_size: Resolution step size (patch dimension). Defaults to 16.
|
|
get_num_embeddings: Function to get the number of embeddings from the patch size (width, height).
|
|
factor_max: Maximum scaling factor to apply. Defaults to 1.0.
|
|
pixel_shuffle: Whether to ensure compatibility with pixel shuffle operations by rounding to even patch
|
|
dimensions. Defaults to False.
|
|
min_side: Minimum side length in pixels. If specified, ensures at least one side meets this constraint.
|
|
Defaults to None.
|
|
conv_merging: Whether to ensure compatibility with convolution merging by rounding to even patch dimensions.
|
|
Defaults to False.
|
|
use_thumbnail: Whether to add a thumbnail image when processing. Defaults to False.
|
|
thumbnail_size: Size of the thumbnail image (width and height). Defaults to 448.
|
|
thumbnail_area_threshold: Maximum area percentage (0.0-1.0) of the resized image relative to thumbnail area
|
|
for which to add a thumbnail. If the resized image area is larger than this threshold of the thumbnail
|
|
area, no thumbnail will be added. Defaults to 0.8 (80%).
|
|
apply_data_augment: Whether to apply data augmentation to the image. Defaults to False.
|
|
"""
|
|
assert "radio" in vision_model_type, (
|
|
"Dynamic resolution is only supported for radio models"
|
|
)
|
|
self._vision_model_type = vision_model_type
|
|
self._min_num_patches = min_num_patches
|
|
self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf")
|
|
self._patch_size = patch_size
|
|
self._get_num_embeddings = get_num_embeddings
|
|
self._factor_max = factor_max
|
|
self._pixel_shuffle = pixel_shuffle
|
|
self._min_side = min_side
|
|
self._conv_merging = conv_merging
|
|
self._use_thumbnail = use_thumbnail
|
|
self._thumbnail_size = thumbnail_size
|
|
self._thumbnail_area_threshold = thumbnail_area_threshold
|
|
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
|
|
self._transform = T.Compose(
|
|
[
|
|
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
|
T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)),
|
|
T.Normalize(mean=pixel_mean, std=pixel_std),
|
|
]
|
|
)
|
|
self._apply_data_augment = apply_data_augment
|
|
|
|
def apply_params(self, params: DynamicResolutionParams, **kwargs) -> list[torch.Tensor]:
|
|
# resize the image
|
|
resized_img = params.media.value.resize(
|
|
(
|
|
params.patch_size[0] * self._patch_size,
|
|
params.patch_size[1] * self._patch_size,
|
|
)
|
|
)
|
|
processed_images = [resized_img]
|
|
|
|
# Add thumbnail if enabled and image area is below threshold
|
|
if self._use_thumbnail:
|
|
# Calculate areas
|
|
resized_area = resized_img.size[0] * resized_img.size[1]
|
|
thumbnail_area = self._thumbnail_size * self._thumbnail_size
|
|
area_ratio = resized_area / thumbnail_area
|
|
|
|
# Only add thumbnail if resized image area is less than threshold % of thumbnail area
|
|
if area_ratio < self._thumbnail_area_threshold:
|
|
thumbnail_img = params.media.value.resize((self._thumbnail_size, self._thumbnail_size))
|
|
processed_images.append(thumbnail_img)
|
|
|
|
return [self._transform(img) for img in processed_images]
|
|
|
|
def process_media(
|
|
self,
|
|
media: ImageMedia | VideoFrameMedia,
|
|
num_tokens_available: int,
|
|
data_augment: bool = False,
|
|
tiling_augment_prob: float = 0.4,
|
|
) -> DynamicResolutionParams:
|
|
"""Process a single media item and return its parameters.
|
|
|
|
Args:
|
|
media: The media item to process
|
|
num_tokens_available: Number of tokens available for this media
|
|
data_augment: Whether to apply data augmentation to the image. Defaults to False.
|
|
Returns:
|
|
DynamicResolutionParams for the media
|
|
"""
|
|
current_num_tokens_available = num_tokens_available
|
|
if isinstance(media, ImageMedia):
|
|
orig_width, orig_height = media.width, media.height
|
|
elif isinstance(media, VideoFrameMedia):
|
|
orig_width, orig_height = media.video_width, media.video_height
|
|
# current_num_tokens_available = 1024 #TEMP: hack for video
|
|
else:
|
|
raise ValueError(f"Unsupported media type: {type(media)}")
|
|
|
|
closest_patch_height = round(orig_height / self._patch_size + 0.5)
|
|
closest_patch_width = round(orig_width / self._patch_size + 0.5)
|
|
patches = closest_patch_height * closest_patch_width
|
|
|
|
factor = min(math.sqrt(current_num_tokens_available / patches), self._factor_max)
|
|
target_patch_height = math.floor(factor * closest_patch_height)
|
|
target_patch_width = math.floor(factor * closest_patch_width)
|
|
|
|
# We only consider self._min_num_patches if it is greater than current_num_tokens_available.
|
|
if current_num_tokens_available > self._min_num_patches and target_patch_height * target_patch_width < self._min_num_patches:
|
|
up_factor = math.sqrt(
|
|
self._min_num_patches / (target_patch_height * target_patch_width)
|
|
)
|
|
target_patch_height = math.ceil(up_factor * target_patch_height)
|
|
target_patch_width = math.ceil(up_factor * target_patch_width)
|
|
|
|
if (
|
|
self._min_side is not None
|
|
and min(target_patch_width, target_patch_height) * self._patch_size
|
|
< self._min_side
|
|
):
|
|
if target_patch_width <= target_patch_height:
|
|
up_factor = self._min_side / (target_patch_width * self._patch_size)
|
|
new_patch_height = math.ceil(up_factor * target_patch_height)
|
|
new_patch_width = math.ceil(up_factor * target_patch_width)
|
|
|
|
if new_patch_height * new_patch_width > current_num_tokens_available:
|
|
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
|
|
if (
|
|
max(current_num_tokens_available // new_patch_width, 1)
|
|
* self._patch_size
|
|
< self._min_side
|
|
):
|
|
up_factor = math.sqrt(
|
|
current_num_tokens_available
|
|
/ (target_patch_height * target_patch_width)
|
|
)
|
|
target_patch_height = math.floor(
|
|
up_factor * target_patch_height
|
|
)
|
|
target_patch_width = math.floor(
|
|
up_factor * target_patch_width
|
|
)
|
|
target_patch_width = new_patch_width
|
|
target_patch_height = max(
|
|
current_num_tokens_available // new_patch_width, 1
|
|
)
|
|
else:
|
|
target_patch_height = new_patch_height
|
|
target_patch_width = new_patch_width
|
|
else:
|
|
up_factor = self._min_side / (
|
|
target_patch_height * self._patch_size
|
|
)
|
|
new_patch_height = math.ceil(up_factor * target_patch_height)
|
|
new_patch_width = math.ceil(up_factor * target_patch_width)
|
|
|
|
if new_patch_height * new_patch_width > current_num_tokens_available:
|
|
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
|
|
if (
|
|
max(current_num_tokens_available // new_patch_height, 1)
|
|
* self._patch_size
|
|
< self._min_side
|
|
):
|
|
up_factor = math.sqrt(
|
|
current_num_tokens_available
|
|
/ (target_patch_height * target_patch_width)
|
|
)
|
|
target_patch_height = math.floor(
|
|
up_factor * target_patch_height
|
|
)
|
|
target_patch_width = math.floor(
|
|
up_factor * target_patch_width
|
|
)
|
|
else:
|
|
target_patch_height = new_patch_height
|
|
target_patch_width = max(
|
|
current_num_tokens_available // new_patch_height, 1
|
|
)
|
|
else:
|
|
target_patch_height = new_patch_height
|
|
target_patch_width = new_patch_width
|
|
|
|
# Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
|
|
# or by 4 when BOTH are enabled (two successive 2x reductions)
|
|
if self._pixel_shuffle or self._conv_merging:
|
|
required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2
|
|
|
|
rem_h = target_patch_height % required_divisor
|
|
if rem_h != 0:
|
|
inc_h = required_divisor - rem_h
|
|
if (target_patch_height + inc_h) * target_patch_width <= current_num_tokens_available:
|
|
target_patch_height += inc_h
|
|
else:
|
|
target_patch_height = max(required_divisor, target_patch_height - rem_h)
|
|
|
|
rem_w = target_patch_width % required_divisor
|
|
if rem_w != 0:
|
|
inc_w = required_divisor - rem_w
|
|
if target_patch_height * (target_patch_width + inc_w) <= current_num_tokens_available:
|
|
target_patch_width += inc_w
|
|
else:
|
|
target_patch_width = max(required_divisor, target_patch_width - rem_w)
|
|
|
|
if data_augment and self._apply_data_augment and random.random() < tiling_augment_prob:
|
|
target_patch_width, target_patch_height = self.augment_resolution(target_patch_width, target_patch_height, current_num_tokens_available)
|
|
|
|
#TEMP: hack for video
|
|
if isinstance(media, VideoFrameMedia):
|
|
target_patch_width = 32
|
|
target_patch_height = 32
|
|
|
|
# Calculate embeddings for the main dynamic resolution image
|
|
num_embeddings = self._get_num_embeddings(
|
|
target_patch_width * self._patch_size,
|
|
target_patch_height * self._patch_size,
|
|
)
|
|
|
|
token_count = target_patch_width * target_patch_height
|
|
|
|
# Add thumbnail embeddings if enabled and image area is below threshold
|
|
num_tiles = 1 # Base dynamic resolution image
|
|
if self._use_thumbnail:
|
|
# Calculate areas
|
|
resized_area = (target_patch_width * self._patch_size) * (target_patch_height * self._patch_size)
|
|
thumbnail_area = self._thumbnail_size * self._thumbnail_size
|
|
area_ratio = resized_area / thumbnail_area
|
|
|
|
# Only add thumbnail if resized image area is less than threshold % of thumbnail area
|
|
if area_ratio < self._thumbnail_area_threshold:
|
|
num_tiles += 1 # Add 1 for thumbnail
|
|
# Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
|
|
num_embeddings += self._get_num_embeddings(self._thumbnail_size, self._thumbnail_size)
|
|
token_count += self._thumbnail_size // self._patch_size * self._thumbnail_size // self._patch_size
|
|
|
|
return DynamicResolutionParams(
|
|
media=media,
|
|
num_tiles=num_tiles,
|
|
num_embeddings=num_embeddings,
|
|
patch_size=(target_patch_width, target_patch_height),
|
|
), token_count
|
|
|
|
def augment_resolution(self, target_patch_width: int, target_patch_height: int, current_num_tokens_available: int) -> tuple[int, int]:
|
|
|
|
min_num_patch_one_side = 32
|
|
|
|
if random.random() < 0.5:
|
|
# Minus one
|
|
if target_patch_width <= min_num_patch_one_side and target_patch_height <= min_num_patch_one_side:
|
|
return target_patch_width, target_patch_height
|
|
elif target_patch_width <= min_num_patch_one_side:
|
|
return target_patch_width, target_patch_height - min_num_patch_one_side
|
|
elif target_patch_height <= min_num_patch_one_side:
|
|
return target_patch_width - min_num_patch_one_side, target_patch_height
|
|
else:
|
|
if random.random() < 0.5:
|
|
return target_patch_width - min_num_patch_one_side, target_patch_height
|
|
else:
|
|
return target_patch_width, target_patch_height - min_num_patch_one_side
|
|
else:
|
|
# Plus one
|
|
if target_patch_width * target_patch_height < current_num_tokens_available:
|
|
if random.random() < 0.5:
|
|
return target_patch_width + min_num_patch_one_side, target_patch_height
|
|
else:
|
|
return target_patch_width, target_patch_height + min_num_patch_one_side
|
|
return target_patch_width, target_patch_height
|
|
|
|
def compute_params(
|
|
self,
|
|
media_list: list[ImageMedia | VideoFrameMedia],
|
|
num_tokens_available: int | None = None,
|
|
max_num_tiles: int | None = None,
|
|
data_augment: bool = False,
|
|
**kwargs,
|
|
) -> list[ImageTilingParams]:
|
|
"""Compute parameters for all media with iterative token budgeting.
|
|
|
|
Args:
|
|
media_list: List of media items to process
|
|
num_tokens_available: Total number of tokens available across all media
|
|
max_num_tiles: Maximum number of tiles (unused in this implementation)
|
|
data_augment: Whether to apply data augmentation to the image. Defaults to False.
|
|
Returns:
|
|
List of ImageTilingParams for each media item
|
|
"""
|
|
num_tokens_available = num_tokens_available * (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1)
|
|
# When the number of available token is too small, allow self._min_num_patches per media and
|
|
# let the sample be truncated.
|
|
num_tokens_available = max(num_tokens_available, self._min_num_patches * len(media_list))
|
|
|
|
# Clip the number of tokens available per media to be between min and max patches.
|
|
num_tokens_available_per_media = [
|
|
max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
|
|
for _ in range(len(media_list))]
|
|
|
|
# In theory this could be a while True loop, but in case the process_media method slightly
|
|
# changes, I want to make sure we don't get stuck in an infinite loop.
|
|
for _ in range(10):
|
|
# Step 1: Process each media with current token budget
|
|
params = []
|
|
token_counts = []
|
|
|
|
for media, tokens_for_media in zip(media_list, num_tokens_available_per_media):
|
|
param, token_count = self.process_media(media, tokens_for_media, data_augment=data_augment)
|
|
params.append(param)
|
|
token_counts.append(token_count)
|
|
|
|
# Step 2: Check if total tokens is within budget
|
|
total_tokens = sum(token_counts)
|
|
|
|
if total_tokens <= num_tokens_available:
|
|
# We're within budget, return the params
|
|
return params
|
|
|
|
# Step 3: We're over budget, need to scale down
|
|
# Calculate scaling factor to get under budget
|
|
scaling_factor = num_tokens_available / total_tokens
|
|
|
|
# Recalculate token budgets for each media based on scaling
|
|
# Each media gets a proportional share of the total budget
|
|
scaled_down_num_tokens_available_per_media = [
|
|
max(self._min_num_patches, int(token_count * scaling_factor))
|
|
for token_count in token_counts
|
|
]
|
|
scaled_down = any([
|
|
scaled_down_num_tokens_available_per_media[i] < num_tokens_available_per_media[i]
|
|
for i in range(len(num_tokens_available_per_media))])
|
|
# If there was not scaling down, we're stuck just use min_num_patches per media, else
|
|
# try with the scaled down num_tokens_available_per_media.
|
|
if not scaled_down:
|
|
num_tokens_available_per_media = [self._min_num_patches] * len(media_list)
|
|
else:
|
|
num_tokens_available_per_media = scaled_down_num_tokens_available_per_media
|
|
return params
|
|
|
|
def stack(
|
|
self, images: list[torch.Tensor]
|
|
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
|
|
imgs_sizes = torch.tensor(
|
|
[[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
|
|
)
|
|
|
|
def rearrange_img(x):
|
|
py = x.shape[-2] // self._patch_size
|
|
px = x.shape[-1] // self._patch_size
|
|
x = einops.rearrange(
|
|
x,
|
|
"c (py yy) (px xx) -> (py px) (c yy xx)",
|
|
py=py,
|
|
yy=self._patch_size,
|
|
px=px,
|
|
xx=self._patch_size,
|
|
)
|
|
return x
|
|
|
|
if len(images) > 0:
|
|
imgs = [rearrange_img(img) for img in images]
|
|
|
|
current_length = 0
|
|
max_length = 0
|
|
vision_cu_lengths = [0]
|
|
for img in imgs:
|
|
if max_length < img.shape[0]:
|
|
max_length = img.shape[0]
|
|
current_length += img.shape[0]
|
|
vision_cu_lengths.append(current_length)
|
|
|
|
vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
|
|
vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
|
|
|
|
return (
|
|
torch.cat(imgs, dim=0).unsqueeze(0),
|
|
imgs_sizes,
|
|
vision_cu_lengths,
|
|
vision_max_lengths,
|
|
)
|
|
else:
|
|
return (
|
|
torch.tensor([[0]], dtype=torch.float32),
|
|
torch.tensor([[0,0]], dtype=torch.int32),
|
|
None,
|
|
None,
|
|
)
|
|
|
|
def __str__(self):
|
|
return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})"
|
|
|
|
|
|
@dataclass
|
|
class MatchTilingDynamicResolutionParams(ImageTilingParams):
|
|
tiling: tuple[int, int]
|
|
|
|
|
|
class MatchTilingDynamicResolutionStrategy(ImageTilingStrategy):
|
|
"""
|
|
Strategy that uses tiling logic to determine optimal image dimensions but processes
|
|
the image as a single dynamic resolution image instead of splitting into tiles.
|
|
|
|
This combines the aspect ratio optimization from ImageTilingStrategyV1 with the
|
|
dynamic resolution processing from DynamicResolutionImageTilingStrategy.
|
|
|
|
Also includes tile degradation logic similar to TileDegradationStrategy.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vision_model_type: str,
|
|
tile_size: int,
|
|
use_thumbnail: bool,
|
|
min_num_tiles: int,
|
|
max_num_tiles: int,
|
|
embeddings_per_tile: int,
|
|
patch_size: int,
|
|
get_num_embeddings: Callable[[int, int], int],
|
|
find_closest_aspect_ratio_fn=find_closest_aspect_ratio,
|
|
pixel_shuffle: bool = False,
|
|
conv_merging: bool = False,
|
|
tile_degradation_map: dict[int, int] = None,
|
|
video_frame_strategy: ImageTilingStrategy = None,
|
|
enable_tile_degradation: bool = True,
|
|
):
|
|
"""
|
|
Args:
|
|
vision_model_type: Vision model type (should support dynamic resolution)
|
|
tile_size: Size of each tile for tiling calculation
|
|
use_thumbnail: Whether tiling logic should include thumbnail
|
|
min_num_tiles: Minimum number of tiles for tiling calculation
|
|
max_num_tiles: Maximum number of tiles for tiling calculation
|
|
embeddings_per_tile: Embeddings per tile for tiling calculation
|
|
patch_size: Patch size for dynamic resolution processing
|
|
get_num_embeddings: Function to get number of embeddings from dimensions
|
|
find_closest_aspect_ratio_fn: Function to find closest aspect ratio
|
|
pixel_shuffle: Whether to ensure compatibility with pixel shuffle
|
|
conv_merging: Whether to ensure compatibility with convolution merging
|
|
tile_degradation_map: Map for degrading tiles when tokens are insufficient
|
|
video_frame_strategy: Strategy for processing video frames
|
|
enable_tile_degradation: Whether to enable tile degradation (default: True)
|
|
"""
|
|
assert "radio" in vision_model_type, (
|
|
"MatchTilingDynamicResolution is only supported for radio models"
|
|
)
|
|
|
|
self._vision_model_type = vision_model_type
|
|
self._tile_size = tile_size
|
|
self._use_thumbnail = use_thumbnail
|
|
self._min_num_tiles = min_num_tiles
|
|
self._max_num_tiles = max_num_tiles
|
|
self._embeddings_per_tile = embeddings_per_tile
|
|
self._patch_size = patch_size
|
|
self._get_num_embeddings = get_num_embeddings
|
|
self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn
|
|
self._pixel_shuffle = pixel_shuffle
|
|
self._conv_merging = conv_merging
|
|
self._enable_tile_degradation = enable_tile_degradation
|
|
|
|
# Tile degradation logic (similar to TileDegradationStrategy)
|
|
if tile_degradation_map is None:
|
|
self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1}
|
|
else:
|
|
self._tile_degradation_map = tile_degradation_map
|
|
|
|
# Video frame strategy (similar to TileDegradationStrategy)
|
|
if video_frame_strategy is None:
|
|
self._video_frame_strategy = NoTilingStrategy(
|
|
vision_model_type=vision_model_type,
|
|
target_width=tile_size,
|
|
target_height=tile_size,
|
|
embeddings_per_image=embeddings_per_tile,
|
|
)
|
|
else:
|
|
self._video_frame_strategy = video_frame_strategy
|
|
|
|
# Calculate all possible aspect ratios for each max_num_tiles (borrowed from ImageTilingStrategyV1)
|
|
self.target_ratios = {
|
|
max_num_tiles: sorted(
|
|
set(
|
|
(x, y)
|
|
for n in range(self._min_num_tiles, max_num_tiles + 1)
|
|
for x in range(1, n + 1)
|
|
for y in range(1, n + 1)
|
|
if x * y <= max_num_tiles and x * y >= self._min_num_tiles
|
|
),
|
|
key=lambda x: x[0] * x[1],
|
|
)
|
|
for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1)
|
|
}
|
|
|
|
# Set up transform for dynamic resolution processing
|
|
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
|
|
self._transform = T.Compose(
|
|
[
|
|
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=pixel_mean, std=pixel_std),
|
|
]
|
|
)
|
|
|
|
def apply_params(self, params: MatchTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]:
|
|
# Handle video frames using the video frame strategy
|
|
if isinstance(params.media, VideoFrameMedia):
|
|
return self._video_frame_strategy.apply_params(params, **kwargs)
|
|
|
|
# Handle images with dynamic resolution processing
|
|
image = params.media.value
|
|
# Calculate the target width and height (same logic as ImageTilingStrategyV1)
|
|
target_width = self._tile_size * params.tiling[0]
|
|
target_height = self._tile_size * params.tiling[1]
|
|
|
|
# Resize the image to the target dimensions (same as ImageTilingStrategyV1)
|
|
resized_img = image.resize((target_width, target_height))
|
|
|
|
# Process as single dynamic resolution image
|
|
processed_images = [resized_img]
|
|
|
|
# Add thumbnail if use_thumbnail=True and there's more than 1 tile (same as ImageTilingStrategyV1)
|
|
blocks = params.tiling[0] * params.tiling[1]
|
|
if self._use_thumbnail and blocks != 1:
|
|
thumbnail_img = image.resize((self._tile_size, self._tile_size))
|
|
processed_images.append(thumbnail_img)
|
|
|
|
return [self._transform(img) for img in processed_images]
|
|
|
|
def compute_params(
|
|
self,
|
|
media_list: list[ImageMedia | VideoFrameMedia],
|
|
num_tokens_available: int | None = None,
|
|
max_num_tiles: int | None = None,
|
|
**kwargs,
|
|
) -> list[MatchTilingDynamicResolutionParams]:
|
|
# Implement tile degradation logic similar to TileDegradationStrategy
|
|
# Use provided max_num_tiles or fall back to instance's max_num_tiles
|
|
# Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value
|
|
effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
|
|
effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles)
|
|
max_num_tiles_to_use = effective_max_num_tiles
|
|
degradation_map = self._tile_degradation_map
|
|
|
|
while True:
|
|
params = []
|
|
total_embeddings_needed = 0
|
|
|
|
for media in media_list:
|
|
if isinstance(media, ImageMedia):
|
|
# Use tiling logic for images
|
|
img_size = (media.width, media.height)
|
|
aspect_ratio = img_size[0] / img_size[1]
|
|
|
|
# Find the closest aspect ratio to the target
|
|
target_ratios = self.target_ratios[max_num_tiles_to_use]
|
|
tiling = self._find_closest_aspect_ratio_fn(
|
|
aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size
|
|
)
|
|
|
|
# Calculate target dimensions for dynamic resolution processing
|
|
target_width = self._tile_size * tiling[0]
|
|
target_height = self._tile_size * tiling[1]
|
|
num_embeddings = self._get_num_embeddings(target_width, target_height)
|
|
|
|
# Account for thumbnail (same logic as ImageTilingStrategyV1)
|
|
num_tiles = 1 # Base dynamic resolution image
|
|
blocks = tiling[0] * tiling[1]
|
|
if self._use_thumbnail and blocks != 1:
|
|
num_tiles += 1 # Add 1 for thumbnail
|
|
# Add embeddings for thumbnail (tile_size x tile_size)
|
|
num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size)
|
|
|
|
media_params = MatchTilingDynamicResolutionParams(
|
|
media=media,
|
|
num_tiles=num_tiles,
|
|
num_embeddings=num_embeddings,
|
|
tiling=tiling,
|
|
)
|
|
elif isinstance(media, VideoFrameMedia):
|
|
# Use video frame strategy for video frames (always 1 tile)
|
|
video_params = self._video_frame_strategy.compute_params(
|
|
[media], 1 * self._embeddings_per_tile
|
|
)[0]
|
|
media_params = MatchTilingDynamicResolutionParams(
|
|
media=media,
|
|
num_tiles=video_params.num_tiles,
|
|
num_embeddings=video_params.num_embeddings,
|
|
tiling=(1, 1), # Video frames always use 1x1 tiling
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported media type: {type(media)}")
|
|
|
|
params.append(media_params)
|
|
total_embeddings_needed += media_params.num_embeddings
|
|
|
|
# Check if we need to degrade (only if degradation is enabled)
|
|
if not self._enable_tile_degradation:
|
|
break
|
|
if max_num_tiles_to_use == 1 or num_tokens_available is None:
|
|
break
|
|
if total_embeddings_needed > num_tokens_available:
|
|
if max_num_tiles_to_use in degradation_map:
|
|
max_num_tiles_to_use = degradation_map[max_num_tiles_to_use]
|
|
# Recalculate target ratios for the new max_num_tiles_to_use
|
|
if max_num_tiles_to_use not in self.target_ratios:
|
|
self.target_ratios[max_num_tiles_to_use] = sorted(
|
|
set(
|
|
(x, y)
|
|
for n in range(self._min_num_tiles, max_num_tiles_to_use + 1)
|
|
for x in range(1, n + 1)
|
|
for y in range(1, n + 1)
|
|
if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles
|
|
),
|
|
key=lambda x: x[0] * x[1],
|
|
)
|
|
else:
|
|
# End of degradation
|
|
break
|
|
else:
|
|
break
|
|
|
|
return params
|
|
|
|
def stack(
|
|
self, images: list[torch.Tensor]
|
|
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]:
|
|
"""Stack images using dynamic resolution approach with sequence packing"""
|
|
imgs_sizes = torch.tensor(
|
|
[[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
|
|
)
|
|
|
|
def rearrange_img(x):
|
|
py = x.shape[-2] // self._patch_size
|
|
px = x.shape[-1] // self._patch_size
|
|
x = einops.rearrange(
|
|
x,
|
|
"c (py yy) (px xx) -> (py px) (c yy xx)",
|
|
py=py,
|
|
yy=self._patch_size,
|
|
px=px,
|
|
xx=self._patch_size,
|
|
)
|
|
return x
|
|
|
|
if len(images) > 0:
|
|
imgs = [rearrange_img(img) for img in images]
|
|
|
|
current_length = 0
|
|
max_length = 0
|
|
vision_cu_lengths = [0]
|
|
for img in imgs:
|
|
if max_length < img.shape[0]:
|
|
max_length = img.shape[0]
|
|
current_length += img.shape[0]
|
|
vision_cu_lengths.append(current_length)
|
|
|
|
vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
|
|
vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
|
|
|
|
return (
|
|
torch.cat(imgs, dim=0).unsqueeze(0),
|
|
imgs_sizes,
|
|
vision_cu_lengths,
|
|
vision_max_lengths,
|
|
)
|
|
else:
|
|
return (
|
|
torch.tensor([[0]], dtype=torch.float32),
|
|
torch.tensor([[0,0]], dtype=torch.int32),
|
|
None,
|
|
None,
|
|
)
|
|
|
|
def __str__(self):
|
|
return f"MatchTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})"
|
|
|
|
|
|
@dataclass
|
|
class MaskedTilingDynamicResolutionParams(ImageTilingParams):
|
|
tiling: tuple[int, int]
|
|
|
|
|
|
class MaskedTilingDynamicResolutionStrategy(ImageTilingStrategy):
|
|
"""
|
|
Like MatchTilingDynamicResolutionStrategy, but ensures tiles are isolated in the
|
|
vision encoder by emitting per-tile packed samples (block-diagonal attention across tiles).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vision_model_type: str,
|
|
tile_size: int,
|
|
use_thumbnail: bool,
|
|
min_num_tiles: int,
|
|
max_num_tiles: int,
|
|
embeddings_per_tile: int,
|
|
patch_size: int,
|
|
get_num_embeddings: Callable[[int, int], int],
|
|
find_closest_aspect_ratio_fn=find_closest_aspect_ratio,
|
|
pixel_shuffle: bool = False,
|
|
conv_merging: bool = False,
|
|
tile_degradation_map: dict[int, int] = None,
|
|
video_frame_strategy: ImageTilingStrategy = None,
|
|
enable_tile_degradation: bool = True,
|
|
):
|
|
assert "radio" in vision_model_type, (
|
|
"MaskedTilingDynamicResolution is only supported for radio models"
|
|
)
|
|
|
|
self._vision_model_type = vision_model_type
|
|
self._tile_size = tile_size
|
|
self._use_thumbnail = use_thumbnail
|
|
self._min_num_tiles = min_num_tiles
|
|
self._max_num_tiles = max_num_tiles
|
|
self._embeddings_per_tile = embeddings_per_tile
|
|
self._patch_size = patch_size
|
|
self._get_num_embeddings = get_num_embeddings
|
|
self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn
|
|
self._pixel_shuffle = pixel_shuffle
|
|
self._conv_merging = conv_merging
|
|
self._enable_tile_degradation = enable_tile_degradation
|
|
|
|
if tile_degradation_map is None:
|
|
self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1}
|
|
else:
|
|
self._tile_degradation_map = tile_degradation_map
|
|
|
|
if video_frame_strategy is None:
|
|
self._video_frame_strategy = NoTilingStrategy(
|
|
vision_model_type=vision_model_type,
|
|
target_width=tile_size,
|
|
target_height=tile_size,
|
|
embeddings_per_image=embeddings_per_tile,
|
|
)
|
|
else:
|
|
self._video_frame_strategy = video_frame_strategy
|
|
|
|
self.target_ratios = {
|
|
max_num_tiles: sorted(
|
|
set(
|
|
(x, y)
|
|
for n in range(self._min_num_tiles, max_num_tiles + 1)
|
|
for x in range(1, n + 1)
|
|
for y in range(1, n + 1)
|
|
if x * y <= max_num_tiles and x * y >= self._min_num_tiles
|
|
),
|
|
key=lambda x: x[0] * x[1],
|
|
)
|
|
for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1)
|
|
}
|
|
|
|
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
|
|
self._transform = T.Compose(
|
|
[
|
|
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=pixel_mean, std=pixel_std),
|
|
]
|
|
)
|
|
|
|
def apply_params(self, params: MaskedTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]:
|
|
# Handle video frames using the video frame strategy
|
|
if isinstance(params.media, VideoFrameMedia):
|
|
return self._video_frame_strategy.apply_params(params, **kwargs)
|
|
|
|
image = params.media.value
|
|
nx, ny = params.tiling
|
|
target_width = self._tile_size * nx
|
|
target_height = self._tile_size * ny
|
|
|
|
resized_img = image.resize((target_width, target_height))
|
|
|
|
processed_images = []
|
|
# Emit per-tile images (each becomes an isolated packed sample later)
|
|
for j in range(ny):
|
|
for i in range(nx):
|
|
box = (
|
|
i * self._tile_size,
|
|
j * self._tile_size,
|
|
(i + 1) * self._tile_size,
|
|
(j + 1) * self._tile_size,
|
|
)
|
|
tile_img = resized_img.crop(box)
|
|
processed_images.append(tile_img)
|
|
|
|
if self._use_thumbnail and (nx * ny) != 1:
|
|
thumbnail_img = image.resize((self._tile_size, self._tile_size))
|
|
processed_images.append(thumbnail_img)
|
|
|
|
return [self._transform(img) for img in processed_images]
|
|
|
|
def compute_params(
|
|
self,
|
|
media_list: list[ImageMedia | VideoFrameMedia],
|
|
num_tokens_available: int | None = None,
|
|
max_num_tiles: int | None = None,
|
|
data_augment: bool = False,
|
|
tiling_augment_prob: float = 0.4,
|
|
**kwargs,
|
|
) -> list[MaskedTilingDynamicResolutionParams]:
|
|
effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
|
|
effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles)
|
|
max_num_tiles_to_use = effective_max_num_tiles
|
|
degradation_map = self._tile_degradation_map
|
|
|
|
while True:
|
|
params = []
|
|
total_embeddings_needed = 0
|
|
|
|
for media in media_list:
|
|
if isinstance(media, ImageMedia):
|
|
img_size = (media.width, media.height)
|
|
aspect_ratio = img_size[0] / img_size[1]
|
|
|
|
target_ratios = self.target_ratios[max_num_tiles_to_use]
|
|
tiling = self._find_closest_aspect_ratio_fn(
|
|
aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size
|
|
)
|
|
|
|
# Apply tiling augmentation if enabled
|
|
if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob:
|
|
tiling = self.augment_tiling(tiling)
|
|
|
|
blocks = tiling[0] * tiling[1]
|
|
# Each tile is tile_size x tile_size
|
|
per_tile_emb = self._get_num_embeddings(self._tile_size, self._tile_size)
|
|
num_embeddings = blocks * per_tile_emb
|
|
|
|
num_tiles = blocks
|
|
if self._use_thumbnail and blocks != 1:
|
|
num_tiles += 1
|
|
num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size)
|
|
|
|
media_params = MaskedTilingDynamicResolutionParams(
|
|
media=media,
|
|
num_tiles=num_tiles,
|
|
num_embeddings=num_embeddings,
|
|
tiling=tiling,
|
|
)
|
|
elif isinstance(media, VideoFrameMedia):
|
|
video_params = self._video_frame_strategy.compute_params(
|
|
[media], 1 * self._embeddings_per_tile
|
|
)[0]
|
|
media_params = MaskedTilingDynamicResolutionParams(
|
|
media=media,
|
|
num_tiles=video_params.num_tiles,
|
|
num_embeddings=video_params.num_embeddings,
|
|
tiling=(1, 1),
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported media type: {type(media)}")
|
|
|
|
params.append(media_params)
|
|
total_embeddings_needed += media_params.num_embeddings
|
|
|
|
if not self._enable_tile_degradation:
|
|
break
|
|
if max_num_tiles_to_use == 1 or num_tokens_available is None:
|
|
break
|
|
if total_embeddings_needed > num_tokens_available:
|
|
if max_num_tiles_to_use in degradation_map:
|
|
max_num_tiles_to_use = degradation_map[max_num_tiles_to_use]
|
|
if max_num_tiles_to_use not in self.target_ratios:
|
|
self.target_ratios[max_num_tiles_to_use] = sorted(
|
|
set(
|
|
(x, y)
|
|
for n in range(self._min_num_tiles, max_num_tiles_to_use + 1)
|
|
for x in range(1, n + 1)
|
|
for y in range(1, n + 1)
|
|
if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles
|
|
),
|
|
key=lambda x: x[0] * x[1],
|
|
)
|
|
else:
|
|
break
|
|
else:
|
|
break
|
|
|
|
return params
|
|
|
|
def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]:
|
|
def num_tiles(tiling: tuple[int, int]) -> int:
|
|
return tiling[0] * tiling[1]
|
|
|
|
def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]:
|
|
if random.random() < minus_prob:
|
|
# Minus one
|
|
if tiling[0] == 1 and tiling[1] == 1:
|
|
return tiling
|
|
elif tiling[0] == 1:
|
|
return (tiling[0], tiling[1] - 1)
|
|
elif tiling[1] == 1:
|
|
return (tiling[0] - 1, tiling[1])
|
|
else:
|
|
if random.random() < 0.5:
|
|
return (tiling[0] - 1, tiling[1])
|
|
else:
|
|
return (tiling[0], tiling[1] - 1)
|
|
else:
|
|
# Plus one
|
|
if num_tiles(tiling) < self._max_num_tiles:
|
|
tiling0 = (tiling[0] + 1, tiling[1])
|
|
tiling1 = (tiling[0], tiling[1] + 1)
|
|
if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles:
|
|
return tiling
|
|
elif num_tiles(tiling0) > self._max_num_tiles:
|
|
return tiling1
|
|
elif num_tiles(tiling1) > self._max_num_tiles:
|
|
return tiling0
|
|
else:
|
|
if random.random() < 0.5:
|
|
return tiling0
|
|
else:
|
|
return tiling1
|
|
return tiling
|
|
|
|
new_tiling = plus_minus_one(tiling)
|
|
return new_tiling
|
|
|
|
def stack(
|
|
self, images: list[torch.Tensor]
|
|
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]:
|
|
# Identical to dynamic resolution packing; each tile is already an independent image sample
|
|
imgs_sizes = torch.tensor(
|
|
[[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
|
|
)
|
|
|
|
def rearrange_img(x):
|
|
py = x.shape[-2] // self._patch_size
|
|
px = x.shape[-1] // self._patch_size
|
|
x = einops.rearrange(
|
|
x,
|
|
"c (py yy) (px xx) -> (py px) (c yy xx)",
|
|
py=py,
|
|
yy=self._patch_size,
|
|
px=px,
|
|
xx=self._patch_size,
|
|
)
|
|
return x
|
|
|
|
if len(images) > 0:
|
|
imgs = [rearrange_img(img) for img in images]
|
|
|
|
current_length = 0
|
|
max_length = 0
|
|
vision_cu_lengths = [0]
|
|
for img in imgs:
|
|
if max_length < img.shape[0]:
|
|
max_length = img.shape[0]
|
|
current_length += img.shape[0]
|
|
vision_cu_lengths.append(current_length)
|
|
|
|
vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
|
|
vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
|
|
|
|
return (
|
|
torch.cat(imgs, dim=0).unsqueeze(0),
|
|
imgs_sizes,
|
|
vision_cu_lengths,
|
|
vision_max_lengths,
|
|
)
|
|
else:
|
|
return (
|
|
torch.tensor([[0]], dtype=torch.float32),
|
|
torch.tensor([[0,0]], dtype=torch.int32),
|
|
None,
|
|
None,
|
|
)
|
|
|
|
def __str__(self):
|
|
return f"MaskedTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})"
|
|
|
|
def create_image_tiling_strategy(args):
|
|
"""
|
|
Create an image tiling strategy based on the provided arguments.
|
|
|
|
This function encapsulates the logic for creating the appropriate image tiling strategy
|
|
based on the training/evaluation configuration. It can be used by both training (task_encoder)
|
|
and evaluation code outside of data_loading/.
|
|
|
|
Args:
|
|
args: Arguments object with the following relevant attributes:
|
|
- img_h, img_w: Image height and width
|
|
- patch_dim: Patch dimension
|
|
- vision_model_type: Vision model type (e.g., 'radio', 'clip', 'siglip')
|
|
- disable_vision_class_token: Whether to disable vision class token
|
|
- pixel_shuffle: Whether to use pixel shuffle
|
|
- use_tile_tags: Whether to use tile tags
|
|
- max_num_tiles: Maximum number of tiles
|
|
- tokenizer_prompt_format: Tokenizer prompt format
|
|
- image_break_token: Image break token (optional)
|
|
- conv_merging: Whether to use convolution merging
|
|
- dynamic_resolution: Whether to use dynamic resolution
|
|
- match_tiling_dynamic_resolution: Whether to match tiling with dynamic resolution
|
|
- use_area_weighted_aspect_ratio: Whether to use area-weighted aspect ratio
|
|
- use_thumbnail: Whether to use thumbnail
|
|
- dynamic_resolution_min_patches: Minimum number of patches for dynamic resolution
|
|
- dynamic_resolution_min_side: Minimum side length for dynamic resolution (optional)
|
|
- thumbnail_area_threshold: Thumbnail area threshold (optional)
|
|
- use_tiling: Whether to use tiling
|
|
|
|
Returns:
|
|
ImageTilingStrategy: The created image tiling strategy
|
|
"""
|
|
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
|
|
|
|
assert args.img_h == args.img_w, "img_h and img_w must be the same"
|
|
|
|
match_tiling_dynamic_resolution = args.match_tiling_dynamic_resolution
|
|
masked_tiling_dynamic_resolution = getattr(args, "masked_tiling_dynamic_resolution", False)
|
|
dynamic_resolution = args.dynamic_resolution
|
|
use_tiling = args.use_tiling
|
|
use_area_weighted_aspect_ratio = args.use_area_weighted_aspect_ratio
|
|
|
|
if match_tiling_dynamic_resolution:
|
|
assert dynamic_resolution, "must enable --dynamic-resolution if using --match-tiling-dynamic-resolution"
|
|
assert not use_tiling, "cannot use --use-tiling and --match-tiling-dynamic-resolution together"
|
|
if masked_tiling_dynamic_resolution:
|
|
assert dynamic_resolution, "must enable --dynamic-resolution if using --masked-tiling-dynamic-resolution"
|
|
assert not use_tiling, "cannot use --use-tiling and --masked-tiling-dynamic-resolution together"
|
|
assert not match_tiling_dynamic_resolution, "cannot combine --masked-tiling-dynamic-resolution with --match-tiling-dynamic-resolution"
|
|
|
|
if dynamic_resolution:
|
|
if masked_tiling_dynamic_resolution:
|
|
num_image_embeddings_per_tile = get_num_image_embeddings(
|
|
img_h=args.img_h,
|
|
img_w=args.img_w,
|
|
patch_dim=args.patch_dim,
|
|
vision_model_type=args.vision_model_type,
|
|
disable_vision_class_token=args.disable_vision_class_token,
|
|
class_token_len=1,
|
|
pixel_shuffle=args.pixel_shuffle,
|
|
use_tile_tags=args.use_tile_tags,
|
|
max_num_tiles=args.max_num_tiles,
|
|
tokenizer_type=args.tokenizer_prompt_format,
|
|
use_image_break_token=args.image_break_token is not None,
|
|
conv_merging=args.conv_merging,
|
|
)
|
|
image_tiling_strategy = MaskedTilingDynamicResolutionStrategy(
|
|
vision_model_type=args.vision_model_type,
|
|
tile_size=args.img_h,
|
|
use_thumbnail=args.use_thumbnail,
|
|
min_num_tiles=1,
|
|
max_num_tiles=args.max_num_tiles,
|
|
embeddings_per_tile=num_image_embeddings_per_tile,
|
|
patch_size=args.patch_dim,
|
|
get_num_embeddings=lambda width, height: get_num_image_embeddings(
|
|
img_h=height,
|
|
img_w=width,
|
|
patch_dim=args.patch_dim,
|
|
vision_model_type=args.vision_model_type,
|
|
disable_vision_class_token=args.disable_vision_class_token,
|
|
class_token_len=1,
|
|
pixel_shuffle=args.pixel_shuffle,
|
|
use_tile_tags=args.use_tile_tags,
|
|
max_num_tiles=args.max_num_tiles,
|
|
tokenizer_type=args.tokenizer_prompt_format,
|
|
use_image_break_token=args.image_break_token is not None,
|
|
conv_merging=args.conv_merging,
|
|
),
|
|
find_closest_aspect_ratio_fn=(
|
|
find_closest_area_weighted_aspect_ratio
|
|
if use_area_weighted_aspect_ratio
|
|
else find_closest_aspect_ratio
|
|
),
|
|
pixel_shuffle=args.pixel_shuffle,
|
|
conv_merging=args.conv_merging,
|
|
)
|
|
elif match_tiling_dynamic_resolution:
|
|
num_image_embeddings_per_tile = get_num_image_embeddings(
|
|
img_h=args.img_h,
|
|
img_w=args.img_w,
|
|
patch_dim=args.patch_dim,
|
|
vision_model_type=args.vision_model_type,
|
|
disable_vision_class_token=args.disable_vision_class_token,
|
|
class_token_len=1,
|
|
pixel_shuffle=args.pixel_shuffle,
|
|
use_tile_tags=args.use_tile_tags,
|
|
max_num_tiles=args.max_num_tiles,
|
|
tokenizer_type=args.tokenizer_prompt_format,
|
|
use_image_break_token=args.image_break_token is not None,
|
|
conv_merging=args.conv_merging,
|
|
)
|
|
image_tiling_strategy = MatchTilingDynamicResolutionStrategy(
|
|
vision_model_type=args.vision_model_type,
|
|
tile_size=args.img_h,
|
|
use_thumbnail=args.use_thumbnail,
|
|
min_num_tiles=1,
|
|
max_num_tiles=args.max_num_tiles,
|
|
embeddings_per_tile=num_image_embeddings_per_tile,
|
|
patch_size=args.patch_dim,
|
|
get_num_embeddings=lambda width, height: get_num_image_embeddings(
|
|
img_h=height,
|
|
img_w=width,
|
|
patch_dim=args.patch_dim,
|
|
vision_model_type=args.vision_model_type,
|
|
disable_vision_class_token=args.disable_vision_class_token,
|
|
class_token_len=1,
|
|
pixel_shuffle=args.pixel_shuffle,
|
|
use_tile_tags=args.use_tile_tags,
|
|
max_num_tiles=args.max_num_tiles,
|
|
tokenizer_type=args.tokenizer_prompt_format,
|
|
use_image_break_token=args.image_break_token is not None,
|
|
conv_merging=args.conv_merging,
|
|
),
|
|
find_closest_aspect_ratio_fn=(
|
|
find_closest_area_weighted_aspect_ratio
|
|
if use_area_weighted_aspect_ratio
|
|
else find_closest_aspect_ratio
|
|
),
|
|
pixel_shuffle=args.pixel_shuffle,
|
|
conv_merging=args.conv_merging,
|
|
)
|
|
else:
|
|
image_tiling_strategy = DynamicResolutionImageTilingStrategy(
|
|
vision_model_type=args.vision_model_type,
|
|
min_num_patches=args.dynamic_resolution_min_patches,
|
|
patch_size=args.patch_dim,
|
|
get_num_embeddings=lambda width, height: get_num_image_embeddings(
|
|
img_h=height,
|
|
img_w=width,
|
|
patch_dim=args.patch_dim,
|
|
vision_model_type=args.vision_model_type,
|
|
disable_vision_class_token=args.disable_vision_class_token,
|
|
class_token_len=1,
|
|
pixel_shuffle=args.pixel_shuffle,
|
|
use_tile_tags=args.use_tile_tags,
|
|
max_num_tiles=args.max_num_tiles,
|
|
tokenizer_type=args.tokenizer_prompt_format,
|
|
use_image_break_token=args.image_break_token is not None,
|
|
conv_merging=args.conv_merging,
|
|
),
|
|
pixel_shuffle=args.pixel_shuffle,
|
|
min_side=args.dynamic_resolution_min_side,
|
|
conv_merging=args.conv_merging,
|
|
use_thumbnail=args.use_thumbnail,
|
|
thumbnail_size=args.img_h,
|
|
thumbnail_area_threshold=args.thumbnail_area_threshold,
|
|
max_num_patches=args.dynamic_resolution_max_patches,
|
|
apply_data_augment=args.apply_data_augment,
|
|
)
|
|
else:
|
|
num_image_embeddings_per_tile = get_num_image_embeddings(
|
|
img_h=args.img_h,
|
|
img_w=args.img_w,
|
|
patch_dim=args.patch_dim,
|
|
vision_model_type=args.vision_model_type,
|
|
disable_vision_class_token=args.disable_vision_class_token,
|
|
class_token_len=1,
|
|
pixel_shuffle=args.pixel_shuffle,
|
|
use_tile_tags=args.use_tile_tags,
|
|
max_num_tiles=args.max_num_tiles,
|
|
tokenizer_type=args.tokenizer_prompt_format,
|
|
use_image_break_token=args.image_break_token is not None,
|
|
conv_merging=args.conv_merging,
|
|
)
|
|
if use_tiling:
|
|
image_strategy = ImageTilingStrategyV1(
|
|
vision_model_type=args.vision_model_type,
|
|
tile_size=args.img_h,
|
|
use_thumbnail=args.use_thumbnail,
|
|
min_num_tiles=1,
|
|
max_num_tiles=args.max_num_tiles,
|
|
embeddings_per_tile=num_image_embeddings_per_tile,
|
|
find_closest_aspect_ratio_fn=(
|
|
find_closest_area_weighted_aspect_ratio
|
|
if use_area_weighted_aspect_ratio
|
|
else find_closest_aspect_ratio
|
|
),
|
|
)
|
|
else:
|
|
image_strategy = NoTilingStrategy(
|
|
vision_model_type=args.vision_model_type,
|
|
embeddings_per_image=num_image_embeddings_per_tile,
|
|
target_width=args.img_w,
|
|
target_height=args.img_h,
|
|
)
|
|
image_tiling_strategy = TileDegradationStrategy(
|
|
image_strategy=image_strategy,
|
|
video_frame_strategy=NoTilingStrategy(
|
|
vision_model_type=args.vision_model_type,
|
|
embeddings_per_image=num_image_embeddings_per_tile,
|
|
target_width=args.img_w,
|
|
target_height=args.img_h,
|
|
),
|
|
embeddings_per_tile=num_image_embeddings_per_tile,
|
|
max_num_tiles=args.max_num_tiles,
|
|
)
|
|
|
|
return image_tiling_strategy
|