vllm/image_processing.py
Netanel Haber 50ffea9826 update
2025-12-15 07:48:56 -08:00

1829 lines
80 KiB
Python

# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE.
from abc import ABC, abstractmethod
from dataclasses import dataclass
import math
from typing import Callable, Optional
import numpy as np
import random
from PIL import Image
import albumentations as A
import einops
import torch
from torchvision import transforms as T
from torchvision.transforms import Compose
from torchvision.transforms.functional import InterpolationMode
from data_loading.conversation_sample import (
ImageMedia,
VideoFrameMedia,
)
IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060]
RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250]
pixel_statistics = {
"clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
"internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
"radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
"huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
"radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
}
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685
# Copyright (c) 2023 OpenGVLab.
def find_closest_aspect_ratio(
aspect_ratio: float,
target_ratios: list[tuple[int, int]],
width: int,
height: int,
image_size: int,
) -> tuple[int, int]:
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def find_closest_area_weighted_aspect_ratio(
aspect_ratio: float,
target_ratios: list[tuple[int, int]],
width: int,
height: int,
image_size: int,
):
"""
Find the best number of tiles based on the aspect ratio and the area covered by the tiles.
"""
best_factor = float("-inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
factor_based_on_area_n_ratio = min(
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
# Mike's optimized ToTensor.
def _fast_to_tensor(pic) -> torch.Tensor:
np_img = np.array(pic, copy=False)
img = torch.from_numpy(np_img)
img = img.permute(2, 0, 1) # HWC to CHW
fp_img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
fp_img.div_(255)
return fp_img
@dataclass
class ImageTilingParams:
media: ImageMedia | VideoFrameMedia
num_tiles: int
num_embeddings: int
class ImageTilingStrategy(ABC):
"""
Base class for image tiling strategies.
A tiling strategy is a function that takes a list of media and returns a list of image tiling parameters.
These can then be used to apply the tiling to the media.
Subclasses must implement the `compute_params` and `apply_params` methods.
The `transform` method is a convenience method that computes the transformation parameters and applies the transformation to the media.
"""
def transform(
self,
media_list: list[ImageMedia | VideoFrameMedia],
num_tokens_available: int | None = None,
) -> list[torch.Tensor]:
"""
Transform the media and compute the transformation parameters.
"""
transform_media_list = self.compute_params(media_list, num_tokens_available)
return [
self.apply_params(transform_media, **kwargs)
for transform_media in transform_media_list
]
@abstractmethod
def compute_params(
self, media_list: list[ImageMedia | VideoFrameMedia], num_tokens_available: int, max_num_tiles: int | None = None, **kwargs
) -> list[ImageTilingParams]:
"""
Compute the transformation parameters and the number of tokens to use for the media.
Args:
media_list: List of media to transform
num_tokens_available: Number of tokens available for all media
max_num_tiles: Maximum number of tiles allowed (optional, defaults to instance's max_num_tiles if not provided)
Returns:
list of transformation parameters with the media
"""
...
@abstractmethod
def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]:
"""
Apply the transformation parameters to the media.
Args:
transform_media: The media to apply the transformation to
Returns:
list of transformed media tensors
"""
...
@abstractmethod
def stack(
self, images: list[torch.Tensor]
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]:
"""
Stack the images into a single tensor.
Args:
media_list: List of images to stack
Returns:
tuple of (stacked media, image sizes, vision cu lengths, vision max lengths)
"""
...
class _FixedSizeStrategy(ImageTilingStrategy):
"""
Base class for fixed size image tiling strategies.
"""
def __init__(
self,
vision_model_type: str,
target_width: int,
target_height: int,
embeddings_per_image: int,
):
self._vision_model_type = vision_model_type
self._target_width = target_width
self._target_height = target_height
self._embeddings_per_image = embeddings_per_image
self._transform = self._build_transform(
(target_width, target_height), vision_model_type
)
# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
@staticmethod
def _build_transform(target_size: tuple[int, int], vision_model_type: str):
"""
Build a transform for a given vision model type and target size.
"""
if vision_model_type in ("siglip", "internvit", "radio", "radio-g", "cradio-g"):
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
transform = T.Compose(
[
T.Lambda(
lambda img: img.convert("RGB") if img.mode != "RGB" else img
),
T.Resize(
(target_size[1], target_size[0]),
interpolation=InterpolationMode.BICUBIC,
),
T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)),
T.Normalize(mean=pixel_mean, std=pixel_std),
]
)
# From the official CLIP repo.
elif vision_model_type == "clip":
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
transform = Compose(
[
T.Resize(
(target_size[1], target_size[0]),
interpolation=InterpolationMode.BICUBIC,
),
T.Lambda(
lambda img: img.convert("RGB") if img.mode != "RGB" else img
),
T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)),
T.Normalize(mean=pixel_mean, std=pixel_std),
]
)
elif vision_model_type.startswith("hf://"):
from megatron.core.models.huggingface.module import get_hf_model_type
model_type = get_hf_model_type(vision_model_type)
if "siglip" in model_type:
from transformers.models.siglip.image_processing_siglip import (
SiglipImageProcessor,
)
processor = SiglipImageProcessor(
size={"height": target_size[1], "width": target_size[0]}
)
def transform(x):
x = x.convert("RGB") if x.mode != "RGB" else x
x = processor(x, return_tensors="pt")
return x["pixel_values"][0]
else:
raise NotImplementedError(
f"image processing not defined for huggingface model {vision_model_type}"
)
else:
raise NotImplementedError(
f"image processing not defined for vision model {vision_model_type}"
)
return transform
def stack(
self, images: list[torch.Tensor]
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
return (
torch.stack(images) if len(images) > 0 else None,
torch.tensor(
[(img.shape[1], img.shape[2]) for img in images], dtype=torch.int32
) if len(images) > 0 else None,
None,
None,
)
class NoTilingStrategy(_FixedSizeStrategy):
"""
A simple image transformation that resizes the image to the target width and height.
"""
def __init__(
self,
vision_model_type: str,
target_width: int,
target_height: int,
embeddings_per_image: int,
):
super().__init__(
vision_model_type=vision_model_type,
target_width=target_width,
target_height=target_height,
embeddings_per_image=embeddings_per_image,
)
def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]:
return [self._transform(transform_media.media.value)]
def compute_params(
self,
media_list: list[ImageMedia | VideoFrameMedia],
num_tokens_available: Optional[int] = None,
max_num_tiles: int | None = None,
**kwargs,
) -> list[ImageTilingParams]:
return [
ImageTilingParams(
media=media, num_tiles=1, num_embeddings=self._embeddings_per_image
)
for media in media_list
]
def __str__(self):
return f"SimpleImageTransform(vision_model_type={self._vision_model_type}, num_tokens_per_image={self._embeddings_per_image})"
@dataclass
class ImageTilingParamsV1(ImageTilingParams):
tiling: tuple[int, int]
class ImageTilingStrategyV1(_FixedSizeStrategy):
"""Tiling image transformation.
This transformation splits the image into a grid of tiles and applies the transformation to each tile.
"""
# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
def __init__(
self,
vision_model_type: str,
tile_size: int,
use_thumbnail: bool,
min_num_tiles: int,
max_num_tiles: int,
embeddings_per_tile: int,
find_closest_aspect_ratio_fn=find_closest_aspect_ratio,
):
super().__init__(
vision_model_type=vision_model_type,
target_width=tile_size,
target_height=tile_size,
embeddings_per_image=embeddings_per_tile,
)
# print(f"Transformation params: {vision_model_type=}, {use_tiling=}, {tile_size=}, {use_thumbnail=}, {augment=}, {min_num_tiles=}, {max_num_tiles=}, {find_closest_aspect_ratio_fn=}")
self._tile_size = tile_size
self._use_thumbnail = use_thumbnail
self._min_num_tiles = min_num_tiles
self._max_num_tiles = max_num_tiles
self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn
# Calculate all possible aspect ratios for each max_num_tiles.
self.target_ratios = {
max_num_tiles: sorted(
set(
(x, y)
for n in range(self._min_num_tiles, max_num_tiles + 1)
for x in range(1, n + 1)
for y in range(1, n + 1)
if x * y <= max_num_tiles and x * y >= self._min_num_tiles
),
key=lambda x: x[0] * x[1],
)
for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1)
}
self.transform = A.Compose([
A.OneOf([
A.GaussNoise(var_limit=(5.0, 30.0)),
A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5)),
], p=0.3),
A.OneOf([
A.MedianBlur(blur_limit=5),
A.GaussianBlur(blur_limit=5),
], p=0.2),
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.5),
A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=15, val_shift_limit=15, p=0.3),
A.ImageCompression(quality_lower=70, quality_upper=100, p=0.3),
])
def apply_params(self, transform_media: ImageTilingParams, data_augment: bool = False, **kwargs) -> list[torch.Tensor]:
assert isinstance(transform_media, ImageTilingParamsV1)
image = transform_media.media.value
if data_augment:
image = self.transform(image=np.asarray(image))["image"]
image = Image.fromarray(image)
# calculate the target width and height
target_width = self._tile_size * transform_media.tiling[0]
target_height = self._tile_size * transform_media.tiling[1]
blocks = transform_media.tiling[0] * transform_media.tiling[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // self._tile_size)) * self._tile_size,
(i // (target_width // self._tile_size)) * self._tile_size,
((i % (target_width // self._tile_size)) + 1) * self._tile_size,
((i // (target_width // self._tile_size)) + 1) * self._tile_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if self._use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((self._tile_size, self._tile_size))
processed_images.append(thumbnail_img)
return [self._transform(img) for img in processed_images]
def compute_params(
self,
media_list: list[ImageMedia | VideoFrameMedia],
num_tokens_available: Optional[int] = None,
max_num_tiles: int | None = None,
data_augment: bool = False,
tiling_augment_prob: float = 0.4,
**kwargs,
) -> list[ImageTilingParamsV1]:
# Use provided max_num_tiles or fall back to instance's max_num_tiles
# Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value
effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles)
max_num_tiles_to_use = min(
num_tokens_available // self._embeddings_per_image, effective_max_num_tiles
)
# calculate the existing image aspect ratio
target_ratios = self.target_ratios[max_num_tiles_to_use]
params = []
for media in media_list:
if isinstance(media, ImageMedia):
img_size = (media.width, media.height)
elif isinstance(media, VideoFrameMedia):
img_size = (media.video_width, media.video_height)
else:
raise ValueError(f"Unsupported media type: {type(media)}")
aspect_ratio = img_size[0] / img_size[1]
# find the closest aspect ratio to the target
tiling = self._find_closest_aspect_ratio_fn(
aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size
)
if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob:
tiling = self.augment_tiling(tiling)
num_tiles = tiling[0] * tiling[1]
if self._use_thumbnail and num_tiles != 1:
num_tiles += 1
params.append(
ImageTilingParamsV1(
media=media,
num_tiles=num_tiles,
num_embeddings=num_tiles * self._embeddings_per_image,
tiling=tiling,
)
)
return params
def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]:
def num_tiles(tiling: tuple[int, int]) -> int:
return tiling[0] * tiling[1]
def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]:
if random.random() < minus_prob:
# Minus one
if tiling[0] == 1 and tiling[1] == 1:
return tiling
elif tiling[0] == 1:
return (tiling[0], tiling[1] - 1)
elif tiling[1] == 1:
return (tiling[0] - 1, tiling[1])
else:
if random.random() < 0.5:
return (tiling[0] - 1, tiling[1])
else:
return (tiling[0], tiling[1] - 1)
else:
# Plus one
if num_tiles(tiling) < self._max_num_tiles:
tiling0 = (tiling[0] + 1, tiling[1])
tiling1 = (tiling[0], tiling[1] + 1)
if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles:
return tiling
elif num_tiles(tiling0) > self._max_num_tiles:
return tiling1
elif num_tiles(tiling1) > self._max_num_tiles:
return tiling0
else:
if random.random() < 0.5:
return tiling0
else:
return tiling1
return tiling
new_tiling = plus_minus_one(tiling)
return new_tiling
def __str__(self):
return f"TilingImageTransform(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, embeddings_per_tile={self._embeddings_per_image}, find_closest_aspect_ratio_fn={self._find_closest_aspect_ratio_fn})"
class TileDegradationStrategy(ImageTilingStrategy):
"""Strategy for tiling images and video frames, each with their own tiling strategy, while trying to match the
number of tokens left in the sample by reducing the number of tiles if needed.
"""
# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
def __init__(
self,
image_strategy: ImageTilingStrategy,
video_frame_strategy: ImageTilingStrategy,
embeddings_per_tile: int,
max_num_tiles: int,
tile_degradation_map: dict[int, int] = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1},
):
self._image_strategy = image_strategy
self._video_frame_strategy = video_frame_strategy
self._embeddings_per_tile = embeddings_per_tile
self._max_num_tiles = max_num_tiles
self._tile_degradation_map = tile_degradation_map
def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]:
if isinstance(transform_media.media, ImageMedia):
return self._image_strategy.apply_params(transform_media, **kwargs)
elif isinstance(transform_media.media, VideoFrameMedia):
return self._video_frame_strategy.apply_params(transform_media, **kwargs)
else:
raise ValueError(f"Unsupported media type: {type(transform_media.media)}")
def compute_params(
self,
media_list: list[ImageMedia | VideoFrameMedia],
num_tokens_available: int | None = None,
max_num_tiles: int | None = None,
**kwargs,
) -> list[ImageTilingParams]:
# Use provided max_num_tiles or fall back to instance's max_num_tiles
effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
max_num_tiles_to_use = effective_max_num_tiles
degradation_map = self._tile_degradation_map
while True:
params = []
img_num_tiles = []
for media in media_list:
if isinstance(media, ImageMedia):
media_params = self._image_strategy.compute_params(
[media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs
)[0]
elif isinstance(media, VideoFrameMedia):
max_num_tiles_to_use = 1
media_params = self._video_frame_strategy.compute_params(
[media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs
)[0]
else:
raise ValueError(f"Unsupported media type: {type(media)}")
img_num_tiles.append(media_params.num_tiles)
params.append(media_params)
if max_num_tiles_to_use == 1 or num_tokens_available is None:
break
if sum(img_num_tiles) * self._embeddings_per_tile > num_tokens_available:
if max_num_tiles_to_use in degradation_map:
max_num_tiles_to_use = degradation_map[max_num_tiles_to_use]
else:
# End of degradation
break
else:
break
return params
def stack(
self, images: list[torch.Tensor]
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
return self._image_strategy.stack(images)
def __str__(self):
return f"TileDegradationImageTransform(max_num_tiles={self._max_num_tiles}, image_transform={self._image_strategy}, video_frame_transform={self._video_frame_strategy})"
@dataclass
class DynamicResolutionParams(ImageTilingParams):
patch_size: tuple[int, int]
class DynamicResolutionImageTilingStrategy(ImageTilingStrategy):
"""Preprocess an image with dynamic resolution for vision transformers.
This function resizes an image to optimize the number of patches while respecting
constraints on minimum/maximum patches, minimum side length, and compatibility
with pixel shuffle or convolution merging operations.
The algorithm works by:
1. Computing the initial patch grid size based on the image dimensions and res_step
2. Scaling the patch grid to fit within the max_patches constraint
3. Ensuring the result has at least min_patches
4. Optionally enforcing a minimum side length constraint
5. Rounding patch dimensions to even numbers for pixel_shuffle/conv_merging compatibility
6. Resizing the image to the computed target dimensions
Note:
The function preserves aspect ratio as much as possible while satisfying all constraints.
When constraints conflict (e.g., min_side vs max_patches), the function prioritizes
staying within max_patches while maximizing the image size.
Example:
>>> from PIL import Image
>>> img = Image.open("example.jpg") # 800x600 image
>>> strategy = DynamicResolutionImageTilingStrategy(vision_model_type="radio", min_patches=4, max_patches=64, res_step=14, get_num_embeddings=lambda x, y: x * y * 2)
>>> params = strategy.compute_params([img])
>>> img_tensor = strategy.apply_params(params[0])
>>> # Returns image resized to maintain aspect ratio with 4-64 patches of size 14x14
"""
def __init__(
self,
vision_model_type: str,
min_num_patches: int,
patch_size: int,
get_num_embeddings: Callable[[int, int], int],
factor_max: float = 1.0,
pixel_shuffle: bool = False,
min_side: int | None = None,
conv_merging: bool = False,
use_thumbnail: bool = False,
thumbnail_size: int = 448,
thumbnail_area_threshold: float = 0.8,
max_num_patches: int = 0,
apply_data_augment: bool = False,
):
"""
Args:
vision_model_type: Vision model type.
min_num_patches: Minimum number of patches required. Defaults to 1.
max_num_patches: Maximum number of patches allowed. Defaults to 0 (no maximum).
patch_size: Resolution step size (patch dimension). Defaults to 16.
get_num_embeddings: Function to get the number of embeddings from the patch size (width, height).
factor_max: Maximum scaling factor to apply. Defaults to 1.0.
pixel_shuffle: Whether to ensure compatibility with pixel shuffle operations by rounding to even patch
dimensions. Defaults to False.
min_side: Minimum side length in pixels. If specified, ensures at least one side meets this constraint.
Defaults to None.
conv_merging: Whether to ensure compatibility with convolution merging by rounding to even patch dimensions.
Defaults to False.
use_thumbnail: Whether to add a thumbnail image when processing. Defaults to False.
thumbnail_size: Size of the thumbnail image (width and height). Defaults to 448.
thumbnail_area_threshold: Maximum area percentage (0.0-1.0) of the resized image relative to thumbnail area
for which to add a thumbnail. If the resized image area is larger than this threshold of the thumbnail
area, no thumbnail will be added. Defaults to 0.8 (80%).
apply_data_augment: Whether to apply data augmentation to the image. Defaults to False.
"""
assert "radio" in vision_model_type, (
"Dynamic resolution is only supported for radio models"
)
self._vision_model_type = vision_model_type
self._min_num_patches = min_num_patches
self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf")
self._patch_size = patch_size
self._get_num_embeddings = get_num_embeddings
self._factor_max = factor_max
self._pixel_shuffle = pixel_shuffle
self._min_side = min_side
self._conv_merging = conv_merging
self._use_thumbnail = use_thumbnail
self._thumbnail_size = thumbnail_size
self._thumbnail_area_threshold = thumbnail_area_threshold
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
self._transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)),
T.Normalize(mean=pixel_mean, std=pixel_std),
]
)
self._apply_data_augment = apply_data_augment
def apply_params(self, params: DynamicResolutionParams, **kwargs) -> list[torch.Tensor]:
# resize the image
resized_img = params.media.value.resize(
(
params.patch_size[0] * self._patch_size,
params.patch_size[1] * self._patch_size,
)
)
processed_images = [resized_img]
# Add thumbnail if enabled and image area is below threshold
if self._use_thumbnail:
# Calculate areas
resized_area = resized_img.size[0] * resized_img.size[1]
thumbnail_area = self._thumbnail_size * self._thumbnail_size
area_ratio = resized_area / thumbnail_area
# Only add thumbnail if resized image area is less than threshold % of thumbnail area
if area_ratio < self._thumbnail_area_threshold:
thumbnail_img = params.media.value.resize((self._thumbnail_size, self._thumbnail_size))
processed_images.append(thumbnail_img)
return [self._transform(img) for img in processed_images]
def process_media(
self,
media: ImageMedia | VideoFrameMedia,
num_tokens_available: int,
data_augment: bool = False,
tiling_augment_prob: float = 0.4,
) -> DynamicResolutionParams:
"""Process a single media item and return its parameters.
Args:
media: The media item to process
num_tokens_available: Number of tokens available for this media
data_augment: Whether to apply data augmentation to the image. Defaults to False.
Returns:
DynamicResolutionParams for the media
"""
current_num_tokens_available = num_tokens_available
if isinstance(media, ImageMedia):
orig_width, orig_height = media.width, media.height
elif isinstance(media, VideoFrameMedia):
orig_width, orig_height = media.video_width, media.video_height
# current_num_tokens_available = 1024 #TEMP: hack for video
else:
raise ValueError(f"Unsupported media type: {type(media)}")
closest_patch_height = round(orig_height / self._patch_size + 0.5)
closest_patch_width = round(orig_width / self._patch_size + 0.5)
patches = closest_patch_height * closest_patch_width
factor = min(math.sqrt(current_num_tokens_available / patches), self._factor_max)
target_patch_height = math.floor(factor * closest_patch_height)
target_patch_width = math.floor(factor * closest_patch_width)
# We only consider self._min_num_patches if it is greater than current_num_tokens_available.
if current_num_tokens_available > self._min_num_patches and target_patch_height * target_patch_width < self._min_num_patches:
up_factor = math.sqrt(
self._min_num_patches / (target_patch_height * target_patch_width)
)
target_patch_height = math.ceil(up_factor * target_patch_height)
target_patch_width = math.ceil(up_factor * target_patch_width)
if (
self._min_side is not None
and min(target_patch_width, target_patch_height) * self._patch_size
< self._min_side
):
if target_patch_width <= target_patch_height:
up_factor = self._min_side / (target_patch_width * self._patch_size)
new_patch_height = math.ceil(up_factor * target_patch_height)
new_patch_width = math.ceil(up_factor * target_patch_width)
if new_patch_height * new_patch_width > current_num_tokens_available:
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
if (
max(current_num_tokens_available // new_patch_width, 1)
* self._patch_size
< self._min_side
):
up_factor = math.sqrt(
current_num_tokens_available
/ (target_patch_height * target_patch_width)
)
target_patch_height = math.floor(
up_factor * target_patch_height
)
target_patch_width = math.floor(
up_factor * target_patch_width
)
target_patch_width = new_patch_width
target_patch_height = max(
current_num_tokens_available // new_patch_width, 1
)
else:
target_patch_height = new_patch_height
target_patch_width = new_patch_width
else:
up_factor = self._min_side / (
target_patch_height * self._patch_size
)
new_patch_height = math.ceil(up_factor * target_patch_height)
new_patch_width = math.ceil(up_factor * target_patch_width)
if new_patch_height * new_patch_width > current_num_tokens_available:
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
if (
max(current_num_tokens_available // new_patch_height, 1)
* self._patch_size
< self._min_side
):
up_factor = math.sqrt(
current_num_tokens_available
/ (target_patch_height * target_patch_width)
)
target_patch_height = math.floor(
up_factor * target_patch_height
)
target_patch_width = math.floor(
up_factor * target_patch_width
)
else:
target_patch_height = new_patch_height
target_patch_width = max(
current_num_tokens_available // new_patch_height, 1
)
else:
target_patch_height = new_patch_height
target_patch_width = new_patch_width
# Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
# or by 4 when BOTH are enabled (two successive 2x reductions)
if self._pixel_shuffle or self._conv_merging:
required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2
rem_h = target_patch_height % required_divisor
if rem_h != 0:
inc_h = required_divisor - rem_h
if (target_patch_height + inc_h) * target_patch_width <= current_num_tokens_available:
target_patch_height += inc_h
else:
target_patch_height = max(required_divisor, target_patch_height - rem_h)
rem_w = target_patch_width % required_divisor
if rem_w != 0:
inc_w = required_divisor - rem_w
if target_patch_height * (target_patch_width + inc_w) <= current_num_tokens_available:
target_patch_width += inc_w
else:
target_patch_width = max(required_divisor, target_patch_width - rem_w)
if data_augment and self._apply_data_augment and random.random() < tiling_augment_prob:
target_patch_width, target_patch_height = self.augment_resolution(target_patch_width, target_patch_height, current_num_tokens_available)
#TEMP: hack for video
if isinstance(media, VideoFrameMedia):
target_patch_width = 32
target_patch_height = 32
# Calculate embeddings for the main dynamic resolution image
num_embeddings = self._get_num_embeddings(
target_patch_width * self._patch_size,
target_patch_height * self._patch_size,
)
token_count = target_patch_width * target_patch_height
# Add thumbnail embeddings if enabled and image area is below threshold
num_tiles = 1 # Base dynamic resolution image
if self._use_thumbnail:
# Calculate areas
resized_area = (target_patch_width * self._patch_size) * (target_patch_height * self._patch_size)
thumbnail_area = self._thumbnail_size * self._thumbnail_size
area_ratio = resized_area / thumbnail_area
# Only add thumbnail if resized image area is less than threshold % of thumbnail area
if area_ratio < self._thumbnail_area_threshold:
num_tiles += 1 # Add 1 for thumbnail
# Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
num_embeddings += self._get_num_embeddings(self._thumbnail_size, self._thumbnail_size)
token_count += self._thumbnail_size // self._patch_size * self._thumbnail_size // self._patch_size
return DynamicResolutionParams(
media=media,
num_tiles=num_tiles,
num_embeddings=num_embeddings,
patch_size=(target_patch_width, target_patch_height),
), token_count
def augment_resolution(self, target_patch_width: int, target_patch_height: int, current_num_tokens_available: int) -> tuple[int, int]:
min_num_patch_one_side = 32
if random.random() < 0.5:
# Minus one
if target_patch_width <= min_num_patch_one_side and target_patch_height <= min_num_patch_one_side:
return target_patch_width, target_patch_height
elif target_patch_width <= min_num_patch_one_side:
return target_patch_width, target_patch_height - min_num_patch_one_side
elif target_patch_height <= min_num_patch_one_side:
return target_patch_width - min_num_patch_one_side, target_patch_height
else:
if random.random() < 0.5:
return target_patch_width - min_num_patch_one_side, target_patch_height
else:
return target_patch_width, target_patch_height - min_num_patch_one_side
else:
# Plus one
if target_patch_width * target_patch_height < current_num_tokens_available:
if random.random() < 0.5:
return target_patch_width + min_num_patch_one_side, target_patch_height
else:
return target_patch_width, target_patch_height + min_num_patch_one_side
return target_patch_width, target_patch_height
def compute_params(
self,
media_list: list[ImageMedia | VideoFrameMedia],
num_tokens_available: int | None = None,
max_num_tiles: int | None = None,
data_augment: bool = False,
**kwargs,
) -> list[ImageTilingParams]:
"""Compute parameters for all media with iterative token budgeting.
Args:
media_list: List of media items to process
num_tokens_available: Total number of tokens available across all media
max_num_tiles: Maximum number of tiles (unused in this implementation)
data_augment: Whether to apply data augmentation to the image. Defaults to False.
Returns:
List of ImageTilingParams for each media item
"""
num_tokens_available = num_tokens_available * (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1)
# When the number of available token is too small, allow self._min_num_patches per media and
# let the sample be truncated.
num_tokens_available = max(num_tokens_available, self._min_num_patches * len(media_list))
# Clip the number of tokens available per media to be between min and max patches.
num_tokens_available_per_media = [
max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
for _ in range(len(media_list))]
# In theory this could be a while True loop, but in case the process_media method slightly
# changes, I want to make sure we don't get stuck in an infinite loop.
for _ in range(10):
# Step 1: Process each media with current token budget
params = []
token_counts = []
for media, tokens_for_media in zip(media_list, num_tokens_available_per_media):
param, token_count = self.process_media(media, tokens_for_media, data_augment=data_augment)
params.append(param)
token_counts.append(token_count)
# Step 2: Check if total tokens is within budget
total_tokens = sum(token_counts)
if total_tokens <= num_tokens_available:
# We're within budget, return the params
return params
# Step 3: We're over budget, need to scale down
# Calculate scaling factor to get under budget
scaling_factor = num_tokens_available / total_tokens
# Recalculate token budgets for each media based on scaling
# Each media gets a proportional share of the total budget
scaled_down_num_tokens_available_per_media = [
max(self._min_num_patches, int(token_count * scaling_factor))
for token_count in token_counts
]
scaled_down = any([
scaled_down_num_tokens_available_per_media[i] < num_tokens_available_per_media[i]
for i in range(len(num_tokens_available_per_media))])
# If there was not scaling down, we're stuck just use min_num_patches per media, else
# try with the scaled down num_tokens_available_per_media.
if not scaled_down:
num_tokens_available_per_media = [self._min_num_patches] * len(media_list)
else:
num_tokens_available_per_media = scaled_down_num_tokens_available_per_media
return params
def stack(
self, images: list[torch.Tensor]
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
imgs_sizes = torch.tensor(
[[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
)
def rearrange_img(x):
py = x.shape[-2] // self._patch_size
px = x.shape[-1] // self._patch_size
x = einops.rearrange(
x,
"c (py yy) (px xx) -> (py px) (c yy xx)",
py=py,
yy=self._patch_size,
px=px,
xx=self._patch_size,
)
return x
if len(images) > 0:
imgs = [rearrange_img(img) for img in images]
current_length = 0
max_length = 0
vision_cu_lengths = [0]
for img in imgs:
if max_length < img.shape[0]:
max_length = img.shape[0]
current_length += img.shape[0]
vision_cu_lengths.append(current_length)
vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
return (
torch.cat(imgs, dim=0).unsqueeze(0),
imgs_sizes,
vision_cu_lengths,
vision_max_lengths,
)
else:
return (
torch.tensor([[0]], dtype=torch.float32),
torch.tensor([[0,0]], dtype=torch.int32),
None,
None,
)
def __str__(self):
return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})"
@dataclass
class MatchTilingDynamicResolutionParams(ImageTilingParams):
tiling: tuple[int, int]
class MatchTilingDynamicResolutionStrategy(ImageTilingStrategy):
"""
Strategy that uses tiling logic to determine optimal image dimensions but processes
the image as a single dynamic resolution image instead of splitting into tiles.
This combines the aspect ratio optimization from ImageTilingStrategyV1 with the
dynamic resolution processing from DynamicResolutionImageTilingStrategy.
Also includes tile degradation logic similar to TileDegradationStrategy.
"""
def __init__(
self,
vision_model_type: str,
tile_size: int,
use_thumbnail: bool,
min_num_tiles: int,
max_num_tiles: int,
embeddings_per_tile: int,
patch_size: int,
get_num_embeddings: Callable[[int, int], int],
find_closest_aspect_ratio_fn=find_closest_aspect_ratio,
pixel_shuffle: bool = False,
conv_merging: bool = False,
tile_degradation_map: dict[int, int] = None,
video_frame_strategy: ImageTilingStrategy = None,
enable_tile_degradation: bool = True,
):
"""
Args:
vision_model_type: Vision model type (should support dynamic resolution)
tile_size: Size of each tile for tiling calculation
use_thumbnail: Whether tiling logic should include thumbnail
min_num_tiles: Minimum number of tiles for tiling calculation
max_num_tiles: Maximum number of tiles for tiling calculation
embeddings_per_tile: Embeddings per tile for tiling calculation
patch_size: Patch size for dynamic resolution processing
get_num_embeddings: Function to get number of embeddings from dimensions
find_closest_aspect_ratio_fn: Function to find closest aspect ratio
pixel_shuffle: Whether to ensure compatibility with pixel shuffle
conv_merging: Whether to ensure compatibility with convolution merging
tile_degradation_map: Map for degrading tiles when tokens are insufficient
video_frame_strategy: Strategy for processing video frames
enable_tile_degradation: Whether to enable tile degradation (default: True)
"""
assert "radio" in vision_model_type, (
"MatchTilingDynamicResolution is only supported for radio models"
)
self._vision_model_type = vision_model_type
self._tile_size = tile_size
self._use_thumbnail = use_thumbnail
self._min_num_tiles = min_num_tiles
self._max_num_tiles = max_num_tiles
self._embeddings_per_tile = embeddings_per_tile
self._patch_size = patch_size
self._get_num_embeddings = get_num_embeddings
self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn
self._pixel_shuffle = pixel_shuffle
self._conv_merging = conv_merging
self._enable_tile_degradation = enable_tile_degradation
# Tile degradation logic (similar to TileDegradationStrategy)
if tile_degradation_map is None:
self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1}
else:
self._tile_degradation_map = tile_degradation_map
# Video frame strategy (similar to TileDegradationStrategy)
if video_frame_strategy is None:
self._video_frame_strategy = NoTilingStrategy(
vision_model_type=vision_model_type,
target_width=tile_size,
target_height=tile_size,
embeddings_per_image=embeddings_per_tile,
)
else:
self._video_frame_strategy = video_frame_strategy
# Calculate all possible aspect ratios for each max_num_tiles (borrowed from ImageTilingStrategyV1)
self.target_ratios = {
max_num_tiles: sorted(
set(
(x, y)
for n in range(self._min_num_tiles, max_num_tiles + 1)
for x in range(1, n + 1)
for y in range(1, n + 1)
if x * y <= max_num_tiles and x * y >= self._min_num_tiles
),
key=lambda x: x[0] * x[1],
)
for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1)
}
# Set up transform for dynamic resolution processing
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
self._transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.ToTensor(),
T.Normalize(mean=pixel_mean, std=pixel_std),
]
)
def apply_params(self, params: MatchTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]:
# Handle video frames using the video frame strategy
if isinstance(params.media, VideoFrameMedia):
return self._video_frame_strategy.apply_params(params, **kwargs)
# Handle images with dynamic resolution processing
image = params.media.value
# Calculate the target width and height (same logic as ImageTilingStrategyV1)
target_width = self._tile_size * params.tiling[0]
target_height = self._tile_size * params.tiling[1]
# Resize the image to the target dimensions (same as ImageTilingStrategyV1)
resized_img = image.resize((target_width, target_height))
# Process as single dynamic resolution image
processed_images = [resized_img]
# Add thumbnail if use_thumbnail=True and there's more than 1 tile (same as ImageTilingStrategyV1)
blocks = params.tiling[0] * params.tiling[1]
if self._use_thumbnail and blocks != 1:
thumbnail_img = image.resize((self._tile_size, self._tile_size))
processed_images.append(thumbnail_img)
return [self._transform(img) for img in processed_images]
def compute_params(
self,
media_list: list[ImageMedia | VideoFrameMedia],
num_tokens_available: int | None = None,
max_num_tiles: int | None = None,
**kwargs,
) -> list[MatchTilingDynamicResolutionParams]:
# Implement tile degradation logic similar to TileDegradationStrategy
# Use provided max_num_tiles or fall back to instance's max_num_tiles
# Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value
effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles)
max_num_tiles_to_use = effective_max_num_tiles
degradation_map = self._tile_degradation_map
while True:
params = []
total_embeddings_needed = 0
for media in media_list:
if isinstance(media, ImageMedia):
# Use tiling logic for images
img_size = (media.width, media.height)
aspect_ratio = img_size[0] / img_size[1]
# Find the closest aspect ratio to the target
target_ratios = self.target_ratios[max_num_tiles_to_use]
tiling = self._find_closest_aspect_ratio_fn(
aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size
)
# Calculate target dimensions for dynamic resolution processing
target_width = self._tile_size * tiling[0]
target_height = self._tile_size * tiling[1]
num_embeddings = self._get_num_embeddings(target_width, target_height)
# Account for thumbnail (same logic as ImageTilingStrategyV1)
num_tiles = 1 # Base dynamic resolution image
blocks = tiling[0] * tiling[1]
if self._use_thumbnail and blocks != 1:
num_tiles += 1 # Add 1 for thumbnail
# Add embeddings for thumbnail (tile_size x tile_size)
num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size)
media_params = MatchTilingDynamicResolutionParams(
media=media,
num_tiles=num_tiles,
num_embeddings=num_embeddings,
tiling=tiling,
)
elif isinstance(media, VideoFrameMedia):
# Use video frame strategy for video frames (always 1 tile)
video_params = self._video_frame_strategy.compute_params(
[media], 1 * self._embeddings_per_tile
)[0]
media_params = MatchTilingDynamicResolutionParams(
media=media,
num_tiles=video_params.num_tiles,
num_embeddings=video_params.num_embeddings,
tiling=(1, 1), # Video frames always use 1x1 tiling
)
else:
raise ValueError(f"Unsupported media type: {type(media)}")
params.append(media_params)
total_embeddings_needed += media_params.num_embeddings
# Check if we need to degrade (only if degradation is enabled)
if not self._enable_tile_degradation:
break
if max_num_tiles_to_use == 1 or num_tokens_available is None:
break
if total_embeddings_needed > num_tokens_available:
if max_num_tiles_to_use in degradation_map:
max_num_tiles_to_use = degradation_map[max_num_tiles_to_use]
# Recalculate target ratios for the new max_num_tiles_to_use
if max_num_tiles_to_use not in self.target_ratios:
self.target_ratios[max_num_tiles_to_use] = sorted(
set(
(x, y)
for n in range(self._min_num_tiles, max_num_tiles_to_use + 1)
for x in range(1, n + 1)
for y in range(1, n + 1)
if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles
),
key=lambda x: x[0] * x[1],
)
else:
# End of degradation
break
else:
break
return params
def stack(
self, images: list[torch.Tensor]
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]:
"""Stack images using dynamic resolution approach with sequence packing"""
imgs_sizes = torch.tensor(
[[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
)
def rearrange_img(x):
py = x.shape[-2] // self._patch_size
px = x.shape[-1] // self._patch_size
x = einops.rearrange(
x,
"c (py yy) (px xx) -> (py px) (c yy xx)",
py=py,
yy=self._patch_size,
px=px,
xx=self._patch_size,
)
return x
if len(images) > 0:
imgs = [rearrange_img(img) for img in images]
current_length = 0
max_length = 0
vision_cu_lengths = [0]
for img in imgs:
if max_length < img.shape[0]:
max_length = img.shape[0]
current_length += img.shape[0]
vision_cu_lengths.append(current_length)
vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
return (
torch.cat(imgs, dim=0).unsqueeze(0),
imgs_sizes,
vision_cu_lengths,
vision_max_lengths,
)
else:
return (
torch.tensor([[0]], dtype=torch.float32),
torch.tensor([[0,0]], dtype=torch.int32),
None,
None,
)
def __str__(self):
return f"MatchTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})"
@dataclass
class MaskedTilingDynamicResolutionParams(ImageTilingParams):
tiling: tuple[int, int]
class MaskedTilingDynamicResolutionStrategy(ImageTilingStrategy):
"""
Like MatchTilingDynamicResolutionStrategy, but ensures tiles are isolated in the
vision encoder by emitting per-tile packed samples (block-diagonal attention across tiles).
"""
def __init__(
self,
vision_model_type: str,
tile_size: int,
use_thumbnail: bool,
min_num_tiles: int,
max_num_tiles: int,
embeddings_per_tile: int,
patch_size: int,
get_num_embeddings: Callable[[int, int], int],
find_closest_aspect_ratio_fn=find_closest_aspect_ratio,
pixel_shuffle: bool = False,
conv_merging: bool = False,
tile_degradation_map: dict[int, int] = None,
video_frame_strategy: ImageTilingStrategy = None,
enable_tile_degradation: bool = True,
):
assert "radio" in vision_model_type, (
"MaskedTilingDynamicResolution is only supported for radio models"
)
self._vision_model_type = vision_model_type
self._tile_size = tile_size
self._use_thumbnail = use_thumbnail
self._min_num_tiles = min_num_tiles
self._max_num_tiles = max_num_tiles
self._embeddings_per_tile = embeddings_per_tile
self._patch_size = patch_size
self._get_num_embeddings = get_num_embeddings
self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn
self._pixel_shuffle = pixel_shuffle
self._conv_merging = conv_merging
self._enable_tile_degradation = enable_tile_degradation
if tile_degradation_map is None:
self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1}
else:
self._tile_degradation_map = tile_degradation_map
if video_frame_strategy is None:
self._video_frame_strategy = NoTilingStrategy(
vision_model_type=vision_model_type,
target_width=tile_size,
target_height=tile_size,
embeddings_per_image=embeddings_per_tile,
)
else:
self._video_frame_strategy = video_frame_strategy
self.target_ratios = {
max_num_tiles: sorted(
set(
(x, y)
for n in range(self._min_num_tiles, max_num_tiles + 1)
for x in range(1, n + 1)
for y in range(1, n + 1)
if x * y <= max_num_tiles and x * y >= self._min_num_tiles
),
key=lambda x: x[0] * x[1],
)
for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1)
}
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
self._transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.ToTensor(),
T.Normalize(mean=pixel_mean, std=pixel_std),
]
)
def apply_params(self, params: MaskedTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]:
# Handle video frames using the video frame strategy
if isinstance(params.media, VideoFrameMedia):
return self._video_frame_strategy.apply_params(params, **kwargs)
image = params.media.value
nx, ny = params.tiling
target_width = self._tile_size * nx
target_height = self._tile_size * ny
resized_img = image.resize((target_width, target_height))
processed_images = []
# Emit per-tile images (each becomes an isolated packed sample later)
for j in range(ny):
for i in range(nx):
box = (
i * self._tile_size,
j * self._tile_size,
(i + 1) * self._tile_size,
(j + 1) * self._tile_size,
)
tile_img = resized_img.crop(box)
processed_images.append(tile_img)
if self._use_thumbnail and (nx * ny) != 1:
thumbnail_img = image.resize((self._tile_size, self._tile_size))
processed_images.append(thumbnail_img)
return [self._transform(img) for img in processed_images]
def compute_params(
self,
media_list: list[ImageMedia | VideoFrameMedia],
num_tokens_available: int | None = None,
max_num_tiles: int | None = None,
data_augment: bool = False,
tiling_augment_prob: float = 0.4,
**kwargs,
) -> list[MaskedTilingDynamicResolutionParams]:
effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles)
max_num_tiles_to_use = effective_max_num_tiles
degradation_map = self._tile_degradation_map
while True:
params = []
total_embeddings_needed = 0
for media in media_list:
if isinstance(media, ImageMedia):
img_size = (media.width, media.height)
aspect_ratio = img_size[0] / img_size[1]
target_ratios = self.target_ratios[max_num_tiles_to_use]
tiling = self._find_closest_aspect_ratio_fn(
aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size
)
# Apply tiling augmentation if enabled
if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob:
tiling = self.augment_tiling(tiling)
blocks = tiling[0] * tiling[1]
# Each tile is tile_size x tile_size
per_tile_emb = self._get_num_embeddings(self._tile_size, self._tile_size)
num_embeddings = blocks * per_tile_emb
num_tiles = blocks
if self._use_thumbnail and blocks != 1:
num_tiles += 1
num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size)
media_params = MaskedTilingDynamicResolutionParams(
media=media,
num_tiles=num_tiles,
num_embeddings=num_embeddings,
tiling=tiling,
)
elif isinstance(media, VideoFrameMedia):
video_params = self._video_frame_strategy.compute_params(
[media], 1 * self._embeddings_per_tile
)[0]
media_params = MaskedTilingDynamicResolutionParams(
media=media,
num_tiles=video_params.num_tiles,
num_embeddings=video_params.num_embeddings,
tiling=(1, 1),
)
else:
raise ValueError(f"Unsupported media type: {type(media)}")
params.append(media_params)
total_embeddings_needed += media_params.num_embeddings
if not self._enable_tile_degradation:
break
if max_num_tiles_to_use == 1 or num_tokens_available is None:
break
if total_embeddings_needed > num_tokens_available:
if max_num_tiles_to_use in degradation_map:
max_num_tiles_to_use = degradation_map[max_num_tiles_to_use]
if max_num_tiles_to_use not in self.target_ratios:
self.target_ratios[max_num_tiles_to_use] = sorted(
set(
(x, y)
for n in range(self._min_num_tiles, max_num_tiles_to_use + 1)
for x in range(1, n + 1)
for y in range(1, n + 1)
if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles
),
key=lambda x: x[0] * x[1],
)
else:
break
else:
break
return params
def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]:
def num_tiles(tiling: tuple[int, int]) -> int:
return tiling[0] * tiling[1]
def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]:
if random.random() < minus_prob:
# Minus one
if tiling[0] == 1 and tiling[1] == 1:
return tiling
elif tiling[0] == 1:
return (tiling[0], tiling[1] - 1)
elif tiling[1] == 1:
return (tiling[0] - 1, tiling[1])
else:
if random.random() < 0.5:
return (tiling[0] - 1, tiling[1])
else:
return (tiling[0], tiling[1] - 1)
else:
# Plus one
if num_tiles(tiling) < self._max_num_tiles:
tiling0 = (tiling[0] + 1, tiling[1])
tiling1 = (tiling[0], tiling[1] + 1)
if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles:
return tiling
elif num_tiles(tiling0) > self._max_num_tiles:
return tiling1
elif num_tiles(tiling1) > self._max_num_tiles:
return tiling0
else:
if random.random() < 0.5:
return tiling0
else:
return tiling1
return tiling
new_tiling = plus_minus_one(tiling)
return new_tiling
def stack(
self, images: list[torch.Tensor]
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]:
# Identical to dynamic resolution packing; each tile is already an independent image sample
imgs_sizes = torch.tensor(
[[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
)
def rearrange_img(x):
py = x.shape[-2] // self._patch_size
px = x.shape[-1] // self._patch_size
x = einops.rearrange(
x,
"c (py yy) (px xx) -> (py px) (c yy xx)",
py=py,
yy=self._patch_size,
px=px,
xx=self._patch_size,
)
return x
if len(images) > 0:
imgs = [rearrange_img(img) for img in images]
current_length = 0
max_length = 0
vision_cu_lengths = [0]
for img in imgs:
if max_length < img.shape[0]:
max_length = img.shape[0]
current_length += img.shape[0]
vision_cu_lengths.append(current_length)
vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
return (
torch.cat(imgs, dim=0).unsqueeze(0),
imgs_sizes,
vision_cu_lengths,
vision_max_lengths,
)
else:
return (
torch.tensor([[0]], dtype=torch.float32),
torch.tensor([[0,0]], dtype=torch.int32),
None,
None,
)
def __str__(self):
return f"MaskedTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})"
def create_image_tiling_strategy(args):
"""
Create an image tiling strategy based on the provided arguments.
This function encapsulates the logic for creating the appropriate image tiling strategy
based on the training/evaluation configuration. It can be used by both training (task_encoder)
and evaluation code outside of data_loading/.
Args:
args: Arguments object with the following relevant attributes:
- img_h, img_w: Image height and width
- patch_dim: Patch dimension
- vision_model_type: Vision model type (e.g., 'radio', 'clip', 'siglip')
- disable_vision_class_token: Whether to disable vision class token
- pixel_shuffle: Whether to use pixel shuffle
- use_tile_tags: Whether to use tile tags
- max_num_tiles: Maximum number of tiles
- tokenizer_prompt_format: Tokenizer prompt format
- image_break_token: Image break token (optional)
- conv_merging: Whether to use convolution merging
- dynamic_resolution: Whether to use dynamic resolution
- match_tiling_dynamic_resolution: Whether to match tiling with dynamic resolution
- use_area_weighted_aspect_ratio: Whether to use area-weighted aspect ratio
- use_thumbnail: Whether to use thumbnail
- dynamic_resolution_min_patches: Minimum number of patches for dynamic resolution
- dynamic_resolution_min_side: Minimum side length for dynamic resolution (optional)
- thumbnail_area_threshold: Thumbnail area threshold (optional)
- use_tiling: Whether to use tiling
Returns:
ImageTilingStrategy: The created image tiling strategy
"""
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
assert args.img_h == args.img_w, "img_h and img_w must be the same"
match_tiling_dynamic_resolution = args.match_tiling_dynamic_resolution
masked_tiling_dynamic_resolution = getattr(args, "masked_tiling_dynamic_resolution", False)
dynamic_resolution = args.dynamic_resolution
use_tiling = args.use_tiling
use_area_weighted_aspect_ratio = args.use_area_weighted_aspect_ratio
if match_tiling_dynamic_resolution:
assert dynamic_resolution, "must enable --dynamic-resolution if using --match-tiling-dynamic-resolution"
assert not use_tiling, "cannot use --use-tiling and --match-tiling-dynamic-resolution together"
if masked_tiling_dynamic_resolution:
assert dynamic_resolution, "must enable --dynamic-resolution if using --masked-tiling-dynamic-resolution"
assert not use_tiling, "cannot use --use-tiling and --masked-tiling-dynamic-resolution together"
assert not match_tiling_dynamic_resolution, "cannot combine --masked-tiling-dynamic-resolution with --match-tiling-dynamic-resolution"
if dynamic_resolution:
if masked_tiling_dynamic_resolution:
num_image_embeddings_per_tile = get_num_image_embeddings(
img_h=args.img_h,
img_w=args.img_w,
patch_dim=args.patch_dim,
vision_model_type=args.vision_model_type,
disable_vision_class_token=args.disable_vision_class_token,
class_token_len=1,
pixel_shuffle=args.pixel_shuffle,
use_tile_tags=args.use_tile_tags,
max_num_tiles=args.max_num_tiles,
tokenizer_type=args.tokenizer_prompt_format,
use_image_break_token=args.image_break_token is not None,
conv_merging=args.conv_merging,
)
image_tiling_strategy = MaskedTilingDynamicResolutionStrategy(
vision_model_type=args.vision_model_type,
tile_size=args.img_h,
use_thumbnail=args.use_thumbnail,
min_num_tiles=1,
max_num_tiles=args.max_num_tiles,
embeddings_per_tile=num_image_embeddings_per_tile,
patch_size=args.patch_dim,
get_num_embeddings=lambda width, height: get_num_image_embeddings(
img_h=height,
img_w=width,
patch_dim=args.patch_dim,
vision_model_type=args.vision_model_type,
disable_vision_class_token=args.disable_vision_class_token,
class_token_len=1,
pixel_shuffle=args.pixel_shuffle,
use_tile_tags=args.use_tile_tags,
max_num_tiles=args.max_num_tiles,
tokenizer_type=args.tokenizer_prompt_format,
use_image_break_token=args.image_break_token is not None,
conv_merging=args.conv_merging,
),
find_closest_aspect_ratio_fn=(
find_closest_area_weighted_aspect_ratio
if use_area_weighted_aspect_ratio
else find_closest_aspect_ratio
),
pixel_shuffle=args.pixel_shuffle,
conv_merging=args.conv_merging,
)
elif match_tiling_dynamic_resolution:
num_image_embeddings_per_tile = get_num_image_embeddings(
img_h=args.img_h,
img_w=args.img_w,
patch_dim=args.patch_dim,
vision_model_type=args.vision_model_type,
disable_vision_class_token=args.disable_vision_class_token,
class_token_len=1,
pixel_shuffle=args.pixel_shuffle,
use_tile_tags=args.use_tile_tags,
max_num_tiles=args.max_num_tiles,
tokenizer_type=args.tokenizer_prompt_format,
use_image_break_token=args.image_break_token is not None,
conv_merging=args.conv_merging,
)
image_tiling_strategy = MatchTilingDynamicResolutionStrategy(
vision_model_type=args.vision_model_type,
tile_size=args.img_h,
use_thumbnail=args.use_thumbnail,
min_num_tiles=1,
max_num_tiles=args.max_num_tiles,
embeddings_per_tile=num_image_embeddings_per_tile,
patch_size=args.patch_dim,
get_num_embeddings=lambda width, height: get_num_image_embeddings(
img_h=height,
img_w=width,
patch_dim=args.patch_dim,
vision_model_type=args.vision_model_type,
disable_vision_class_token=args.disable_vision_class_token,
class_token_len=1,
pixel_shuffle=args.pixel_shuffle,
use_tile_tags=args.use_tile_tags,
max_num_tiles=args.max_num_tiles,
tokenizer_type=args.tokenizer_prompt_format,
use_image_break_token=args.image_break_token is not None,
conv_merging=args.conv_merging,
),
find_closest_aspect_ratio_fn=(
find_closest_area_weighted_aspect_ratio
if use_area_weighted_aspect_ratio
else find_closest_aspect_ratio
),
pixel_shuffle=args.pixel_shuffle,
conv_merging=args.conv_merging,
)
else:
image_tiling_strategy = DynamicResolutionImageTilingStrategy(
vision_model_type=args.vision_model_type,
min_num_patches=args.dynamic_resolution_min_patches,
patch_size=args.patch_dim,
get_num_embeddings=lambda width, height: get_num_image_embeddings(
img_h=height,
img_w=width,
patch_dim=args.patch_dim,
vision_model_type=args.vision_model_type,
disable_vision_class_token=args.disable_vision_class_token,
class_token_len=1,
pixel_shuffle=args.pixel_shuffle,
use_tile_tags=args.use_tile_tags,
max_num_tiles=args.max_num_tiles,
tokenizer_type=args.tokenizer_prompt_format,
use_image_break_token=args.image_break_token is not None,
conv_merging=args.conv_merging,
),
pixel_shuffle=args.pixel_shuffle,
min_side=args.dynamic_resolution_min_side,
conv_merging=args.conv_merging,
use_thumbnail=args.use_thumbnail,
thumbnail_size=args.img_h,
thumbnail_area_threshold=args.thumbnail_area_threshold,
max_num_patches=args.dynamic_resolution_max_patches,
apply_data_augment=args.apply_data_augment,
)
else:
num_image_embeddings_per_tile = get_num_image_embeddings(
img_h=args.img_h,
img_w=args.img_w,
patch_dim=args.patch_dim,
vision_model_type=args.vision_model_type,
disable_vision_class_token=args.disable_vision_class_token,
class_token_len=1,
pixel_shuffle=args.pixel_shuffle,
use_tile_tags=args.use_tile_tags,
max_num_tiles=args.max_num_tiles,
tokenizer_type=args.tokenizer_prompt_format,
use_image_break_token=args.image_break_token is not None,
conv_merging=args.conv_merging,
)
if use_tiling:
image_strategy = ImageTilingStrategyV1(
vision_model_type=args.vision_model_type,
tile_size=args.img_h,
use_thumbnail=args.use_thumbnail,
min_num_tiles=1,
max_num_tiles=args.max_num_tiles,
embeddings_per_tile=num_image_embeddings_per_tile,
find_closest_aspect_ratio_fn=(
find_closest_area_weighted_aspect_ratio
if use_area_weighted_aspect_ratio
else find_closest_aspect_ratio
),
)
else:
image_strategy = NoTilingStrategy(
vision_model_type=args.vision_model_type,
embeddings_per_image=num_image_embeddings_per_tile,
target_width=args.img_w,
target_height=args.img_h,
)
image_tiling_strategy = TileDegradationStrategy(
image_strategy=image_strategy,
video_frame_strategy=NoTilingStrategy(
vision_model_type=args.vision_model_type,
embeddings_per_image=num_image_embeddings_per_tile,
target_width=args.img_w,
target_height=args.img_h,
),
embeddings_per_tile=num_image_embeddings_per_tile,
max_num_tiles=args.max_num_tiles,
)
return image_tiling_strategy