diff --git a/image_processing.py b/image_processing.py deleted file mode 100644 index 5d43e764acadd..0000000000000 --- a/image_processing.py +++ /dev/null @@ -1,1828 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE. -from abc import ABC, abstractmethod -from dataclasses import dataclass -import math -from typing import Callable, Optional -import numpy as np -import random -from PIL import Image -import albumentations as A - -import einops -import torch -from torchvision import transforms as T -from torchvision.transforms import Compose -from torchvision.transforms.functional import InterpolationMode - -from data_loading.conversation_sample import ( - ImageMedia, - VideoFrameMedia, -) - -IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] -IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] -SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5] -SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5] -CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] -CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] -RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060] -RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250] - - -pixel_statistics = { - "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), - "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), - "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD), - "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), - "radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), -} - - -# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685 -# Copyright (c) 2023 OpenGVLab. -def find_closest_aspect_ratio( - aspect_ratio: float, - target_ratios: list[tuple[int, int]], - width: int, - height: int, - image_size: int, -) -> tuple[int, int]: - best_ratio_diff = float("inf") - best_ratio = (1, 1) - area = width * height - for ratio in target_ratios: - target_aspect_ratio = ratio[0] / ratio[1] - ratio_diff = abs(aspect_ratio - target_aspect_ratio) - if ratio_diff < best_ratio_diff: - best_ratio_diff = ratio_diff - best_ratio = ratio - elif ratio_diff == best_ratio_diff: - if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: - best_ratio = ratio - return best_ratio - - -def find_closest_area_weighted_aspect_ratio( - aspect_ratio: float, - target_ratios: list[tuple[int, int]], - width: int, - height: int, - image_size: int, -): - """ - Find the best number of tiles based on the aspect ratio and the area covered by the tiles. - """ - best_factor = float("-inf") - best_ratio = (1, 1) - area = width * height - for ratio in target_ratios: - target_aspect_ratio = ratio[0] / ratio[1] - factor_based_on_area_n_ratio = min( - (ratio[0] * ratio[1] * image_size * image_size) / area, 0.6 - ) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio) - if factor_based_on_area_n_ratio > best_factor: - best_factor = factor_based_on_area_n_ratio - best_ratio = ratio - return best_ratio - - -# Mike's optimized ToTensor. -def _fast_to_tensor(pic) -> torch.Tensor: - np_img = np.array(pic, copy=False) - img = torch.from_numpy(np_img) - img = img.permute(2, 0, 1) # HWC to CHW - fp_img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format) - fp_img.div_(255) - return fp_img - - -@dataclass -class ImageTilingParams: - media: ImageMedia | VideoFrameMedia - num_tiles: int - num_embeddings: int - - -class ImageTilingStrategy(ABC): - """ - Base class for image tiling strategies. - A tiling strategy is a function that takes a list of media and returns a list of image tiling parameters. - These can then be used to apply the tiling to the media. - - Subclasses must implement the `compute_params` and `apply_params` methods. - - The `transform` method is a convenience method that computes the transformation parameters and applies the transformation to the media. - - """ - - def transform( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: int | None = None, - ) -> list[torch.Tensor]: - """ - Transform the media and compute the transformation parameters. - """ - transform_media_list = self.compute_params(media_list, num_tokens_available) - return [ - self.apply_params(transform_media, **kwargs) - for transform_media in transform_media_list - ] - - @abstractmethod - def compute_params( - self, media_list: list[ImageMedia | VideoFrameMedia], num_tokens_available: int, max_num_tiles: int | None = None, **kwargs - ) -> list[ImageTilingParams]: - """ - Compute the transformation parameters and the number of tokens to use for the media. - - Args: - media_list: List of media to transform - num_tokens_available: Number of tokens available for all media - max_num_tiles: Maximum number of tiles allowed (optional, defaults to instance's max_num_tiles if not provided) - - Returns: - list of transformation parameters with the media - """ - ... - - @abstractmethod - def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]: - """ - Apply the transformation parameters to the media. - - Args: - transform_media: The media to apply the transformation to - - Returns: - list of transformed media tensors - """ - ... - - @abstractmethod - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]: - """ - Stack the images into a single tensor. - - Args: - media_list: List of images to stack - - Returns: - tuple of (stacked media, image sizes, vision cu lengths, vision max lengths) - """ - ... - - -class _FixedSizeStrategy(ImageTilingStrategy): - """ - Base class for fixed size image tiling strategies. - """ - - def __init__( - self, - vision_model_type: str, - target_width: int, - target_height: int, - embeddings_per_image: int, - ): - self._vision_model_type = vision_model_type - self._target_width = target_width - self._target_height = target_height - self._embeddings_per_image = embeddings_per_image - self._transform = self._build_transform( - (target_width, target_height), vision_model_type - ) - - # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 - # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 - @staticmethod - def _build_transform(target_size: tuple[int, int], vision_model_type: str): - """ - Build a transform for a given vision model type and target size. - """ - if vision_model_type in ("siglip", "internvit", "radio", "radio-g", "cradio-g"): - pixel_mean, pixel_std = pixel_statistics[vision_model_type] - - transform = T.Compose( - [ - T.Lambda( - lambda img: img.convert("RGB") if img.mode != "RGB" else img - ), - T.Resize( - (target_size[1], target_size[0]), - interpolation=InterpolationMode.BICUBIC, - ), - T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)), - T.Normalize(mean=pixel_mean, std=pixel_std), - ] - ) - # From the official CLIP repo. - elif vision_model_type == "clip": - pixel_mean, pixel_std = pixel_statistics[vision_model_type] - - transform = Compose( - [ - T.Resize( - (target_size[1], target_size[0]), - interpolation=InterpolationMode.BICUBIC, - ), - T.Lambda( - lambda img: img.convert("RGB") if img.mode != "RGB" else img - ), - T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)), - T.Normalize(mean=pixel_mean, std=pixel_std), - ] - ) - elif vision_model_type.startswith("hf://"): - from megatron.core.models.huggingface.module import get_hf_model_type - - model_type = get_hf_model_type(vision_model_type) - if "siglip" in model_type: - from transformers.models.siglip.image_processing_siglip import ( - SiglipImageProcessor, - ) - - processor = SiglipImageProcessor( - size={"height": target_size[1], "width": target_size[0]} - ) - - def transform(x): - x = x.convert("RGB") if x.mode != "RGB" else x - x = processor(x, return_tensors="pt") - return x["pixel_values"][0] - else: - raise NotImplementedError( - f"image processing not defined for huggingface model {vision_model_type}" - ) - else: - raise NotImplementedError( - f"image processing not defined for vision model {vision_model_type}" - ) - - return transform - - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: - return ( - torch.stack(images) if len(images) > 0 else None, - torch.tensor( - [(img.shape[1], img.shape[2]) for img in images], dtype=torch.int32 - ) if len(images) > 0 else None, - None, - None, - ) - - -class NoTilingStrategy(_FixedSizeStrategy): - """ - A simple image transformation that resizes the image to the target width and height. - """ - - def __init__( - self, - vision_model_type: str, - target_width: int, - target_height: int, - embeddings_per_image: int, - ): - super().__init__( - vision_model_type=vision_model_type, - target_width=target_width, - target_height=target_height, - embeddings_per_image=embeddings_per_image, - ) - - def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]: - return [self._transform(transform_media.media.value)] - - def compute_params( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: Optional[int] = None, - max_num_tiles: int | None = None, - **kwargs, - ) -> list[ImageTilingParams]: - return [ - ImageTilingParams( - media=media, num_tiles=1, num_embeddings=self._embeddings_per_image - ) - for media in media_list - ] - - def __str__(self): - return f"SimpleImageTransform(vision_model_type={self._vision_model_type}, num_tokens_per_image={self._embeddings_per_image})" - - -@dataclass -class ImageTilingParamsV1(ImageTilingParams): - tiling: tuple[int, int] - - -class ImageTilingStrategyV1(_FixedSizeStrategy): - """Tiling image transformation. - - This transformation splits the image into a grid of tiles and applies the transformation to each tile. - """ - - # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 - # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 - - def __init__( - self, - vision_model_type: str, - tile_size: int, - use_thumbnail: bool, - min_num_tiles: int, - max_num_tiles: int, - embeddings_per_tile: int, - find_closest_aspect_ratio_fn=find_closest_aspect_ratio, - ): - super().__init__( - vision_model_type=vision_model_type, - target_width=tile_size, - target_height=tile_size, - embeddings_per_image=embeddings_per_tile, - ) - - # print(f"Transformation params: {vision_model_type=}, {use_tiling=}, {tile_size=}, {use_thumbnail=}, {augment=}, {min_num_tiles=}, {max_num_tiles=}, {find_closest_aspect_ratio_fn=}") - self._tile_size = tile_size - self._use_thumbnail = use_thumbnail - self._min_num_tiles = min_num_tiles - self._max_num_tiles = max_num_tiles - self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn - - # Calculate all possible aspect ratios for each max_num_tiles. - self.target_ratios = { - max_num_tiles: sorted( - set( - (x, y) - for n in range(self._min_num_tiles, max_num_tiles + 1) - for x in range(1, n + 1) - for y in range(1, n + 1) - if x * y <= max_num_tiles and x * y >= self._min_num_tiles - ), - key=lambda x: x[0] * x[1], - ) - for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1) - } - - self.transform = A.Compose([ - A.OneOf([ - A.GaussNoise(var_limit=(5.0, 30.0)), - A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5)), - ], p=0.3), - A.OneOf([ - A.MedianBlur(blur_limit=5), - A.GaussianBlur(blur_limit=5), - ], p=0.2), - A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.5), - A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=15, val_shift_limit=15, p=0.3), - A.ImageCompression(quality_lower=70, quality_upper=100, p=0.3), - ]) - - def apply_params(self, transform_media: ImageTilingParams, data_augment: bool = False, **kwargs) -> list[torch.Tensor]: - assert isinstance(transform_media, ImageTilingParamsV1) - image = transform_media.media.value - - if data_augment: - image = self.transform(image=np.asarray(image))["image"] - image = Image.fromarray(image) - - # calculate the target width and height - target_width = self._tile_size * transform_media.tiling[0] - target_height = self._tile_size * transform_media.tiling[1] - blocks = transform_media.tiling[0] * transform_media.tiling[1] - - # resize the image - resized_img = image.resize((target_width, target_height)) - processed_images = [] - for i in range(blocks): - box = ( - (i % (target_width // self._tile_size)) * self._tile_size, - (i // (target_width // self._tile_size)) * self._tile_size, - ((i % (target_width // self._tile_size)) + 1) * self._tile_size, - ((i // (target_width // self._tile_size)) + 1) * self._tile_size, - ) - # split the image - split_img = resized_img.crop(box) - processed_images.append(split_img) - assert len(processed_images) == blocks - if self._use_thumbnail and len(processed_images) != 1: - thumbnail_img = image.resize((self._tile_size, self._tile_size)) - processed_images.append(thumbnail_img) - - return [self._transform(img) for img in processed_images] - - def compute_params( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: Optional[int] = None, - max_num_tiles: int | None = None, - data_augment: bool = False, - tiling_augment_prob: float = 0.4, - **kwargs, - ) -> list[ImageTilingParamsV1]: - # Use provided max_num_tiles or fall back to instance's max_num_tiles - # Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value - effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles - effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles) - - max_num_tiles_to_use = min( - num_tokens_available // self._embeddings_per_image, effective_max_num_tiles - ) - - # calculate the existing image aspect ratio - target_ratios = self.target_ratios[max_num_tiles_to_use] - - params = [] - for media in media_list: - if isinstance(media, ImageMedia): - img_size = (media.width, media.height) - elif isinstance(media, VideoFrameMedia): - img_size = (media.video_width, media.video_height) - else: - raise ValueError(f"Unsupported media type: {type(media)}") - - aspect_ratio = img_size[0] / img_size[1] - - # find the closest aspect ratio to the target - tiling = self._find_closest_aspect_ratio_fn( - aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size - ) - if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob: - tiling = self.augment_tiling(tiling) - num_tiles = tiling[0] * tiling[1] - if self._use_thumbnail and num_tiles != 1: - num_tiles += 1 - - params.append( - ImageTilingParamsV1( - media=media, - num_tiles=num_tiles, - num_embeddings=num_tiles * self._embeddings_per_image, - tiling=tiling, - ) - ) - - return params - - def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]: - def num_tiles(tiling: tuple[int, int]) -> int: - return tiling[0] * tiling[1] - - def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]: - if random.random() < minus_prob: - # Minus one - if tiling[0] == 1 and tiling[1] == 1: - return tiling - elif tiling[0] == 1: - return (tiling[0], tiling[1] - 1) - elif tiling[1] == 1: - return (tiling[0] - 1, tiling[1]) - else: - if random.random() < 0.5: - return (tiling[0] - 1, tiling[1]) - else: - return (tiling[0], tiling[1] - 1) - else: - # Plus one - if num_tiles(tiling) < self._max_num_tiles: - tiling0 = (tiling[0] + 1, tiling[1]) - tiling1 = (tiling[0], tiling[1] + 1) - if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles: - return tiling - elif num_tiles(tiling0) > self._max_num_tiles: - return tiling1 - elif num_tiles(tiling1) > self._max_num_tiles: - return tiling0 - else: - if random.random() < 0.5: - return tiling0 - else: - return tiling1 - return tiling - - new_tiling = plus_minus_one(tiling) - return new_tiling - - def __str__(self): - return f"TilingImageTransform(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, embeddings_per_tile={self._embeddings_per_image}, find_closest_aspect_ratio_fn={self._find_closest_aspect_ratio_fn})" - - -class TileDegradationStrategy(ImageTilingStrategy): - """Strategy for tiling images and video frames, each with their own tiling strategy, while trying to match the - number of tokens left in the sample by reducing the number of tiles if needed. - """ - - # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 - # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 - - def __init__( - self, - image_strategy: ImageTilingStrategy, - video_frame_strategy: ImageTilingStrategy, - embeddings_per_tile: int, - max_num_tiles: int, - tile_degradation_map: dict[int, int] = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1}, - ): - self._image_strategy = image_strategy - self._video_frame_strategy = video_frame_strategy - self._embeddings_per_tile = embeddings_per_tile - self._max_num_tiles = max_num_tiles - self._tile_degradation_map = tile_degradation_map - - def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]: - if isinstance(transform_media.media, ImageMedia): - return self._image_strategy.apply_params(transform_media, **kwargs) - elif isinstance(transform_media.media, VideoFrameMedia): - return self._video_frame_strategy.apply_params(transform_media, **kwargs) - else: - raise ValueError(f"Unsupported media type: {type(transform_media.media)}") - - def compute_params( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: int | None = None, - max_num_tiles: int | None = None, - **kwargs, - ) -> list[ImageTilingParams]: - # Use provided max_num_tiles or fall back to instance's max_num_tiles - effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles - max_num_tiles_to_use = effective_max_num_tiles - degradation_map = self._tile_degradation_map - - while True: - params = [] - img_num_tiles = [] - for media in media_list: - if isinstance(media, ImageMedia): - media_params = self._image_strategy.compute_params( - [media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs - )[0] - elif isinstance(media, VideoFrameMedia): - max_num_tiles_to_use = 1 - media_params = self._video_frame_strategy.compute_params( - [media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs - )[0] - else: - raise ValueError(f"Unsupported media type: {type(media)}") - img_num_tiles.append(media_params.num_tiles) - params.append(media_params) - if max_num_tiles_to_use == 1 or num_tokens_available is None: - break - if sum(img_num_tiles) * self._embeddings_per_tile > num_tokens_available: - if max_num_tiles_to_use in degradation_map: - max_num_tiles_to_use = degradation_map[max_num_tiles_to_use] - else: - # End of degradation - break - else: - break - return params - - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: - return self._image_strategy.stack(images) - - def __str__(self): - return f"TileDegradationImageTransform(max_num_tiles={self._max_num_tiles}, image_transform={self._image_strategy}, video_frame_transform={self._video_frame_strategy})" - - -@dataclass -class DynamicResolutionParams(ImageTilingParams): - patch_size: tuple[int, int] - - -class DynamicResolutionImageTilingStrategy(ImageTilingStrategy): - """Preprocess an image with dynamic resolution for vision transformers. - - This function resizes an image to optimize the number of patches while respecting - constraints on minimum/maximum patches, minimum side length, and compatibility - with pixel shuffle or convolution merging operations. - - The algorithm works by: - 1. Computing the initial patch grid size based on the image dimensions and res_step - 2. Scaling the patch grid to fit within the max_patches constraint - 3. Ensuring the result has at least min_patches - 4. Optionally enforcing a minimum side length constraint - 5. Rounding patch dimensions to even numbers for pixel_shuffle/conv_merging compatibility - 6. Resizing the image to the computed target dimensions - - Note: - The function preserves aspect ratio as much as possible while satisfying all constraints. - When constraints conflict (e.g., min_side vs max_patches), the function prioritizes - staying within max_patches while maximizing the image size. - - Example: - >>> from PIL import Image - >>> img = Image.open("example.jpg") # 800x600 image - >>> strategy = DynamicResolutionImageTilingStrategy(vision_model_type="radio", min_patches=4, max_patches=64, res_step=14, get_num_embeddings=lambda x, y: x * y * 2) - >>> params = strategy.compute_params([img]) - >>> img_tensor = strategy.apply_params(params[0]) - >>> # Returns image resized to maintain aspect ratio with 4-64 patches of size 14x14 - """ - - def __init__( - self, - vision_model_type: str, - min_num_patches: int, - patch_size: int, - get_num_embeddings: Callable[[int, int], int], - factor_max: float = 1.0, - pixel_shuffle: bool = False, - min_side: int | None = None, - conv_merging: bool = False, - use_thumbnail: bool = False, - thumbnail_size: int = 448, - thumbnail_area_threshold: float = 0.8, - max_num_patches: int = 0, - apply_data_augment: bool = False, - ): - """ - Args: - vision_model_type: Vision model type. - min_num_patches: Minimum number of patches required. Defaults to 1. - max_num_patches: Maximum number of patches allowed. Defaults to 0 (no maximum). - patch_size: Resolution step size (patch dimension). Defaults to 16. - get_num_embeddings: Function to get the number of embeddings from the patch size (width, height). - factor_max: Maximum scaling factor to apply. Defaults to 1.0. - pixel_shuffle: Whether to ensure compatibility with pixel shuffle operations by rounding to even patch - dimensions. Defaults to False. - min_side: Minimum side length in pixels. If specified, ensures at least one side meets this constraint. - Defaults to None. - conv_merging: Whether to ensure compatibility with convolution merging by rounding to even patch dimensions. - Defaults to False. - use_thumbnail: Whether to add a thumbnail image when processing. Defaults to False. - thumbnail_size: Size of the thumbnail image (width and height). Defaults to 448. - thumbnail_area_threshold: Maximum area percentage (0.0-1.0) of the resized image relative to thumbnail area - for which to add a thumbnail. If the resized image area is larger than this threshold of the thumbnail - area, no thumbnail will be added. Defaults to 0.8 (80%). - apply_data_augment: Whether to apply data augmentation to the image. Defaults to False. - """ - assert "radio" in vision_model_type, ( - "Dynamic resolution is only supported for radio models" - ) - self._vision_model_type = vision_model_type - self._min_num_patches = min_num_patches - self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf") - self._patch_size = patch_size - self._get_num_embeddings = get_num_embeddings - self._factor_max = factor_max - self._pixel_shuffle = pixel_shuffle - self._min_side = min_side - self._conv_merging = conv_merging - self._use_thumbnail = use_thumbnail - self._thumbnail_size = thumbnail_size - self._thumbnail_area_threshold = thumbnail_area_threshold - pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] - self._transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)), - T.Normalize(mean=pixel_mean, std=pixel_std), - ] - ) - self._apply_data_augment = apply_data_augment - - def apply_params(self, params: DynamicResolutionParams, **kwargs) -> list[torch.Tensor]: - # resize the image - resized_img = params.media.value.resize( - ( - params.patch_size[0] * self._patch_size, - params.patch_size[1] * self._patch_size, - ) - ) - processed_images = [resized_img] - - # Add thumbnail if enabled and image area is below threshold - if self._use_thumbnail: - # Calculate areas - resized_area = resized_img.size[0] * resized_img.size[1] - thumbnail_area = self._thumbnail_size * self._thumbnail_size - area_ratio = resized_area / thumbnail_area - - # Only add thumbnail if resized image area is less than threshold % of thumbnail area - if area_ratio < self._thumbnail_area_threshold: - thumbnail_img = params.media.value.resize((self._thumbnail_size, self._thumbnail_size)) - processed_images.append(thumbnail_img) - - return [self._transform(img) for img in processed_images] - - def process_media( - self, - media: ImageMedia | VideoFrameMedia, - num_tokens_available: int, - data_augment: bool = False, - tiling_augment_prob: float = 0.4, - ) -> DynamicResolutionParams: - """Process a single media item and return its parameters. - - Args: - media: The media item to process - num_tokens_available: Number of tokens available for this media - data_augment: Whether to apply data augmentation to the image. Defaults to False. - Returns: - DynamicResolutionParams for the media - """ - current_num_tokens_available = num_tokens_available - if isinstance(media, ImageMedia): - orig_width, orig_height = media.width, media.height - elif isinstance(media, VideoFrameMedia): - orig_width, orig_height = media.video_width, media.video_height - # current_num_tokens_available = 1024 #TEMP: hack for video - else: - raise ValueError(f"Unsupported media type: {type(media)}") - - closest_patch_height = round(orig_height / self._patch_size + 0.5) - closest_patch_width = round(orig_width / self._patch_size + 0.5) - patches = closest_patch_height * closest_patch_width - - factor = min(math.sqrt(current_num_tokens_available / patches), self._factor_max) - target_patch_height = math.floor(factor * closest_patch_height) - target_patch_width = math.floor(factor * closest_patch_width) - - # We only consider self._min_num_patches if it is greater than current_num_tokens_available. - if current_num_tokens_available > self._min_num_patches and target_patch_height * target_patch_width < self._min_num_patches: - up_factor = math.sqrt( - self._min_num_patches / (target_patch_height * target_patch_width) - ) - target_patch_height = math.ceil(up_factor * target_patch_height) - target_patch_width = math.ceil(up_factor * target_patch_width) - - if ( - self._min_side is not None - and min(target_patch_width, target_patch_height) * self._patch_size - < self._min_side - ): - if target_patch_width <= target_patch_height: - up_factor = self._min_side / (target_patch_width * self._patch_size) - new_patch_height = math.ceil(up_factor * target_patch_height) - new_patch_width = math.ceil(up_factor * target_patch_width) - - if new_patch_height * new_patch_width > current_num_tokens_available: - # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches - if ( - max(current_num_tokens_available // new_patch_width, 1) - * self._patch_size - < self._min_side - ): - up_factor = math.sqrt( - current_num_tokens_available - / (target_patch_height * target_patch_width) - ) - target_patch_height = math.floor( - up_factor * target_patch_height - ) - target_patch_width = math.floor( - up_factor * target_patch_width - ) - target_patch_width = new_patch_width - target_patch_height = max( - current_num_tokens_available // new_patch_width, 1 - ) - else: - target_patch_height = new_patch_height - target_patch_width = new_patch_width - else: - up_factor = self._min_side / ( - target_patch_height * self._patch_size - ) - new_patch_height = math.ceil(up_factor * target_patch_height) - new_patch_width = math.ceil(up_factor * target_patch_width) - - if new_patch_height * new_patch_width > current_num_tokens_available: - # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches - if ( - max(current_num_tokens_available // new_patch_height, 1) - * self._patch_size - < self._min_side - ): - up_factor = math.sqrt( - current_num_tokens_available - / (target_patch_height * target_patch_width) - ) - target_patch_height = math.floor( - up_factor * target_patch_height - ) - target_patch_width = math.floor( - up_factor * target_patch_width - ) - else: - target_patch_height = new_patch_height - target_patch_width = max( - current_num_tokens_available // new_patch_height, 1 - ) - else: - target_patch_height = new_patch_height - target_patch_width = new_patch_width - - # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging) - # or by 4 when BOTH are enabled (two successive 2x reductions) - if self._pixel_shuffle or self._conv_merging: - required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2 - - rem_h = target_patch_height % required_divisor - if rem_h != 0: - inc_h = required_divisor - rem_h - if (target_patch_height + inc_h) * target_patch_width <= current_num_tokens_available: - target_patch_height += inc_h - else: - target_patch_height = max(required_divisor, target_patch_height - rem_h) - - rem_w = target_patch_width % required_divisor - if rem_w != 0: - inc_w = required_divisor - rem_w - if target_patch_height * (target_patch_width + inc_w) <= current_num_tokens_available: - target_patch_width += inc_w - else: - target_patch_width = max(required_divisor, target_patch_width - rem_w) - - if data_augment and self._apply_data_augment and random.random() < tiling_augment_prob: - target_patch_width, target_patch_height = self.augment_resolution(target_patch_width, target_patch_height, current_num_tokens_available) - - #TEMP: hack for video - if isinstance(media, VideoFrameMedia): - target_patch_width = 32 - target_patch_height = 32 - - # Calculate embeddings for the main dynamic resolution image - num_embeddings = self._get_num_embeddings( - target_patch_width * self._patch_size, - target_patch_height * self._patch_size, - ) - - token_count = target_patch_width * target_patch_height - - # Add thumbnail embeddings if enabled and image area is below threshold - num_tiles = 1 # Base dynamic resolution image - if self._use_thumbnail: - # Calculate areas - resized_area = (target_patch_width * self._patch_size) * (target_patch_height * self._patch_size) - thumbnail_area = self._thumbnail_size * self._thumbnail_size - area_ratio = resized_area / thumbnail_area - - # Only add thumbnail if resized image area is less than threshold % of thumbnail area - if area_ratio < self._thumbnail_area_threshold: - num_tiles += 1 # Add 1 for thumbnail - # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) - num_embeddings += self._get_num_embeddings(self._thumbnail_size, self._thumbnail_size) - token_count += self._thumbnail_size // self._patch_size * self._thumbnail_size // self._patch_size - - return DynamicResolutionParams( - media=media, - num_tiles=num_tiles, - num_embeddings=num_embeddings, - patch_size=(target_patch_width, target_patch_height), - ), token_count - - def augment_resolution(self, target_patch_width: int, target_patch_height: int, current_num_tokens_available: int) -> tuple[int, int]: - - min_num_patch_one_side = 32 - - if random.random() < 0.5: - # Minus one - if target_patch_width <= min_num_patch_one_side and target_patch_height <= min_num_patch_one_side: - return target_patch_width, target_patch_height - elif target_patch_width <= min_num_patch_one_side: - return target_patch_width, target_patch_height - min_num_patch_one_side - elif target_patch_height <= min_num_patch_one_side: - return target_patch_width - min_num_patch_one_side, target_patch_height - else: - if random.random() < 0.5: - return target_patch_width - min_num_patch_one_side, target_patch_height - else: - return target_patch_width, target_patch_height - min_num_patch_one_side - else: - # Plus one - if target_patch_width * target_patch_height < current_num_tokens_available: - if random.random() < 0.5: - return target_patch_width + min_num_patch_one_side, target_patch_height - else: - return target_patch_width, target_patch_height + min_num_patch_one_side - return target_patch_width, target_patch_height - - def compute_params( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: int | None = None, - max_num_tiles: int | None = None, - data_augment: bool = False, - **kwargs, - ) -> list[ImageTilingParams]: - """Compute parameters for all media with iterative token budgeting. - - Args: - media_list: List of media items to process - num_tokens_available: Total number of tokens available across all media - max_num_tiles: Maximum number of tiles (unused in this implementation) - data_augment: Whether to apply data augmentation to the image. Defaults to False. - Returns: - List of ImageTilingParams for each media item - """ - num_tokens_available = num_tokens_available * (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1) - # When the number of available token is too small, allow self._min_num_patches per media and - # let the sample be truncated. - num_tokens_available = max(num_tokens_available, self._min_num_patches * len(media_list)) - - # Clip the number of tokens available per media to be between min and max patches. - num_tokens_available_per_media = [ - max(min(num_tokens_available, self._max_num_patches), self._min_num_patches) - for _ in range(len(media_list))] - - # In theory this could be a while True loop, but in case the process_media method slightly - # changes, I want to make sure we don't get stuck in an infinite loop. - for _ in range(10): - # Step 1: Process each media with current token budget - params = [] - token_counts = [] - - for media, tokens_for_media in zip(media_list, num_tokens_available_per_media): - param, token_count = self.process_media(media, tokens_for_media, data_augment=data_augment) - params.append(param) - token_counts.append(token_count) - - # Step 2: Check if total tokens is within budget - total_tokens = sum(token_counts) - - if total_tokens <= num_tokens_available: - # We're within budget, return the params - return params - - # Step 3: We're over budget, need to scale down - # Calculate scaling factor to get under budget - scaling_factor = num_tokens_available / total_tokens - - # Recalculate token budgets for each media based on scaling - # Each media gets a proportional share of the total budget - scaled_down_num_tokens_available_per_media = [ - max(self._min_num_patches, int(token_count * scaling_factor)) - for token_count in token_counts - ] - scaled_down = any([ - scaled_down_num_tokens_available_per_media[i] < num_tokens_available_per_media[i] - for i in range(len(num_tokens_available_per_media))]) - # If there was not scaling down, we're stuck just use min_num_patches per media, else - # try with the scaled down num_tokens_available_per_media. - if not scaled_down: - num_tokens_available_per_media = [self._min_num_patches] * len(media_list) - else: - num_tokens_available_per_media = scaled_down_num_tokens_available_per_media - return params - - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: - imgs_sizes = torch.tensor( - [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 - ) - - def rearrange_img(x): - py = x.shape[-2] // self._patch_size - px = x.shape[-1] // self._patch_size - x = einops.rearrange( - x, - "c (py yy) (px xx) -> (py px) (c yy xx)", - py=py, - yy=self._patch_size, - px=px, - xx=self._patch_size, - ) - return x - - if len(images) > 0: - imgs = [rearrange_img(img) for img in images] - - current_length = 0 - max_length = 0 - vision_cu_lengths = [0] - for img in imgs: - if max_length < img.shape[0]: - max_length = img.shape[0] - current_length += img.shape[0] - vision_cu_lengths.append(current_length) - - vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) - vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) - - return ( - torch.cat(imgs, dim=0).unsqueeze(0), - imgs_sizes, - vision_cu_lengths, - vision_max_lengths, - ) - else: - return ( - torch.tensor([[0]], dtype=torch.float32), - torch.tensor([[0,0]], dtype=torch.int32), - None, - None, - ) - - def __str__(self): - return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})" - - -@dataclass -class MatchTilingDynamicResolutionParams(ImageTilingParams): - tiling: tuple[int, int] - - -class MatchTilingDynamicResolutionStrategy(ImageTilingStrategy): - """ - Strategy that uses tiling logic to determine optimal image dimensions but processes - the image as a single dynamic resolution image instead of splitting into tiles. - - This combines the aspect ratio optimization from ImageTilingStrategyV1 with the - dynamic resolution processing from DynamicResolutionImageTilingStrategy. - - Also includes tile degradation logic similar to TileDegradationStrategy. - """ - - def __init__( - self, - vision_model_type: str, - tile_size: int, - use_thumbnail: bool, - min_num_tiles: int, - max_num_tiles: int, - embeddings_per_tile: int, - patch_size: int, - get_num_embeddings: Callable[[int, int], int], - find_closest_aspect_ratio_fn=find_closest_aspect_ratio, - pixel_shuffle: bool = False, - conv_merging: bool = False, - tile_degradation_map: dict[int, int] = None, - video_frame_strategy: ImageTilingStrategy = None, - enable_tile_degradation: bool = True, - ): - """ - Args: - vision_model_type: Vision model type (should support dynamic resolution) - tile_size: Size of each tile for tiling calculation - use_thumbnail: Whether tiling logic should include thumbnail - min_num_tiles: Minimum number of tiles for tiling calculation - max_num_tiles: Maximum number of tiles for tiling calculation - embeddings_per_tile: Embeddings per tile for tiling calculation - patch_size: Patch size for dynamic resolution processing - get_num_embeddings: Function to get number of embeddings from dimensions - find_closest_aspect_ratio_fn: Function to find closest aspect ratio - pixel_shuffle: Whether to ensure compatibility with pixel shuffle - conv_merging: Whether to ensure compatibility with convolution merging - tile_degradation_map: Map for degrading tiles when tokens are insufficient - video_frame_strategy: Strategy for processing video frames - enable_tile_degradation: Whether to enable tile degradation (default: True) - """ - assert "radio" in vision_model_type, ( - "MatchTilingDynamicResolution is only supported for radio models" - ) - - self._vision_model_type = vision_model_type - self._tile_size = tile_size - self._use_thumbnail = use_thumbnail - self._min_num_tiles = min_num_tiles - self._max_num_tiles = max_num_tiles - self._embeddings_per_tile = embeddings_per_tile - self._patch_size = patch_size - self._get_num_embeddings = get_num_embeddings - self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn - self._pixel_shuffle = pixel_shuffle - self._conv_merging = conv_merging - self._enable_tile_degradation = enable_tile_degradation - - # Tile degradation logic (similar to TileDegradationStrategy) - if tile_degradation_map is None: - self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1} - else: - self._tile_degradation_map = tile_degradation_map - - # Video frame strategy (similar to TileDegradationStrategy) - if video_frame_strategy is None: - self._video_frame_strategy = NoTilingStrategy( - vision_model_type=vision_model_type, - target_width=tile_size, - target_height=tile_size, - embeddings_per_image=embeddings_per_tile, - ) - else: - self._video_frame_strategy = video_frame_strategy - - # Calculate all possible aspect ratios for each max_num_tiles (borrowed from ImageTilingStrategyV1) - self.target_ratios = { - max_num_tiles: sorted( - set( - (x, y) - for n in range(self._min_num_tiles, max_num_tiles + 1) - for x in range(1, n + 1) - for y in range(1, n + 1) - if x * y <= max_num_tiles and x * y >= self._min_num_tiles - ), - key=lambda x: x[0] * x[1], - ) - for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1) - } - - # Set up transform for dynamic resolution processing - pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] - self._transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), - T.Normalize(mean=pixel_mean, std=pixel_std), - ] - ) - - def apply_params(self, params: MatchTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]: - # Handle video frames using the video frame strategy - if isinstance(params.media, VideoFrameMedia): - return self._video_frame_strategy.apply_params(params, **kwargs) - - # Handle images with dynamic resolution processing - image = params.media.value - # Calculate the target width and height (same logic as ImageTilingStrategyV1) - target_width = self._tile_size * params.tiling[0] - target_height = self._tile_size * params.tiling[1] - - # Resize the image to the target dimensions (same as ImageTilingStrategyV1) - resized_img = image.resize((target_width, target_height)) - - # Process as single dynamic resolution image - processed_images = [resized_img] - - # Add thumbnail if use_thumbnail=True and there's more than 1 tile (same as ImageTilingStrategyV1) - blocks = params.tiling[0] * params.tiling[1] - if self._use_thumbnail and blocks != 1: - thumbnail_img = image.resize((self._tile_size, self._tile_size)) - processed_images.append(thumbnail_img) - - return [self._transform(img) for img in processed_images] - - def compute_params( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: int | None = None, - max_num_tiles: int | None = None, - **kwargs, - ) -> list[MatchTilingDynamicResolutionParams]: - # Implement tile degradation logic similar to TileDegradationStrategy - # Use provided max_num_tiles or fall back to instance's max_num_tiles - # Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value - effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles - effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles) - max_num_tiles_to_use = effective_max_num_tiles - degradation_map = self._tile_degradation_map - - while True: - params = [] - total_embeddings_needed = 0 - - for media in media_list: - if isinstance(media, ImageMedia): - # Use tiling logic for images - img_size = (media.width, media.height) - aspect_ratio = img_size[0] / img_size[1] - - # Find the closest aspect ratio to the target - target_ratios = self.target_ratios[max_num_tiles_to_use] - tiling = self._find_closest_aspect_ratio_fn( - aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size - ) - - # Calculate target dimensions for dynamic resolution processing - target_width = self._tile_size * tiling[0] - target_height = self._tile_size * tiling[1] - num_embeddings = self._get_num_embeddings(target_width, target_height) - - # Account for thumbnail (same logic as ImageTilingStrategyV1) - num_tiles = 1 # Base dynamic resolution image - blocks = tiling[0] * tiling[1] - if self._use_thumbnail and blocks != 1: - num_tiles += 1 # Add 1 for thumbnail - # Add embeddings for thumbnail (tile_size x tile_size) - num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size) - - media_params = MatchTilingDynamicResolutionParams( - media=media, - num_tiles=num_tiles, - num_embeddings=num_embeddings, - tiling=tiling, - ) - elif isinstance(media, VideoFrameMedia): - # Use video frame strategy for video frames (always 1 tile) - video_params = self._video_frame_strategy.compute_params( - [media], 1 * self._embeddings_per_tile - )[0] - media_params = MatchTilingDynamicResolutionParams( - media=media, - num_tiles=video_params.num_tiles, - num_embeddings=video_params.num_embeddings, - tiling=(1, 1), # Video frames always use 1x1 tiling - ) - else: - raise ValueError(f"Unsupported media type: {type(media)}") - - params.append(media_params) - total_embeddings_needed += media_params.num_embeddings - - # Check if we need to degrade (only if degradation is enabled) - if not self._enable_tile_degradation: - break - if max_num_tiles_to_use == 1 or num_tokens_available is None: - break - if total_embeddings_needed > num_tokens_available: - if max_num_tiles_to_use in degradation_map: - max_num_tiles_to_use = degradation_map[max_num_tiles_to_use] - # Recalculate target ratios for the new max_num_tiles_to_use - if max_num_tiles_to_use not in self.target_ratios: - self.target_ratios[max_num_tiles_to_use] = sorted( - set( - (x, y) - for n in range(self._min_num_tiles, max_num_tiles_to_use + 1) - for x in range(1, n + 1) - for y in range(1, n + 1) - if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles - ), - key=lambda x: x[0] * x[1], - ) - else: - # End of degradation - break - else: - break - - return params - - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]: - """Stack images using dynamic resolution approach with sequence packing""" - imgs_sizes = torch.tensor( - [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 - ) - - def rearrange_img(x): - py = x.shape[-2] // self._patch_size - px = x.shape[-1] // self._patch_size - x = einops.rearrange( - x, - "c (py yy) (px xx) -> (py px) (c yy xx)", - py=py, - yy=self._patch_size, - px=px, - xx=self._patch_size, - ) - return x - - if len(images) > 0: - imgs = [rearrange_img(img) for img in images] - - current_length = 0 - max_length = 0 - vision_cu_lengths = [0] - for img in imgs: - if max_length < img.shape[0]: - max_length = img.shape[0] - current_length += img.shape[0] - vision_cu_lengths.append(current_length) - - vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) - vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) - - return ( - torch.cat(imgs, dim=0).unsqueeze(0), - imgs_sizes, - vision_cu_lengths, - vision_max_lengths, - ) - else: - return ( - torch.tensor([[0]], dtype=torch.float32), - torch.tensor([[0,0]], dtype=torch.int32), - None, - None, - ) - - def __str__(self): - return f"MatchTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})" - - -@dataclass -class MaskedTilingDynamicResolutionParams(ImageTilingParams): - tiling: tuple[int, int] - - -class MaskedTilingDynamicResolutionStrategy(ImageTilingStrategy): - """ - Like MatchTilingDynamicResolutionStrategy, but ensures tiles are isolated in the - vision encoder by emitting per-tile packed samples (block-diagonal attention across tiles). - """ - - def __init__( - self, - vision_model_type: str, - tile_size: int, - use_thumbnail: bool, - min_num_tiles: int, - max_num_tiles: int, - embeddings_per_tile: int, - patch_size: int, - get_num_embeddings: Callable[[int, int], int], - find_closest_aspect_ratio_fn=find_closest_aspect_ratio, - pixel_shuffle: bool = False, - conv_merging: bool = False, - tile_degradation_map: dict[int, int] = None, - video_frame_strategy: ImageTilingStrategy = None, - enable_tile_degradation: bool = True, - ): - assert "radio" in vision_model_type, ( - "MaskedTilingDynamicResolution is only supported for radio models" - ) - - self._vision_model_type = vision_model_type - self._tile_size = tile_size - self._use_thumbnail = use_thumbnail - self._min_num_tiles = min_num_tiles - self._max_num_tiles = max_num_tiles - self._embeddings_per_tile = embeddings_per_tile - self._patch_size = patch_size - self._get_num_embeddings = get_num_embeddings - self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn - self._pixel_shuffle = pixel_shuffle - self._conv_merging = conv_merging - self._enable_tile_degradation = enable_tile_degradation - - if tile_degradation_map is None: - self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1} - else: - self._tile_degradation_map = tile_degradation_map - - if video_frame_strategy is None: - self._video_frame_strategy = NoTilingStrategy( - vision_model_type=vision_model_type, - target_width=tile_size, - target_height=tile_size, - embeddings_per_image=embeddings_per_tile, - ) - else: - self._video_frame_strategy = video_frame_strategy - - self.target_ratios = { - max_num_tiles: sorted( - set( - (x, y) - for n in range(self._min_num_tiles, max_num_tiles + 1) - for x in range(1, n + 1) - for y in range(1, n + 1) - if x * y <= max_num_tiles and x * y >= self._min_num_tiles - ), - key=lambda x: x[0] * x[1], - ) - for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1) - } - - pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] - self._transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), - T.Normalize(mean=pixel_mean, std=pixel_std), - ] - ) - - def apply_params(self, params: MaskedTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]: - # Handle video frames using the video frame strategy - if isinstance(params.media, VideoFrameMedia): - return self._video_frame_strategy.apply_params(params, **kwargs) - - image = params.media.value - nx, ny = params.tiling - target_width = self._tile_size * nx - target_height = self._tile_size * ny - - resized_img = image.resize((target_width, target_height)) - - processed_images = [] - # Emit per-tile images (each becomes an isolated packed sample later) - for j in range(ny): - for i in range(nx): - box = ( - i * self._tile_size, - j * self._tile_size, - (i + 1) * self._tile_size, - (j + 1) * self._tile_size, - ) - tile_img = resized_img.crop(box) - processed_images.append(tile_img) - - if self._use_thumbnail and (nx * ny) != 1: - thumbnail_img = image.resize((self._tile_size, self._tile_size)) - processed_images.append(thumbnail_img) - - return [self._transform(img) for img in processed_images] - - def compute_params( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: int | None = None, - max_num_tiles: int | None = None, - data_augment: bool = False, - tiling_augment_prob: float = 0.4, - **kwargs, - ) -> list[MaskedTilingDynamicResolutionParams]: - effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles - effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles) - max_num_tiles_to_use = effective_max_num_tiles - degradation_map = self._tile_degradation_map - - while True: - params = [] - total_embeddings_needed = 0 - - for media in media_list: - if isinstance(media, ImageMedia): - img_size = (media.width, media.height) - aspect_ratio = img_size[0] / img_size[1] - - target_ratios = self.target_ratios[max_num_tiles_to_use] - tiling = self._find_closest_aspect_ratio_fn( - aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size - ) - - # Apply tiling augmentation if enabled - if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob: - tiling = self.augment_tiling(tiling) - - blocks = tiling[0] * tiling[1] - # Each tile is tile_size x tile_size - per_tile_emb = self._get_num_embeddings(self._tile_size, self._tile_size) - num_embeddings = blocks * per_tile_emb - - num_tiles = blocks - if self._use_thumbnail and blocks != 1: - num_tiles += 1 - num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size) - - media_params = MaskedTilingDynamicResolutionParams( - media=media, - num_tiles=num_tiles, - num_embeddings=num_embeddings, - tiling=tiling, - ) - elif isinstance(media, VideoFrameMedia): - video_params = self._video_frame_strategy.compute_params( - [media], 1 * self._embeddings_per_tile - )[0] - media_params = MaskedTilingDynamicResolutionParams( - media=media, - num_tiles=video_params.num_tiles, - num_embeddings=video_params.num_embeddings, - tiling=(1, 1), - ) - else: - raise ValueError(f"Unsupported media type: {type(media)}") - - params.append(media_params) - total_embeddings_needed += media_params.num_embeddings - - if not self._enable_tile_degradation: - break - if max_num_tiles_to_use == 1 or num_tokens_available is None: - break - if total_embeddings_needed > num_tokens_available: - if max_num_tiles_to_use in degradation_map: - max_num_tiles_to_use = degradation_map[max_num_tiles_to_use] - if max_num_tiles_to_use not in self.target_ratios: - self.target_ratios[max_num_tiles_to_use] = sorted( - set( - (x, y) - for n in range(self._min_num_tiles, max_num_tiles_to_use + 1) - for x in range(1, n + 1) - for y in range(1, n + 1) - if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles - ), - key=lambda x: x[0] * x[1], - ) - else: - break - else: - break - - return params - - def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]: - def num_tiles(tiling: tuple[int, int]) -> int: - return tiling[0] * tiling[1] - - def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]: - if random.random() < minus_prob: - # Minus one - if tiling[0] == 1 and tiling[1] == 1: - return tiling - elif tiling[0] == 1: - return (tiling[0], tiling[1] - 1) - elif tiling[1] == 1: - return (tiling[0] - 1, tiling[1]) - else: - if random.random() < 0.5: - return (tiling[0] - 1, tiling[1]) - else: - return (tiling[0], tiling[1] - 1) - else: - # Plus one - if num_tiles(tiling) < self._max_num_tiles: - tiling0 = (tiling[0] + 1, tiling[1]) - tiling1 = (tiling[0], tiling[1] + 1) - if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles: - return tiling - elif num_tiles(tiling0) > self._max_num_tiles: - return tiling1 - elif num_tiles(tiling1) > self._max_num_tiles: - return tiling0 - else: - if random.random() < 0.5: - return tiling0 - else: - return tiling1 - return tiling - - new_tiling = plus_minus_one(tiling) - return new_tiling - - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]: - # Identical to dynamic resolution packing; each tile is already an independent image sample - imgs_sizes = torch.tensor( - [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 - ) - - def rearrange_img(x): - py = x.shape[-2] // self._patch_size - px = x.shape[-1] // self._patch_size - x = einops.rearrange( - x, - "c (py yy) (px xx) -> (py px) (c yy xx)", - py=py, - yy=self._patch_size, - px=px, - xx=self._patch_size, - ) - return x - - if len(images) > 0: - imgs = [rearrange_img(img) for img in images] - - current_length = 0 - max_length = 0 - vision_cu_lengths = [0] - for img in imgs: - if max_length < img.shape[0]: - max_length = img.shape[0] - current_length += img.shape[0] - vision_cu_lengths.append(current_length) - - vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) - vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) - - return ( - torch.cat(imgs, dim=0).unsqueeze(0), - imgs_sizes, - vision_cu_lengths, - vision_max_lengths, - ) - else: - return ( - torch.tensor([[0]], dtype=torch.float32), - torch.tensor([[0,0]], dtype=torch.int32), - None, - None, - ) - - def __str__(self): - return f"MaskedTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})" - -def create_image_tiling_strategy(args): - """ - Create an image tiling strategy based on the provided arguments. - - This function encapsulates the logic for creating the appropriate image tiling strategy - based on the training/evaluation configuration. It can be used by both training (task_encoder) - and evaluation code outside of data_loading/. - - Args: - args: Arguments object with the following relevant attributes: - - img_h, img_w: Image height and width - - patch_dim: Patch dimension - - vision_model_type: Vision model type (e.g., 'radio', 'clip', 'siglip') - - disable_vision_class_token: Whether to disable vision class token - - pixel_shuffle: Whether to use pixel shuffle - - use_tile_tags: Whether to use tile tags - - max_num_tiles: Maximum number of tiles - - tokenizer_prompt_format: Tokenizer prompt format - - image_break_token: Image break token (optional) - - conv_merging: Whether to use convolution merging - - dynamic_resolution: Whether to use dynamic resolution - - match_tiling_dynamic_resolution: Whether to match tiling with dynamic resolution - - use_area_weighted_aspect_ratio: Whether to use area-weighted aspect ratio - - use_thumbnail: Whether to use thumbnail - - dynamic_resolution_min_patches: Minimum number of patches for dynamic resolution - - dynamic_resolution_min_side: Minimum side length for dynamic resolution (optional) - - thumbnail_area_threshold: Thumbnail area threshold (optional) - - use_tiling: Whether to use tiling - - Returns: - ImageTilingStrategy: The created image tiling strategy - """ - from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings - - assert args.img_h == args.img_w, "img_h and img_w must be the same" - - match_tiling_dynamic_resolution = args.match_tiling_dynamic_resolution - masked_tiling_dynamic_resolution = getattr(args, "masked_tiling_dynamic_resolution", False) - dynamic_resolution = args.dynamic_resolution - use_tiling = args.use_tiling - use_area_weighted_aspect_ratio = args.use_area_weighted_aspect_ratio - - if match_tiling_dynamic_resolution: - assert dynamic_resolution, "must enable --dynamic-resolution if using --match-tiling-dynamic-resolution" - assert not use_tiling, "cannot use --use-tiling and --match-tiling-dynamic-resolution together" - if masked_tiling_dynamic_resolution: - assert dynamic_resolution, "must enable --dynamic-resolution if using --masked-tiling-dynamic-resolution" - assert not use_tiling, "cannot use --use-tiling and --masked-tiling-dynamic-resolution together" - assert not match_tiling_dynamic_resolution, "cannot combine --masked-tiling-dynamic-resolution with --match-tiling-dynamic-resolution" - - if dynamic_resolution: - if masked_tiling_dynamic_resolution: - num_image_embeddings_per_tile = get_num_image_embeddings( - img_h=args.img_h, - img_w=args.img_w, - patch_dim=args.patch_dim, - vision_model_type=args.vision_model_type, - disable_vision_class_token=args.disable_vision_class_token, - class_token_len=1, - pixel_shuffle=args.pixel_shuffle, - use_tile_tags=args.use_tile_tags, - max_num_tiles=args.max_num_tiles, - tokenizer_type=args.tokenizer_prompt_format, - use_image_break_token=args.image_break_token is not None, - conv_merging=args.conv_merging, - ) - image_tiling_strategy = MaskedTilingDynamicResolutionStrategy( - vision_model_type=args.vision_model_type, - tile_size=args.img_h, - use_thumbnail=args.use_thumbnail, - min_num_tiles=1, - max_num_tiles=args.max_num_tiles, - embeddings_per_tile=num_image_embeddings_per_tile, - patch_size=args.patch_dim, - get_num_embeddings=lambda width, height: get_num_image_embeddings( - img_h=height, - img_w=width, - patch_dim=args.patch_dim, - vision_model_type=args.vision_model_type, - disable_vision_class_token=args.disable_vision_class_token, - class_token_len=1, - pixel_shuffle=args.pixel_shuffle, - use_tile_tags=args.use_tile_tags, - max_num_tiles=args.max_num_tiles, - tokenizer_type=args.tokenizer_prompt_format, - use_image_break_token=args.image_break_token is not None, - conv_merging=args.conv_merging, - ), - find_closest_aspect_ratio_fn=( - find_closest_area_weighted_aspect_ratio - if use_area_weighted_aspect_ratio - else find_closest_aspect_ratio - ), - pixel_shuffle=args.pixel_shuffle, - conv_merging=args.conv_merging, - ) - elif match_tiling_dynamic_resolution: - num_image_embeddings_per_tile = get_num_image_embeddings( - img_h=args.img_h, - img_w=args.img_w, - patch_dim=args.patch_dim, - vision_model_type=args.vision_model_type, - disable_vision_class_token=args.disable_vision_class_token, - class_token_len=1, - pixel_shuffle=args.pixel_shuffle, - use_tile_tags=args.use_tile_tags, - max_num_tiles=args.max_num_tiles, - tokenizer_type=args.tokenizer_prompt_format, - use_image_break_token=args.image_break_token is not None, - conv_merging=args.conv_merging, - ) - image_tiling_strategy = MatchTilingDynamicResolutionStrategy( - vision_model_type=args.vision_model_type, - tile_size=args.img_h, - use_thumbnail=args.use_thumbnail, - min_num_tiles=1, - max_num_tiles=args.max_num_tiles, - embeddings_per_tile=num_image_embeddings_per_tile, - patch_size=args.patch_dim, - get_num_embeddings=lambda width, height: get_num_image_embeddings( - img_h=height, - img_w=width, - patch_dim=args.patch_dim, - vision_model_type=args.vision_model_type, - disable_vision_class_token=args.disable_vision_class_token, - class_token_len=1, - pixel_shuffle=args.pixel_shuffle, - use_tile_tags=args.use_tile_tags, - max_num_tiles=args.max_num_tiles, - tokenizer_type=args.tokenizer_prompt_format, - use_image_break_token=args.image_break_token is not None, - conv_merging=args.conv_merging, - ), - find_closest_aspect_ratio_fn=( - find_closest_area_weighted_aspect_ratio - if use_area_weighted_aspect_ratio - else find_closest_aspect_ratio - ), - pixel_shuffle=args.pixel_shuffle, - conv_merging=args.conv_merging, - ) - else: - image_tiling_strategy = DynamicResolutionImageTilingStrategy( - vision_model_type=args.vision_model_type, - min_num_patches=args.dynamic_resolution_min_patches, - patch_size=args.patch_dim, - get_num_embeddings=lambda width, height: get_num_image_embeddings( - img_h=height, - img_w=width, - patch_dim=args.patch_dim, - vision_model_type=args.vision_model_type, - disable_vision_class_token=args.disable_vision_class_token, - class_token_len=1, - pixel_shuffle=args.pixel_shuffle, - use_tile_tags=args.use_tile_tags, - max_num_tiles=args.max_num_tiles, - tokenizer_type=args.tokenizer_prompt_format, - use_image_break_token=args.image_break_token is not None, - conv_merging=args.conv_merging, - ), - pixel_shuffle=args.pixel_shuffle, - min_side=args.dynamic_resolution_min_side, - conv_merging=args.conv_merging, - use_thumbnail=args.use_thumbnail, - thumbnail_size=args.img_h, - thumbnail_area_threshold=args.thumbnail_area_threshold, - max_num_patches=args.dynamic_resolution_max_patches, - apply_data_augment=args.apply_data_augment, - ) - else: - num_image_embeddings_per_tile = get_num_image_embeddings( - img_h=args.img_h, - img_w=args.img_w, - patch_dim=args.patch_dim, - vision_model_type=args.vision_model_type, - disable_vision_class_token=args.disable_vision_class_token, - class_token_len=1, - pixel_shuffle=args.pixel_shuffle, - use_tile_tags=args.use_tile_tags, - max_num_tiles=args.max_num_tiles, - tokenizer_type=args.tokenizer_prompt_format, - use_image_break_token=args.image_break_token is not None, - conv_merging=args.conv_merging, - ) - if use_tiling: - image_strategy = ImageTilingStrategyV1( - vision_model_type=args.vision_model_type, - tile_size=args.img_h, - use_thumbnail=args.use_thumbnail, - min_num_tiles=1, - max_num_tiles=args.max_num_tiles, - embeddings_per_tile=num_image_embeddings_per_tile, - find_closest_aspect_ratio_fn=( - find_closest_area_weighted_aspect_ratio - if use_area_weighted_aspect_ratio - else find_closest_aspect_ratio - ), - ) - else: - image_strategy = NoTilingStrategy( - vision_model_type=args.vision_model_type, - embeddings_per_image=num_image_embeddings_per_tile, - target_width=args.img_w, - target_height=args.img_h, - ) - image_tiling_strategy = TileDegradationStrategy( - image_strategy=image_strategy, - video_frame_strategy=NoTilingStrategy( - vision_model_type=args.vision_model_type, - embeddings_per_image=num_image_embeddings_per_tile, - target_width=args.img_w, - target_height=args.img_h, - ), - embeddings_per_tile=num_image_embeddings_per_tile, - max_num_tiles=args.max_num_tiles, - ) - - return image_tiling_strategy diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 82423891c11c7..3b8a3841cf938 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -113,7 +113,7 @@ pixel_statistics = { @dataclass class DynamicResolutionParams: - image: Image.Image + media: Image.Image num_tiles: int num_embeddings: int patch_size: tuple[int, int] @@ -165,7 +165,7 @@ class DynamicResolutionImageTilingStrategy: self, params: DynamicResolutionParams, **kwargs ) -> list[torch.Tensor]: # resize the image - resized_img = params.image.resize( + resized_img = params.media.resize( ( params.patch_size[0] * self._patch_size, params.patch_size[1] * self._patch_size, @@ -183,7 +183,7 @@ class DynamicResolutionImageTilingStrategy: # Only add thumbnail if resized image area is less than threshold % of # thumbnail area if area_ratio < self._thumbnail_area_threshold: - thumbnail_img = params.image.resize( + thumbnail_img = params.media.resize( (self._thumbnail_size, self._thumbnail_size) ) processed_images.append(thumbnail_img) @@ -192,7 +192,7 @@ class DynamicResolutionImageTilingStrategy: def process_media( self, - image: Image.Image, + media: Image.Image, num_tokens_available: int, data_augment: bool = False, tiling_augment_prob: float = 0.4, @@ -207,10 +207,10 @@ class DynamicResolutionImageTilingStrategy: DynamicResolutionParams for the media """ current_num_tokens_available = num_tokens_available - assert isinstance(image, Image.Image), ( + assert isinstance(media, Image.Image), ( "Dynamic resolution is only supported for image media" ) - orig_width, orig_height = image.width, image.height + orig_width, orig_height = media.width, media.height closest_patch_height = round(orig_height / self._patch_size + 0.5) closest_patch_width = round(orig_width / self._patch_size + 0.5) @@ -336,7 +336,7 @@ class DynamicResolutionImageTilingStrategy: target_patch_width, target_patch_height, current_num_tokens_available ) - assert isinstance(image, Image.Image), ( + assert isinstance(media, Image.Image), ( "Dynamic resolution is only supported for image media" ) @@ -374,7 +374,7 @@ class DynamicResolutionImageTilingStrategy: ) return DynamicResolutionParams( - image=image, + media=media, num_tiles=num_tiles, num_embeddings=num_embeddings, patch_size=(target_patch_width, target_patch_height),