From 6979edb5755c320eac3e1432036c1c8242085012 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Thu, 11 Dec 2025 13:49:27 +0200 Subject: [PATCH 01/10] import --- image_processing.py | 508 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 508 insertions(+) create mode 100644 image_processing.py diff --git a/image_processing.py b/image_processing.py new file mode 100644 index 0000000000000..979ff681edc87 --- /dev/null +++ b/image_processing.py @@ -0,0 +1,508 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE. +import math + +import torch +from einops import rearrange +from torchvision import transforms as T +from torchvision.transforms import Compose +from torchvision.transforms.functional import InterpolationMode + + +IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] +IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] +SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5] +SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5] +CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] +CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] +RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060] +RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250] + + +pixel_statistics = { + "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), + "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), + "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD), + "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "internvit300M": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), + "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), +} + + +# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685 +# Copyright (c) 2023 OpenGVLab. +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def find_closest_area_weighted_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + """ + Find the best number of tiles based on the aspect ratio and the area covered by the tiles. + """ + best_factor = float('-inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + factor_based_on_area_n_ratio = ( + min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6) * + min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)) + if factor_based_on_area_n_ratio > best_factor: + best_factor = factor_based_on_area_n_ratio + best_ratio = ratio + return best_ratio + + +def process_images(sample_imgs, patch_dim, dynamic_resolution, batch_mode=False): + """Process a batch of images for multimodal training or evaluation. + + This function handles image preprocessing with support for both static and dynamic + resolution processing. For dynamic resolution, it rearranges images into patches + and computes cumulative sequence lengths for efficient batching. + + Args: + sample_imgs (List[torch.Tensor]): List of image tensors with shape (C, H, W). + patch_dim (int): Dimension of each patch (e.g., 14 for 14x14 patches). + dynamic_resolution (bool): Whether to use dynamic resolution processing. + If True, images are rearranged into patches with variable sequence lengths. + If False, images are simply stacked into a batch tensor. + batch_mode (bool, optional): Whether this is being called from training batch processing. + If True, wraps tensors in additional list dimension for consistency with batch format. + If False, returns tensors directly as used in evaluation. Defaults to False. + + Returns: + tuple: A 4-tuple containing: + - images (torch.Tensor): Processed image tensor. + For dynamic resolution: shape (1, total_patches, patch_features) if batch_mode=False, + or shape (1, total_patches, patch_features) if batch_mode=True + For static resolution: shape (batch_size, C, H, W) + - imgs_sizes (torch.Tensor): Image sizes tensor with shape (N, 2) + containing [width, height] for each image, or [[0,0]] if no images. + - vision_cu_lengths (torch.Tensor or None): Cumulative sequence lengths + for dynamic resolution. Shape (batch_size + 1,) for evaluation mode, + or shape (1, batch_size + 1) for batch mode. None for static resolution. + - vision_max_lengths (torch.Tensor or None): Maximum sequence length + among all images for dynamic resolution. Scalar tensor for evaluation mode, + or shape (1,) for batch mode. None for static resolution. + + Note: + This function is designed for processing one microbatch at a time for dynamic resolution. + """ + vision_cu_lengths = None + vision_max_lengths = None + + if len(sample_imgs) > 0: + imgs_sizes = torch.tensor([[img.shape[1], img.shape[2]] for img in sample_imgs], dtype=torch.int32) + if dynamic_resolution: + def rearrange_img(x): + py = x.shape[-2] // patch_dim + px = x.shape[-1] // patch_dim + x = rearrange(x, 'c (py yy) (px xx) -> (py px) (c yy xx)', + py=py, yy=patch_dim, + px=px, xx=patch_dim, + ) + return x + imgs = [rearrange_img(img) for img in sample_imgs] + + current_length = 0 + max_length = 0 + vision_cu_lengths = [0] + for img in imgs: + if max_length < img.shape[0]: + max_length = img.shape[0] + current_length += img.shape[0] + vision_cu_lengths.append(current_length) + + vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) + vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) + + # For batch mode, wrap in additional dimension for consistency + if batch_mode: + vision_cu_lengths = vision_cu_lengths.unsqueeze(0) # Shape: (1, batch_size + 1) + vision_max_lengths = vision_max_lengths.unsqueeze(0) # Shape: (1,) + + imgs = torch.cat(imgs, dim=0) + images = imgs.unsqueeze(0) + else: + images = torch.stack(sample_imgs) + else: + imgs_sizes = torch.tensor([[0,0]], dtype=torch.int32) + if len(sample_imgs) == 0 and batch_mode: + # For batch mode when no images, use appropriate dummy tensor + images = torch.tensor([[0]], dtype=torch.float32) + else: + images = torch.stack(sample_imgs) + + return images, imgs_sizes, vision_cu_lengths, vision_max_lengths + + +class ImageTransform: + """Image transformation.""" + + def __init__(self, input_size, vision_model_type, *, dynamic_resolution=False, res_step=16, min_num_patches=1, max_num_patches=128, pixel_shuffle=False, min_side=None, conv_merging=False, match_tiling_dynamic_resolution=False, masked_tiling_dynamic_resolution=False, thumbnail_area_threshold=0.8): + self._transform = _build_transform(input_size, vision_model_type) + self._vision_model_type = vision_model_type + self._dynamic_resolution = dynamic_resolution + self._res_step = res_step + self._min_num_patches = min_num_patches + self._max_num_patches = max_num_patches + self._pixel_shuffle = pixel_shuffle + self._min_side = min_side + self._conv_merging = conv_merging + self._match_tiling_dynamic_resolution = match_tiling_dynamic_resolution + self._masked_tiling_dynamic_resolution = masked_tiling_dynamic_resolution + self._thumbnail_area_threshold = thumbnail_area_threshold + + def __call__(self, img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, find_closest_aspect_ratio_fn=find_closest_aspect_ratio, is_video=False): + assert not augment, "Image augmentation not implemented." + if use_tiling: + assert img_h == img_w, "dynamic tiling expects equal tile height and width" + imgs = dynamic_preprocess( + img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail, + find_closest_aspect_ratio_fn=find_closest_aspect_ratio_fn) + imgs = [self._transform(img) for img in imgs] + elif self._masked_tiling_dynamic_resolution: + assert img_h == img_w, "masked tiling dynamic resolution expects equal tile height and width" + assert "radio" in self._vision_model_type, "Masked tiling dynamic resolution is only supported for radio models" + + # Use tiling logic to determine tile grid (nx, ny) + orig_width, orig_height = img.size + aspect_ratio = orig_width / orig_height + + target_ratios = set( + (i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) + if i * j <= max_num_tiles and i * j >= 1 + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + tiling = find_closest_aspect_ratio_fn( + aspect_ratio, target_ratios, orig_width, orig_height, img_h + ) + + # Resize and split into tiles of size (img_h x img_h) + target_width = img_h * tiling[0] + target_height = img_w * tiling[1] + blocks = tiling[0] * tiling[1] + + resized_img = img.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // img_h)) * img_h, + (i // (target_width // img_h)) * img_h, + ((i % (target_width // img_h)) + 1) * img_h, + ((i // (target_width // img_h)) + 1) * img_h, + ) + tile_img = resized_img.crop(box) + processed_images.append(tile_img) + assert len(processed_images) == blocks + + # Optional thumbnail + if use_thumbnail and blocks != 1: + thumbnail_img = img.resize((img_h, img_h)) + processed_images.append(thumbnail_img) + + pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std), + ]) + imgs = [transform(im) for im in processed_images] + elif self._match_tiling_dynamic_resolution: + assert img_h == img_w, "match tiling dynamic resolution expects equal tile height and width" + assert "radio" in self._vision_model_type, "Match tiling dynamic resolution is only supported for radio models" + + # Use tiling logic to determine optimal dimensions + orig_width, orig_height = img.size + aspect_ratio = orig_width / orig_height + + # Calculate target ratios (same logic as tiling) + target_ratios = set( + (i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num_tiles and i * j >= 1) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # Find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio_fn( + aspect_ratio, target_ratios, orig_width, orig_height, img_h) + + # Calculate the target width and height using tiling logic + target_width = img_h * target_aspect_ratio[0] + target_height = img_w * target_aspect_ratio[1] + + # Resize image to target dimensions (same as tiling, but don't split) + resized_img = img.resize((target_width, target_height)) + + # Process as single dynamic resolution image + pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std), + ]) + processed_images = [resized_img] + + # Add thumbnail if use_thumbnail=True and there's more than 1 tile + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + if use_thumbnail and blocks != 1: + thumbnail_img = img.resize((img_h, img_h)) + processed_images.append(thumbnail_img) + + imgs = [transform(img) for img in processed_images] + elif self._dynamic_resolution: + pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std), + ]) + processed_img = dynamic_res_preprocess(img, min_patches=self._min_num_patches, max_patches=self._max_num_patches, res_step=self._res_step, pixel_shuffle=self._pixel_shuffle, min_side=self._min_side, conv_merging=self._conv_merging, is_video=is_video) + processed_images = [processed_img] + + # Add thumbnail if enabled and image area is below threshold + if use_thumbnail: + # Calculate areas + processed_width, processed_height = processed_img.size + resized_area = processed_width * processed_height + thumbnail_area = img_h * img_h # img_h should be square thumbnail size + area_ratio = resized_area / thumbnail_area + + # Only add thumbnail if resized image area is less than threshold % of thumbnail area + if area_ratio < self._thumbnail_area_threshold: + thumbnail_img = img.resize((img_h, img_h)) # Use square thumbnail with img_h size + processed_images.append(thumbnail_img) + + imgs = [transform(img) for img in processed_images] + else: + imgs = [self._transform(img)] + + return imgs + + +# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702 +# Copyright (c) 2023 OpenGVLab. +def dynamic_preprocess( + image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, + find_closest_aspect_ratio_fn=find_closest_aspect_ratio): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio_fn( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def dynamic_res_preprocess(image, min_patches=1, max_patches=128, res_step=16, factor_max=1., pixel_shuffle=False, min_side=None, conv_merging=False, is_video=False): + """Preprocess an image with dynamic resolution for vision transformers. + + This function resizes an image to optimize the number of patches while respecting + constraints on minimum/maximum patches, minimum side length, and compatibility + with pixel shuffle or convolution merging operations. + + The algorithm works by: + 1. Computing the initial patch grid size based on the image dimensions and res_step + 2. Scaling the patch grid to fit within the max_patches constraint + 3. Ensuring the result has at least min_patches + 4. Optionally enforcing a minimum side length constraint + 5. Rounding patch dimensions to even numbers for pixel_shuffle/conv_merging compatibility + 6. Resizing the image to the computed target dimensions + + Args: + image (PIL.Image): Input image to preprocess. + min_patches (int, optional): Minimum number of patches required. Defaults to 1. + max_patches (int, optional): Maximum number of patches allowed. Defaults to 128. + res_step (int, optional): Resolution step size (patch dimension). Defaults to 16. + factor_max (float, optional): Maximum scaling factor to apply. Defaults to 1.0. + pixel_shuffle (bool, optional): Whether to ensure compatibility with pixel shuffle + operations by rounding to even patch dimensions. Defaults to False. + min_side (int, optional): Minimum side length in pixels. If specified, ensures + at least one side meets this constraint. Defaults to None. + conv_merging (bool, optional): Whether to ensure compatibility with convolution + merging by rounding to even patch dimensions. Defaults to False. + + Returns: + PIL.Image: Resized image with dimensions optimized for patch-based processing. + The output dimensions will be (target_patch_width * res_step, target_patch_height * res_step). + + Note: + The function preserves aspect ratio as much as possible while satisfying all constraints. + When constraints conflict (e.g., min_side vs max_patches), the function prioritizes + staying within max_patches while maximizing the image size. + + Example: + >>> from PIL import Image + >>> img = Image.open("example.jpg") # 800x600 image + >>> resized_img = dynamic_res_preprocess(img, min_patches=4, max_patches=64, res_step=14) + >>> # Returns image resized to maintain aspect ratio with 4-64 patches of size 14x14 + """ + orig_width, orig_height = image.size + + closest_patch_height = round(orig_height / res_step + 0.5) + closest_patch_width = round(orig_width / res_step + 0.5) + patches = closest_patch_height * closest_patch_width + + factor = min(math.sqrt(max_patches / patches), factor_max) + target_patch_height = math.floor(factor * closest_patch_height) + target_patch_width = math.floor(factor * closest_patch_width) + + if target_patch_height * target_patch_width < min_patches: + up_factor = math.sqrt(min_patches / (target_patch_height * target_patch_width)) + target_patch_height = math.ceil(up_factor * target_patch_height) + target_patch_width = math.ceil(up_factor * target_patch_width) + + if min_side is not None and min(target_patch_width, target_patch_height) * res_step < min_side: + if target_patch_width <= target_patch_height: + up_factor = min_side / (target_patch_width * res_step) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > max_patches: + # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches + if max(max_patches // new_patch_width, 1) * res_step < min_side: + up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width)) + target_patch_height = math.floor(up_factor * target_patch_height) + target_patch_width = math.floor(up_factor * target_patch_width) + target_patch_width = new_patch_width + target_patch_height = max(max_patches // new_patch_width, 1) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width + else: + up_factor = min_side / (target_patch_height * res_step) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > max_patches: + # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches + if max(max_patches // new_patch_height, 1) * res_step < min_side: + up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width)) + target_patch_height = math.floor(up_factor * target_patch_height) + target_patch_width = math.floor(up_factor * target_patch_width) + else: + target_patch_height = new_patch_height + target_patch_width = max(max_patches // new_patch_height, 1) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width + + # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging) + # or by 4 when BOTH are enabled (two successive 2x reductions) + if pixel_shuffle or conv_merging: + required_divisor = 4 if (pixel_shuffle and conv_merging) else 2 + + rem_h = target_patch_height % required_divisor + if rem_h != 0: + inc_h = required_divisor - rem_h + if (target_patch_height + inc_h) * target_patch_width <= max_patches: + target_patch_height += inc_h + else: + target_patch_height = max(1, target_patch_height - rem_h) + + rem_w = target_patch_width % required_divisor + if rem_w != 0: + inc_w = required_divisor - rem_w + if target_patch_height * (target_patch_width + inc_w) <= max_patches: + target_patch_width += inc_w + else: + target_patch_width = max(1, target_patch_width - rem_w) + assert target_patch_height * target_patch_width <= max_patches + + #TEMP: hacky way to process video same as in training + if is_video: + # max_patches = 1024 + # min_patches = 512 + target_patch_width = 32 + target_patch_height = 32 + + + # resize the image + resized_img = image.resize((target_patch_width * res_step, target_patch_height * res_step)) + + return resized_img + + + +# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 +# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 +def _build_transform(input_size, vision_model_type): + if vision_model_type in ("siglip", "internvit", "internvit300M", "radio", "radio-g", "cradio-g"): + pixel_mean, pixel_std = pixel_statistics[vision_model_type] + + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std) + ]) + elif vision_model_type == "clip": + pixel_mean, pixel_std = pixel_statistics[vision_model_type] + + transform = Compose([ + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std), + ]) + elif vision_model_type.startswith("hf://"): + from megatron.core.models.huggingface.module import get_hf_model_type + + model_type = get_hf_model_type(vision_model_type) + if "siglip" in model_type: + from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor + + processor = SiglipImageProcessor(size={"height": input_size, "width": input_size}) + + def transform(x): + x = x.convert("RGB") if x.mode != "RGB" else x + x = processor(x, return_tensors="pt") + return x["pixel_values"][0] + else: + raise NotImplementedError(f"image processing not defined for huggingface model {vision_model_type}") + else: + raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}") + + return transform From 50ffea98266c3a7910306ce185e198171bf68bb6 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Thu, 11 Dec 2025 14:12:03 +0200 Subject: [PATCH 02/10] update --- image_processing.py | 2126 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 1723 insertions(+), 403 deletions(-) diff --git a/image_processing.py b/image_processing.py index 979ff681edc87..5d43e764acadd 100644 --- a/image_processing.py +++ b/image_processing.py @@ -1,12 +1,23 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE. +from abc import ABC, abstractmethod +from dataclasses import dataclass import math +from typing import Callable, Optional +import numpy as np +import random +from PIL import Image +import albumentations as A +import einops import torch -from einops import rearrange from torchvision import transforms as T from torchvision.transforms import Compose from torchvision.transforms.functional import InterpolationMode +from data_loading.conversation_sample import ( + ImageMedia, + VideoFrameMedia, +) IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] @@ -24,16 +35,23 @@ pixel_statistics = { "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD), - "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "internvit300M": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), + "radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), } # From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685 # Copyright (c) 2023 OpenGVLab. -def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): - best_ratio_diff = float('inf') +def find_closest_aspect_ratio( + aspect_ratio: float, + target_ratios: list[tuple[int, int]], + width: int, + height: int, + image_size: int, +) -> tuple[int, int]: + best_ratio_diff = float("inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: @@ -48,301 +66,548 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_ return best_ratio -def find_closest_area_weighted_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): +def find_closest_area_weighted_aspect_ratio( + aspect_ratio: float, + target_ratios: list[tuple[int, int]], + width: int, + height: int, + image_size: int, +): """ Find the best number of tiles based on the aspect ratio and the area covered by the tiles. """ - best_factor = float('-inf') + best_factor = float("-inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] - factor_based_on_area_n_ratio = ( - min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6) * - min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)) + factor_based_on_area_n_ratio = min( + (ratio[0] * ratio[1] * image_size * image_size) / area, 0.6 + ) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio) if factor_based_on_area_n_ratio > best_factor: best_factor = factor_based_on_area_n_ratio best_ratio = ratio return best_ratio -def process_images(sample_imgs, patch_dim, dynamic_resolution, batch_mode=False): - """Process a batch of images for multimodal training or evaluation. - - This function handles image preprocessing with support for both static and dynamic - resolution processing. For dynamic resolution, it rearranges images into patches - and computes cumulative sequence lengths for efficient batching. - - Args: - sample_imgs (List[torch.Tensor]): List of image tensors with shape (C, H, W). - patch_dim (int): Dimension of each patch (e.g., 14 for 14x14 patches). - dynamic_resolution (bool): Whether to use dynamic resolution processing. - If True, images are rearranged into patches with variable sequence lengths. - If False, images are simply stacked into a batch tensor. - batch_mode (bool, optional): Whether this is being called from training batch processing. - If True, wraps tensors in additional list dimension for consistency with batch format. - If False, returns tensors directly as used in evaluation. Defaults to False. - - Returns: - tuple: A 4-tuple containing: - - images (torch.Tensor): Processed image tensor. - For dynamic resolution: shape (1, total_patches, patch_features) if batch_mode=False, - or shape (1, total_patches, patch_features) if batch_mode=True - For static resolution: shape (batch_size, C, H, W) - - imgs_sizes (torch.Tensor): Image sizes tensor with shape (N, 2) - containing [width, height] for each image, or [[0,0]] if no images. - - vision_cu_lengths (torch.Tensor or None): Cumulative sequence lengths - for dynamic resolution. Shape (batch_size + 1,) for evaluation mode, - or shape (1, batch_size + 1) for batch mode. None for static resolution. - - vision_max_lengths (torch.Tensor or None): Maximum sequence length - among all images for dynamic resolution. Scalar tensor for evaluation mode, - or shape (1,) for batch mode. None for static resolution. - - Note: - This function is designed for processing one microbatch at a time for dynamic resolution. +# Mike's optimized ToTensor. +def _fast_to_tensor(pic) -> torch.Tensor: + np_img = np.array(pic, copy=False) + img = torch.from_numpy(np_img) + img = img.permute(2, 0, 1) # HWC to CHW + fp_img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format) + fp_img.div_(255) + return fp_img + + +@dataclass +class ImageTilingParams: + media: ImageMedia | VideoFrameMedia + num_tiles: int + num_embeddings: int + + +class ImageTilingStrategy(ABC): """ - vision_cu_lengths = None - vision_max_lengths = None - - if len(sample_imgs) > 0: - imgs_sizes = torch.tensor([[img.shape[1], img.shape[2]] for img in sample_imgs], dtype=torch.int32) - if dynamic_resolution: - def rearrange_img(x): - py = x.shape[-2] // patch_dim - px = x.shape[-1] // patch_dim - x = rearrange(x, 'c (py yy) (px xx) -> (py px) (c yy xx)', - py=py, yy=patch_dim, - px=px, xx=patch_dim, - ) - return x - imgs = [rearrange_img(img) for img in sample_imgs] + Base class for image tiling strategies. + A tiling strategy is a function that takes a list of media and returns a list of image tiling parameters. + These can then be used to apply the tiling to the media. - current_length = 0 - max_length = 0 - vision_cu_lengths = [0] - for img in imgs: - if max_length < img.shape[0]: - max_length = img.shape[0] - current_length += img.shape[0] - vision_cu_lengths.append(current_length) - - vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) - vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) - - # For batch mode, wrap in additional dimension for consistency - if batch_mode: - vision_cu_lengths = vision_cu_lengths.unsqueeze(0) # Shape: (1, batch_size + 1) - vision_max_lengths = vision_max_lengths.unsqueeze(0) # Shape: (1,) - - imgs = torch.cat(imgs, dim=0) - images = imgs.unsqueeze(0) - else: - images = torch.stack(sample_imgs) - else: - imgs_sizes = torch.tensor([[0,0]], dtype=torch.int32) - if len(sample_imgs) == 0 and batch_mode: - # For batch mode when no images, use appropriate dummy tensor - images = torch.tensor([[0]], dtype=torch.float32) - else: - images = torch.stack(sample_imgs) + Subclasses must implement the `compute_params` and `apply_params` methods. - return images, imgs_sizes, vision_cu_lengths, vision_max_lengths + The `transform` method is a convenience method that computes the transformation parameters and applies the transformation to the media. + + """ + + def transform( + self, + media_list: list[ImageMedia | VideoFrameMedia], + num_tokens_available: int | None = None, + ) -> list[torch.Tensor]: + """ + Transform the media and compute the transformation parameters. + """ + transform_media_list = self.compute_params(media_list, num_tokens_available) + return [ + self.apply_params(transform_media, **kwargs) + for transform_media in transform_media_list + ] + + @abstractmethod + def compute_params( + self, media_list: list[ImageMedia | VideoFrameMedia], num_tokens_available: int, max_num_tiles: int | None = None, **kwargs + ) -> list[ImageTilingParams]: + """ + Compute the transformation parameters and the number of tokens to use for the media. + + Args: + media_list: List of media to transform + num_tokens_available: Number of tokens available for all media + max_num_tiles: Maximum number of tiles allowed (optional, defaults to instance's max_num_tiles if not provided) + + Returns: + list of transformation parameters with the media + """ + ... + + @abstractmethod + def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]: + """ + Apply the transformation parameters to the media. + + Args: + transform_media: The media to apply the transformation to + + Returns: + list of transformed media tensors + """ + ... + + @abstractmethod + def stack( + self, images: list[torch.Tensor] + ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]: + """ + Stack the images into a single tensor. + + Args: + media_list: List of images to stack + + Returns: + tuple of (stacked media, image sizes, vision cu lengths, vision max lengths) + """ + ... -class ImageTransform: - """Image transformation.""" +class _FixedSizeStrategy(ImageTilingStrategy): + """ + Base class for fixed size image tiling strategies. + """ - def __init__(self, input_size, vision_model_type, *, dynamic_resolution=False, res_step=16, min_num_patches=1, max_num_patches=128, pixel_shuffle=False, min_side=None, conv_merging=False, match_tiling_dynamic_resolution=False, masked_tiling_dynamic_resolution=False, thumbnail_area_threshold=0.8): - self._transform = _build_transform(input_size, vision_model_type) + def __init__( + self, + vision_model_type: str, + target_width: int, + target_height: int, + embeddings_per_image: int, + ): self._vision_model_type = vision_model_type - self._dynamic_resolution = dynamic_resolution - self._res_step = res_step - self._min_num_patches = min_num_patches - self._max_num_patches = max_num_patches - self._pixel_shuffle = pixel_shuffle - self._min_side = min_side - self._conv_merging = conv_merging - self._match_tiling_dynamic_resolution = match_tiling_dynamic_resolution - self._masked_tiling_dynamic_resolution = masked_tiling_dynamic_resolution - self._thumbnail_area_threshold = thumbnail_area_threshold - - def __call__(self, img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, find_closest_aspect_ratio_fn=find_closest_aspect_ratio, is_video=False): - assert not augment, "Image augmentation not implemented." - if use_tiling: - assert img_h == img_w, "dynamic tiling expects equal tile height and width" - imgs = dynamic_preprocess( - img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail, - find_closest_aspect_ratio_fn=find_closest_aspect_ratio_fn) - imgs = [self._transform(img) for img in imgs] - elif self._masked_tiling_dynamic_resolution: - assert img_h == img_w, "masked tiling dynamic resolution expects equal tile height and width" - assert "radio" in self._vision_model_type, "Masked tiling dynamic resolution is only supported for radio models" - - # Use tiling logic to determine tile grid (nx, ny) - orig_width, orig_height = img.size - aspect_ratio = orig_width / orig_height - - target_ratios = set( - (i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) - if i * j <= max_num_tiles and i * j >= 1 - ) - target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) - - tiling = find_closest_aspect_ratio_fn( - aspect_ratio, target_ratios, orig_width, orig_height, img_h - ) - - # Resize and split into tiles of size (img_h x img_h) - target_width = img_h * tiling[0] - target_height = img_w * tiling[1] - blocks = tiling[0] * tiling[1] - - resized_img = img.resize((target_width, target_height)) - processed_images = [] - for i in range(blocks): - box = ( - (i % (target_width // img_h)) * img_h, - (i // (target_width // img_h)) * img_h, - ((i % (target_width // img_h)) + 1) * img_h, - ((i // (target_width // img_h)) + 1) * img_h, - ) - tile_img = resized_img.crop(box) - processed_images.append(tile_img) - assert len(processed_images) == blocks - - # Optional thumbnail - if use_thumbnail and blocks != 1: - thumbnail_img = img.resize((img_h, img_h)) - processed_images.append(thumbnail_img) - - pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] - transform = T.Compose([ - T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), - T.ToTensor(), - T.Normalize(mean=pixel_mean, std=pixel_std), - ]) - imgs = [transform(im) for im in processed_images] - elif self._match_tiling_dynamic_resolution: - assert img_h == img_w, "match tiling dynamic resolution expects equal tile height and width" - assert "radio" in self._vision_model_type, "Match tiling dynamic resolution is only supported for radio models" - - # Use tiling logic to determine optimal dimensions - orig_width, orig_height = img.size - aspect_ratio = orig_width / orig_height - - # Calculate target ratios (same logic as tiling) - target_ratios = set( - (i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) if - i * j <= max_num_tiles and i * j >= 1) - target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) - - # Find the closest aspect ratio to the target - target_aspect_ratio = find_closest_aspect_ratio_fn( - aspect_ratio, target_ratios, orig_width, orig_height, img_h) - - # Calculate the target width and height using tiling logic - target_width = img_h * target_aspect_ratio[0] - target_height = img_w * target_aspect_ratio[1] - - # Resize image to target dimensions (same as tiling, but don't split) - resized_img = img.resize((target_width, target_height)) - - # Process as single dynamic resolution image - pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] - transform = T.Compose([ - T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), - T.ToTensor(), - T.Normalize(mean=pixel_mean, std=pixel_std), - ]) - processed_images = [resized_img] - - # Add thumbnail if use_thumbnail=True and there's more than 1 tile - blocks = target_aspect_ratio[0] * target_aspect_ratio[1] - if use_thumbnail and blocks != 1: - thumbnail_img = img.resize((img_h, img_h)) - processed_images.append(thumbnail_img) - - imgs = [transform(img) for img in processed_images] - elif self._dynamic_resolution: - pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] - transform = T.Compose([ - T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), - T.ToTensor(), - T.Normalize(mean=pixel_mean, std=pixel_std), - ]) - processed_img = dynamic_res_preprocess(img, min_patches=self._min_num_patches, max_patches=self._max_num_patches, res_step=self._res_step, pixel_shuffle=self._pixel_shuffle, min_side=self._min_side, conv_merging=self._conv_merging, is_video=is_video) - processed_images = [processed_img] - - # Add thumbnail if enabled and image area is below threshold - if use_thumbnail: - # Calculate areas - processed_width, processed_height = processed_img.size - resized_area = processed_width * processed_height - thumbnail_area = img_h * img_h # img_h should be square thumbnail size - area_ratio = resized_area / thumbnail_area - - # Only add thumbnail if resized image area is less than threshold % of thumbnail area - if area_ratio < self._thumbnail_area_threshold: - thumbnail_img = img.resize((img_h, img_h)) # Use square thumbnail with img_h size - processed_images.append(thumbnail_img) - - imgs = [transform(img) for img in processed_images] - else: - imgs = [self._transform(img)] - - return imgs - - -# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702 -# Copyright (c) 2023 OpenGVLab. -def dynamic_preprocess( - image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, - find_closest_aspect_ratio_fn=find_closest_aspect_ratio): - orig_width, orig_height = image.size - aspect_ratio = orig_width / orig_height - - # calculate the existing image aspect ratio - target_ratios = set( - (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if - i * j <= max_num and i * j >= min_num) - target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) - - # find the closest aspect ratio to the target - target_aspect_ratio = find_closest_aspect_ratio_fn( - aspect_ratio, target_ratios, orig_width, orig_height, image_size) - - # calculate the target width and height - target_width = image_size * target_aspect_ratio[0] - target_height = image_size * target_aspect_ratio[1] - blocks = target_aspect_ratio[0] * target_aspect_ratio[1] - - # resize the image - resized_img = image.resize((target_width, target_height)) - processed_images = [] - for i in range(blocks): - box = ( - (i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size + self._target_width = target_width + self._target_height = target_height + self._embeddings_per_image = embeddings_per_image + self._transform = self._build_transform( + (target_width, target_height), vision_model_type + ) + + # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 + # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 + @staticmethod + def _build_transform(target_size: tuple[int, int], vision_model_type: str): + """ + Build a transform for a given vision model type and target size. + """ + if vision_model_type in ("siglip", "internvit", "radio", "radio-g", "cradio-g"): + pixel_mean, pixel_std = pixel_statistics[vision_model_type] + + transform = T.Compose( + [ + T.Lambda( + lambda img: img.convert("RGB") if img.mode != "RGB" else img + ), + T.Resize( + (target_size[1], target_size[0]), + interpolation=InterpolationMode.BICUBIC, + ), + T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)), + T.Normalize(mean=pixel_mean, std=pixel_std), + ] + ) + # From the official CLIP repo. + elif vision_model_type == "clip": + pixel_mean, pixel_std = pixel_statistics[vision_model_type] + + transform = Compose( + [ + T.Resize( + (target_size[1], target_size[0]), + interpolation=InterpolationMode.BICUBIC, + ), + T.Lambda( + lambda img: img.convert("RGB") if img.mode != "RGB" else img + ), + T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)), + T.Normalize(mean=pixel_mean, std=pixel_std), + ] + ) + elif vision_model_type.startswith("hf://"): + from megatron.core.models.huggingface.module import get_hf_model_type + + model_type = get_hf_model_type(vision_model_type) + if "siglip" in model_type: + from transformers.models.siglip.image_processing_siglip import ( + SiglipImageProcessor, + ) + + processor = SiglipImageProcessor( + size={"height": target_size[1], "width": target_size[0]} + ) + + def transform(x): + x = x.convert("RGB") if x.mode != "RGB" else x + x = processor(x, return_tensors="pt") + return x["pixel_values"][0] + else: + raise NotImplementedError( + f"image processing not defined for huggingface model {vision_model_type}" + ) + else: + raise NotImplementedError( + f"image processing not defined for vision model {vision_model_type}" + ) + + return transform + + def stack( + self, images: list[torch.Tensor] + ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: + return ( + torch.stack(images) if len(images) > 0 else None, + torch.tensor( + [(img.shape[1], img.shape[2]) for img in images], dtype=torch.int32 + ) if len(images) > 0 else None, + None, + None, ) - # split the image - split_img = resized_img.crop(box) - processed_images.append(split_img) - assert len(processed_images) == blocks - if use_thumbnail and len(processed_images) != 1: - thumbnail_img = image.resize((image_size, image_size)) - processed_images.append(thumbnail_img) - return processed_images -def dynamic_res_preprocess(image, min_patches=1, max_patches=128, res_step=16, factor_max=1., pixel_shuffle=False, min_side=None, conv_merging=False, is_video=False): +class NoTilingStrategy(_FixedSizeStrategy): + """ + A simple image transformation that resizes the image to the target width and height. + """ + + def __init__( + self, + vision_model_type: str, + target_width: int, + target_height: int, + embeddings_per_image: int, + ): + super().__init__( + vision_model_type=vision_model_type, + target_width=target_width, + target_height=target_height, + embeddings_per_image=embeddings_per_image, + ) + + def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]: + return [self._transform(transform_media.media.value)] + + def compute_params( + self, + media_list: list[ImageMedia | VideoFrameMedia], + num_tokens_available: Optional[int] = None, + max_num_tiles: int | None = None, + **kwargs, + ) -> list[ImageTilingParams]: + return [ + ImageTilingParams( + media=media, num_tiles=1, num_embeddings=self._embeddings_per_image + ) + for media in media_list + ] + + def __str__(self): + return f"SimpleImageTransform(vision_model_type={self._vision_model_type}, num_tokens_per_image={self._embeddings_per_image})" + + +@dataclass +class ImageTilingParamsV1(ImageTilingParams): + tiling: tuple[int, int] + + +class ImageTilingStrategyV1(_FixedSizeStrategy): + """Tiling image transformation. + + This transformation splits the image into a grid of tiles and applies the transformation to each tile. + """ + + # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 + # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 + + def __init__( + self, + vision_model_type: str, + tile_size: int, + use_thumbnail: bool, + min_num_tiles: int, + max_num_tiles: int, + embeddings_per_tile: int, + find_closest_aspect_ratio_fn=find_closest_aspect_ratio, + ): + super().__init__( + vision_model_type=vision_model_type, + target_width=tile_size, + target_height=tile_size, + embeddings_per_image=embeddings_per_tile, + ) + + # print(f"Transformation params: {vision_model_type=}, {use_tiling=}, {tile_size=}, {use_thumbnail=}, {augment=}, {min_num_tiles=}, {max_num_tiles=}, {find_closest_aspect_ratio_fn=}") + self._tile_size = tile_size + self._use_thumbnail = use_thumbnail + self._min_num_tiles = min_num_tiles + self._max_num_tiles = max_num_tiles + self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn + + # Calculate all possible aspect ratios for each max_num_tiles. + self.target_ratios = { + max_num_tiles: sorted( + set( + (x, y) + for n in range(self._min_num_tiles, max_num_tiles + 1) + for x in range(1, n + 1) + for y in range(1, n + 1) + if x * y <= max_num_tiles and x * y >= self._min_num_tiles + ), + key=lambda x: x[0] * x[1], + ) + for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1) + } + + self.transform = A.Compose([ + A.OneOf([ + A.GaussNoise(var_limit=(5.0, 30.0)), + A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5)), + ], p=0.3), + A.OneOf([ + A.MedianBlur(blur_limit=5), + A.GaussianBlur(blur_limit=5), + ], p=0.2), + A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.5), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=15, val_shift_limit=15, p=0.3), + A.ImageCompression(quality_lower=70, quality_upper=100, p=0.3), + ]) + + def apply_params(self, transform_media: ImageTilingParams, data_augment: bool = False, **kwargs) -> list[torch.Tensor]: + assert isinstance(transform_media, ImageTilingParamsV1) + image = transform_media.media.value + + if data_augment: + image = self.transform(image=np.asarray(image))["image"] + image = Image.fromarray(image) + + # calculate the target width and height + target_width = self._tile_size * transform_media.tiling[0] + target_height = self._tile_size * transform_media.tiling[1] + blocks = transform_media.tiling[0] * transform_media.tiling[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // self._tile_size)) * self._tile_size, + (i // (target_width // self._tile_size)) * self._tile_size, + ((i % (target_width // self._tile_size)) + 1) * self._tile_size, + ((i // (target_width // self._tile_size)) + 1) * self._tile_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if self._use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((self._tile_size, self._tile_size)) + processed_images.append(thumbnail_img) + + return [self._transform(img) for img in processed_images] + + def compute_params( + self, + media_list: list[ImageMedia | VideoFrameMedia], + num_tokens_available: Optional[int] = None, + max_num_tiles: int | None = None, + data_augment: bool = False, + tiling_augment_prob: float = 0.4, + **kwargs, + ) -> list[ImageTilingParamsV1]: + # Use provided max_num_tiles or fall back to instance's max_num_tiles + # Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value + effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles + effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles) + + max_num_tiles_to_use = min( + num_tokens_available // self._embeddings_per_image, effective_max_num_tiles + ) + + # calculate the existing image aspect ratio + target_ratios = self.target_ratios[max_num_tiles_to_use] + + params = [] + for media in media_list: + if isinstance(media, ImageMedia): + img_size = (media.width, media.height) + elif isinstance(media, VideoFrameMedia): + img_size = (media.video_width, media.video_height) + else: + raise ValueError(f"Unsupported media type: {type(media)}") + + aspect_ratio = img_size[0] / img_size[1] + + # find the closest aspect ratio to the target + tiling = self._find_closest_aspect_ratio_fn( + aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size + ) + if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob: + tiling = self.augment_tiling(tiling) + num_tiles = tiling[0] * tiling[1] + if self._use_thumbnail and num_tiles != 1: + num_tiles += 1 + + params.append( + ImageTilingParamsV1( + media=media, + num_tiles=num_tiles, + num_embeddings=num_tiles * self._embeddings_per_image, + tiling=tiling, + ) + ) + + return params + + def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]: + def num_tiles(tiling: tuple[int, int]) -> int: + return tiling[0] * tiling[1] + + def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]: + if random.random() < minus_prob: + # Minus one + if tiling[0] == 1 and tiling[1] == 1: + return tiling + elif tiling[0] == 1: + return (tiling[0], tiling[1] - 1) + elif tiling[1] == 1: + return (tiling[0] - 1, tiling[1]) + else: + if random.random() < 0.5: + return (tiling[0] - 1, tiling[1]) + else: + return (tiling[0], tiling[1] - 1) + else: + # Plus one + if num_tiles(tiling) < self._max_num_tiles: + tiling0 = (tiling[0] + 1, tiling[1]) + tiling1 = (tiling[0], tiling[1] + 1) + if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles: + return tiling + elif num_tiles(tiling0) > self._max_num_tiles: + return tiling1 + elif num_tiles(tiling1) > self._max_num_tiles: + return tiling0 + else: + if random.random() < 0.5: + return tiling0 + else: + return tiling1 + return tiling + + new_tiling = plus_minus_one(tiling) + return new_tiling + + def __str__(self): + return f"TilingImageTransform(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, embeddings_per_tile={self._embeddings_per_image}, find_closest_aspect_ratio_fn={self._find_closest_aspect_ratio_fn})" + + +class TileDegradationStrategy(ImageTilingStrategy): + """Strategy for tiling images and video frames, each with their own tiling strategy, while trying to match the + number of tokens left in the sample by reducing the number of tiles if needed. + """ + + # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 + # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 + + def __init__( + self, + image_strategy: ImageTilingStrategy, + video_frame_strategy: ImageTilingStrategy, + embeddings_per_tile: int, + max_num_tiles: int, + tile_degradation_map: dict[int, int] = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1}, + ): + self._image_strategy = image_strategy + self._video_frame_strategy = video_frame_strategy + self._embeddings_per_tile = embeddings_per_tile + self._max_num_tiles = max_num_tiles + self._tile_degradation_map = tile_degradation_map + + def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]: + if isinstance(transform_media.media, ImageMedia): + return self._image_strategy.apply_params(transform_media, **kwargs) + elif isinstance(transform_media.media, VideoFrameMedia): + return self._video_frame_strategy.apply_params(transform_media, **kwargs) + else: + raise ValueError(f"Unsupported media type: {type(transform_media.media)}") + + def compute_params( + self, + media_list: list[ImageMedia | VideoFrameMedia], + num_tokens_available: int | None = None, + max_num_tiles: int | None = None, + **kwargs, + ) -> list[ImageTilingParams]: + # Use provided max_num_tiles or fall back to instance's max_num_tiles + effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles + max_num_tiles_to_use = effective_max_num_tiles + degradation_map = self._tile_degradation_map + + while True: + params = [] + img_num_tiles = [] + for media in media_list: + if isinstance(media, ImageMedia): + media_params = self._image_strategy.compute_params( + [media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs + )[0] + elif isinstance(media, VideoFrameMedia): + max_num_tiles_to_use = 1 + media_params = self._video_frame_strategy.compute_params( + [media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs + )[0] + else: + raise ValueError(f"Unsupported media type: {type(media)}") + img_num_tiles.append(media_params.num_tiles) + params.append(media_params) + if max_num_tiles_to_use == 1 or num_tokens_available is None: + break + if sum(img_num_tiles) * self._embeddings_per_tile > num_tokens_available: + if max_num_tiles_to_use in degradation_map: + max_num_tiles_to_use = degradation_map[max_num_tiles_to_use] + else: + # End of degradation + break + else: + break + return params + + def stack( + self, images: list[torch.Tensor] + ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: + return self._image_strategy.stack(images) + + def __str__(self): + return f"TileDegradationImageTransform(max_num_tiles={self._max_num_tiles}, image_transform={self._image_strategy}, video_frame_transform={self._video_frame_strategy})" + + +@dataclass +class DynamicResolutionParams(ImageTilingParams): + patch_size: tuple[int, int] + + +class DynamicResolutionImageTilingStrategy(ImageTilingStrategy): """Preprocess an image with dynamic resolution for vision transformers. - + This function resizes an image to optimize the number of patches while respecting constraints on minimum/maximum patches, minimum side length, and compatibility with pixel shuffle or convolution merging operations. - + The algorithm works by: 1. Computing the initial patch grid size based on the image dimensions and res_step 2. Scaling the patch grid to fit within the max_patches constraint @@ -350,159 +615,1214 @@ def dynamic_res_preprocess(image, min_patches=1, max_patches=128, res_step=16, f 4. Optionally enforcing a minimum side length constraint 5. Rounding patch dimensions to even numbers for pixel_shuffle/conv_merging compatibility 6. Resizing the image to the computed target dimensions - - Args: - image (PIL.Image): Input image to preprocess. - min_patches (int, optional): Minimum number of patches required. Defaults to 1. - max_patches (int, optional): Maximum number of patches allowed. Defaults to 128. - res_step (int, optional): Resolution step size (patch dimension). Defaults to 16. - factor_max (float, optional): Maximum scaling factor to apply. Defaults to 1.0. - pixel_shuffle (bool, optional): Whether to ensure compatibility with pixel shuffle - operations by rounding to even patch dimensions. Defaults to False. - min_side (int, optional): Minimum side length in pixels. If specified, ensures - at least one side meets this constraint. Defaults to None. - conv_merging (bool, optional): Whether to ensure compatibility with convolution - merging by rounding to even patch dimensions. Defaults to False. - - Returns: - PIL.Image: Resized image with dimensions optimized for patch-based processing. - The output dimensions will be (target_patch_width * res_step, target_patch_height * res_step). - + Note: The function preserves aspect ratio as much as possible while satisfying all constraints. When constraints conflict (e.g., min_side vs max_patches), the function prioritizes staying within max_patches while maximizing the image size. - + Example: >>> from PIL import Image >>> img = Image.open("example.jpg") # 800x600 image - >>> resized_img = dynamic_res_preprocess(img, min_patches=4, max_patches=64, res_step=14) + >>> strategy = DynamicResolutionImageTilingStrategy(vision_model_type="radio", min_patches=4, max_patches=64, res_step=14, get_num_embeddings=lambda x, y: x * y * 2) + >>> params = strategy.compute_params([img]) + >>> img_tensor = strategy.apply_params(params[0]) >>> # Returns image resized to maintain aspect ratio with 4-64 patches of size 14x14 """ - orig_width, orig_height = image.size - closest_patch_height = round(orig_height / res_step + 0.5) - closest_patch_width = round(orig_width / res_step + 0.5) - patches = closest_patch_height * closest_patch_width + def __init__( + self, + vision_model_type: str, + min_num_patches: int, + patch_size: int, + get_num_embeddings: Callable[[int, int], int], + factor_max: float = 1.0, + pixel_shuffle: bool = False, + min_side: int | None = None, + conv_merging: bool = False, + use_thumbnail: bool = False, + thumbnail_size: int = 448, + thumbnail_area_threshold: float = 0.8, + max_num_patches: int = 0, + apply_data_augment: bool = False, + ): + """ + Args: + vision_model_type: Vision model type. + min_num_patches: Minimum number of patches required. Defaults to 1. + max_num_patches: Maximum number of patches allowed. Defaults to 0 (no maximum). + patch_size: Resolution step size (patch dimension). Defaults to 16. + get_num_embeddings: Function to get the number of embeddings from the patch size (width, height). + factor_max: Maximum scaling factor to apply. Defaults to 1.0. + pixel_shuffle: Whether to ensure compatibility with pixel shuffle operations by rounding to even patch + dimensions. Defaults to False. + min_side: Minimum side length in pixels. If specified, ensures at least one side meets this constraint. + Defaults to None. + conv_merging: Whether to ensure compatibility with convolution merging by rounding to even patch dimensions. + Defaults to False. + use_thumbnail: Whether to add a thumbnail image when processing. Defaults to False. + thumbnail_size: Size of the thumbnail image (width and height). Defaults to 448. + thumbnail_area_threshold: Maximum area percentage (0.0-1.0) of the resized image relative to thumbnail area + for which to add a thumbnail. If the resized image area is larger than this threshold of the thumbnail + area, no thumbnail will be added. Defaults to 0.8 (80%). + apply_data_augment: Whether to apply data augmentation to the image. Defaults to False. + """ + assert "radio" in vision_model_type, ( + "Dynamic resolution is only supported for radio models" + ) + self._vision_model_type = vision_model_type + self._min_num_patches = min_num_patches + self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf") + self._patch_size = patch_size + self._get_num_embeddings = get_num_embeddings + self._factor_max = factor_max + self._pixel_shuffle = pixel_shuffle + self._min_side = min_side + self._conv_merging = conv_merging + self._use_thumbnail = use_thumbnail + self._thumbnail_size = thumbnail_size + self._thumbnail_area_threshold = thumbnail_area_threshold + pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] + self._transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)), + T.Normalize(mean=pixel_mean, std=pixel_std), + ] + ) + self._apply_data_augment = apply_data_augment - factor = min(math.sqrt(max_patches / patches), factor_max) - target_patch_height = math.floor(factor * closest_patch_height) - target_patch_width = math.floor(factor * closest_patch_width) + def apply_params(self, params: DynamicResolutionParams, **kwargs) -> list[torch.Tensor]: + # resize the image + resized_img = params.media.value.resize( + ( + params.patch_size[0] * self._patch_size, + params.patch_size[1] * self._patch_size, + ) + ) + processed_images = [resized_img] + + # Add thumbnail if enabled and image area is below threshold + if self._use_thumbnail: + # Calculate areas + resized_area = resized_img.size[0] * resized_img.size[1] + thumbnail_area = self._thumbnail_size * self._thumbnail_size + area_ratio = resized_area / thumbnail_area + + # Only add thumbnail if resized image area is less than threshold % of thumbnail area + if area_ratio < self._thumbnail_area_threshold: + thumbnail_img = params.media.value.resize((self._thumbnail_size, self._thumbnail_size)) + processed_images.append(thumbnail_img) + + return [self._transform(img) for img in processed_images] - if target_patch_height * target_patch_width < min_patches: - up_factor = math.sqrt(min_patches / (target_patch_height * target_patch_width)) - target_patch_height = math.ceil(up_factor * target_patch_height) - target_patch_width = math.ceil(up_factor * target_patch_width) - - if min_side is not None and min(target_patch_width, target_patch_height) * res_step < min_side: - if target_patch_width <= target_patch_height: - up_factor = min_side / (target_patch_width * res_step) - new_patch_height = math.ceil(up_factor * target_patch_height) - new_patch_width = math.ceil(up_factor * target_patch_width) - - if new_patch_height * new_patch_width > max_patches: - # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches - if max(max_patches // new_patch_width, 1) * res_step < min_side: - up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width)) - target_patch_height = math.floor(up_factor * target_patch_height) - target_patch_width = math.floor(up_factor * target_patch_width) - target_patch_width = new_patch_width - target_patch_height = max(max_patches // new_patch_width, 1) - else: - target_patch_height = new_patch_height - target_patch_width = new_patch_width + def process_media( + self, + media: ImageMedia | VideoFrameMedia, + num_tokens_available: int, + data_augment: bool = False, + tiling_augment_prob: float = 0.4, + ) -> DynamicResolutionParams: + """Process a single media item and return its parameters. + + Args: + media: The media item to process + num_tokens_available: Number of tokens available for this media + data_augment: Whether to apply data augmentation to the image. Defaults to False. + Returns: + DynamicResolutionParams for the media + """ + current_num_tokens_available = num_tokens_available + if isinstance(media, ImageMedia): + orig_width, orig_height = media.width, media.height + elif isinstance(media, VideoFrameMedia): + orig_width, orig_height = media.video_width, media.video_height + # current_num_tokens_available = 1024 #TEMP: hack for video else: - up_factor = min_side / (target_patch_height * res_step) - new_patch_height = math.ceil(up_factor * target_patch_height) - new_patch_width = math.ceil(up_factor * target_patch_width) + raise ValueError(f"Unsupported media type: {type(media)}") - if new_patch_height * new_patch_width > max_patches: - # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches - if max(max_patches // new_patch_height, 1) * res_step < min_side: - up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width)) - target_patch_height = math.floor(up_factor * target_patch_height) - target_patch_width = math.floor(up_factor * target_patch_width) + closest_patch_height = round(orig_height / self._patch_size + 0.5) + closest_patch_width = round(orig_width / self._patch_size + 0.5) + patches = closest_patch_height * closest_patch_width + + factor = min(math.sqrt(current_num_tokens_available / patches), self._factor_max) + target_patch_height = math.floor(factor * closest_patch_height) + target_patch_width = math.floor(factor * closest_patch_width) + + # We only consider self._min_num_patches if it is greater than current_num_tokens_available. + if current_num_tokens_available > self._min_num_patches and target_patch_height * target_patch_width < self._min_num_patches: + up_factor = math.sqrt( + self._min_num_patches / (target_patch_height * target_patch_width) + ) + target_patch_height = math.ceil(up_factor * target_patch_height) + target_patch_width = math.ceil(up_factor * target_patch_width) + + if ( + self._min_side is not None + and min(target_patch_width, target_patch_height) * self._patch_size + < self._min_side + ): + if target_patch_width <= target_patch_height: + up_factor = self._min_side / (target_patch_width * self._patch_size) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > current_num_tokens_available: + # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches + if ( + max(current_num_tokens_available // new_patch_width, 1) + * self._patch_size + < self._min_side + ): + up_factor = math.sqrt( + current_num_tokens_available + / (target_patch_height * target_patch_width) + ) + target_patch_height = math.floor( + up_factor * target_patch_height + ) + target_patch_width = math.floor( + up_factor * target_patch_width + ) + target_patch_width = new_patch_width + target_patch_height = max( + current_num_tokens_available // new_patch_width, 1 + ) else: target_patch_height = new_patch_height - target_patch_width = max(max_patches // new_patch_height, 1) + target_patch_width = new_patch_width else: - target_patch_height = new_patch_height - target_patch_width = new_patch_width + up_factor = self._min_side / ( + target_patch_height * self._patch_size + ) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) - # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging) - # or by 4 when BOTH are enabled (two successive 2x reductions) - if pixel_shuffle or conv_merging: - required_divisor = 4 if (pixel_shuffle and conv_merging) else 2 + if new_patch_height * new_patch_width > current_num_tokens_available: + # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches + if ( + max(current_num_tokens_available // new_patch_height, 1) + * self._patch_size + < self._min_side + ): + up_factor = math.sqrt( + current_num_tokens_available + / (target_patch_height * target_patch_width) + ) + target_patch_height = math.floor( + up_factor * target_patch_height + ) + target_patch_width = math.floor( + up_factor * target_patch_width + ) + else: + target_patch_height = new_patch_height + target_patch_width = max( + current_num_tokens_available // new_patch_height, 1 + ) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width - rem_h = target_patch_height % required_divisor - if rem_h != 0: - inc_h = required_divisor - rem_h - if (target_patch_height + inc_h) * target_patch_width <= max_patches: - target_patch_height += inc_h - else: - target_patch_height = max(1, target_patch_height - rem_h) + # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging) + # or by 4 when BOTH are enabled (two successive 2x reductions) + if self._pixel_shuffle or self._conv_merging: + required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2 - rem_w = target_patch_width % required_divisor - if rem_w != 0: - inc_w = required_divisor - rem_w - if target_patch_height * (target_patch_width + inc_w) <= max_patches: - target_patch_width += inc_w - else: - target_patch_width = max(1, target_patch_width - rem_w) - assert target_patch_height * target_patch_width <= max_patches + rem_h = target_patch_height % required_divisor + if rem_h != 0: + inc_h = required_divisor - rem_h + if (target_patch_height + inc_h) * target_patch_width <= current_num_tokens_available: + target_patch_height += inc_h + else: + target_patch_height = max(required_divisor, target_patch_height - rem_h) - #TEMP: hacky way to process video same as in training - if is_video: - # max_patches = 1024 - # min_patches = 512 - target_patch_width = 32 - target_patch_height = 32 + rem_w = target_patch_width % required_divisor + if rem_w != 0: + inc_w = required_divisor - rem_w + if target_patch_height * (target_patch_width + inc_w) <= current_num_tokens_available: + target_patch_width += inc_w + else: + target_patch_width = max(required_divisor, target_patch_width - rem_w) + + if data_augment and self._apply_data_augment and random.random() < tiling_augment_prob: + target_patch_width, target_patch_height = self.augment_resolution(target_patch_width, target_patch_height, current_num_tokens_available) + + #TEMP: hack for video + if isinstance(media, VideoFrameMedia): + target_patch_width = 32 + target_patch_height = 32 + + # Calculate embeddings for the main dynamic resolution image + num_embeddings = self._get_num_embeddings( + target_patch_width * self._patch_size, + target_patch_height * self._patch_size, + ) + + token_count = target_patch_width * target_patch_height + # Add thumbnail embeddings if enabled and image area is below threshold + num_tiles = 1 # Base dynamic resolution image + if self._use_thumbnail: + # Calculate areas + resized_area = (target_patch_width * self._patch_size) * (target_patch_height * self._patch_size) + thumbnail_area = self._thumbnail_size * self._thumbnail_size + area_ratio = resized_area / thumbnail_area + + # Only add thumbnail if resized image area is less than threshold % of thumbnail area + if area_ratio < self._thumbnail_area_threshold: + num_tiles += 1 # Add 1 for thumbnail + # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) + num_embeddings += self._get_num_embeddings(self._thumbnail_size, self._thumbnail_size) + token_count += self._thumbnail_size // self._patch_size * self._thumbnail_size // self._patch_size - # resize the image - resized_img = image.resize((target_patch_width * res_step, target_patch_height * res_step)) + return DynamicResolutionParams( + media=media, + num_tiles=num_tiles, + num_embeddings=num_embeddings, + patch_size=(target_patch_width, target_patch_height), + ), token_count + + def augment_resolution(self, target_patch_width: int, target_patch_height: int, current_num_tokens_available: int) -> tuple[int, int]: - return resized_img + min_num_patch_one_side = 32 - - -# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 -# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 -def _build_transform(input_size, vision_model_type): - if vision_model_type in ("siglip", "internvit", "internvit300M", "radio", "radio-g", "cradio-g"): - pixel_mean, pixel_std = pixel_statistics[vision_model_type] - - transform = T.Compose([ - T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), - T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=pixel_mean, std=pixel_std) - ]) - elif vision_model_type == "clip": - pixel_mean, pixel_std = pixel_statistics[vision_model_type] - - transform = Compose([ - T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), - T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), - T.ToTensor(), - T.Normalize(mean=pixel_mean, std=pixel_std), - ]) - elif vision_model_type.startswith("hf://"): - from megatron.core.models.huggingface.module import get_hf_model_type - - model_type = get_hf_model_type(vision_model_type) - if "siglip" in model_type: - from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor - - processor = SiglipImageProcessor(size={"height": input_size, "width": input_size}) - - def transform(x): - x = x.convert("RGB") if x.mode != "RGB" else x - x = processor(x, return_tensors="pt") - return x["pixel_values"][0] + if random.random() < 0.5: + # Minus one + if target_patch_width <= min_num_patch_one_side and target_patch_height <= min_num_patch_one_side: + return target_patch_width, target_patch_height + elif target_patch_width <= min_num_patch_one_side: + return target_patch_width, target_patch_height - min_num_patch_one_side + elif target_patch_height <= min_num_patch_one_side: + return target_patch_width - min_num_patch_one_side, target_patch_height + else: + if random.random() < 0.5: + return target_patch_width - min_num_patch_one_side, target_patch_height + else: + return target_patch_width, target_patch_height - min_num_patch_one_side else: - raise NotImplementedError(f"image processing not defined for huggingface model {vision_model_type}") - else: - raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}") + # Plus one + if target_patch_width * target_patch_height < current_num_tokens_available: + if random.random() < 0.5: + return target_patch_width + min_num_patch_one_side, target_patch_height + else: + return target_patch_width, target_patch_height + min_num_patch_one_side + return target_patch_width, target_patch_height - return transform + def compute_params( + self, + media_list: list[ImageMedia | VideoFrameMedia], + num_tokens_available: int | None = None, + max_num_tiles: int | None = None, + data_augment: bool = False, + **kwargs, + ) -> list[ImageTilingParams]: + """Compute parameters for all media with iterative token budgeting. + + Args: + media_list: List of media items to process + num_tokens_available: Total number of tokens available across all media + max_num_tiles: Maximum number of tiles (unused in this implementation) + data_augment: Whether to apply data augmentation to the image. Defaults to False. + Returns: + List of ImageTilingParams for each media item + """ + num_tokens_available = num_tokens_available * (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1) + # When the number of available token is too small, allow self._min_num_patches per media and + # let the sample be truncated. + num_tokens_available = max(num_tokens_available, self._min_num_patches * len(media_list)) + + # Clip the number of tokens available per media to be between min and max patches. + num_tokens_available_per_media = [ + max(min(num_tokens_available, self._max_num_patches), self._min_num_patches) + for _ in range(len(media_list))] + + # In theory this could be a while True loop, but in case the process_media method slightly + # changes, I want to make sure we don't get stuck in an infinite loop. + for _ in range(10): + # Step 1: Process each media with current token budget + params = [] + token_counts = [] + + for media, tokens_for_media in zip(media_list, num_tokens_available_per_media): + param, token_count = self.process_media(media, tokens_for_media, data_augment=data_augment) + params.append(param) + token_counts.append(token_count) + + # Step 2: Check if total tokens is within budget + total_tokens = sum(token_counts) + + if total_tokens <= num_tokens_available: + # We're within budget, return the params + return params + + # Step 3: We're over budget, need to scale down + # Calculate scaling factor to get under budget + scaling_factor = num_tokens_available / total_tokens + + # Recalculate token budgets for each media based on scaling + # Each media gets a proportional share of the total budget + scaled_down_num_tokens_available_per_media = [ + max(self._min_num_patches, int(token_count * scaling_factor)) + for token_count in token_counts + ] + scaled_down = any([ + scaled_down_num_tokens_available_per_media[i] < num_tokens_available_per_media[i] + for i in range(len(num_tokens_available_per_media))]) + # If there was not scaling down, we're stuck just use min_num_patches per media, else + # try with the scaled down num_tokens_available_per_media. + if not scaled_down: + num_tokens_available_per_media = [self._min_num_patches] * len(media_list) + else: + num_tokens_available_per_media = scaled_down_num_tokens_available_per_media + return params + + def stack( + self, images: list[torch.Tensor] + ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: + imgs_sizes = torch.tensor( + [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 + ) + + def rearrange_img(x): + py = x.shape[-2] // self._patch_size + px = x.shape[-1] // self._patch_size + x = einops.rearrange( + x, + "c (py yy) (px xx) -> (py px) (c yy xx)", + py=py, + yy=self._patch_size, + px=px, + xx=self._patch_size, + ) + return x + + if len(images) > 0: + imgs = [rearrange_img(img) for img in images] + + current_length = 0 + max_length = 0 + vision_cu_lengths = [0] + for img in imgs: + if max_length < img.shape[0]: + max_length = img.shape[0] + current_length += img.shape[0] + vision_cu_lengths.append(current_length) + + vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) + vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) + + return ( + torch.cat(imgs, dim=0).unsqueeze(0), + imgs_sizes, + vision_cu_lengths, + vision_max_lengths, + ) + else: + return ( + torch.tensor([[0]], dtype=torch.float32), + torch.tensor([[0,0]], dtype=torch.int32), + None, + None, + ) + + def __str__(self): + return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})" + + +@dataclass +class MatchTilingDynamicResolutionParams(ImageTilingParams): + tiling: tuple[int, int] + + +class MatchTilingDynamicResolutionStrategy(ImageTilingStrategy): + """ + Strategy that uses tiling logic to determine optimal image dimensions but processes + the image as a single dynamic resolution image instead of splitting into tiles. + + This combines the aspect ratio optimization from ImageTilingStrategyV1 with the + dynamic resolution processing from DynamicResolutionImageTilingStrategy. + + Also includes tile degradation logic similar to TileDegradationStrategy. + """ + + def __init__( + self, + vision_model_type: str, + tile_size: int, + use_thumbnail: bool, + min_num_tiles: int, + max_num_tiles: int, + embeddings_per_tile: int, + patch_size: int, + get_num_embeddings: Callable[[int, int], int], + find_closest_aspect_ratio_fn=find_closest_aspect_ratio, + pixel_shuffle: bool = False, + conv_merging: bool = False, + tile_degradation_map: dict[int, int] = None, + video_frame_strategy: ImageTilingStrategy = None, + enable_tile_degradation: bool = True, + ): + """ + Args: + vision_model_type: Vision model type (should support dynamic resolution) + tile_size: Size of each tile for tiling calculation + use_thumbnail: Whether tiling logic should include thumbnail + min_num_tiles: Minimum number of tiles for tiling calculation + max_num_tiles: Maximum number of tiles for tiling calculation + embeddings_per_tile: Embeddings per tile for tiling calculation + patch_size: Patch size for dynamic resolution processing + get_num_embeddings: Function to get number of embeddings from dimensions + find_closest_aspect_ratio_fn: Function to find closest aspect ratio + pixel_shuffle: Whether to ensure compatibility with pixel shuffle + conv_merging: Whether to ensure compatibility with convolution merging + tile_degradation_map: Map for degrading tiles when tokens are insufficient + video_frame_strategy: Strategy for processing video frames + enable_tile_degradation: Whether to enable tile degradation (default: True) + """ + assert "radio" in vision_model_type, ( + "MatchTilingDynamicResolution is only supported for radio models" + ) + + self._vision_model_type = vision_model_type + self._tile_size = tile_size + self._use_thumbnail = use_thumbnail + self._min_num_tiles = min_num_tiles + self._max_num_tiles = max_num_tiles + self._embeddings_per_tile = embeddings_per_tile + self._patch_size = patch_size + self._get_num_embeddings = get_num_embeddings + self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn + self._pixel_shuffle = pixel_shuffle + self._conv_merging = conv_merging + self._enable_tile_degradation = enable_tile_degradation + + # Tile degradation logic (similar to TileDegradationStrategy) + if tile_degradation_map is None: + self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1} + else: + self._tile_degradation_map = tile_degradation_map + + # Video frame strategy (similar to TileDegradationStrategy) + if video_frame_strategy is None: + self._video_frame_strategy = NoTilingStrategy( + vision_model_type=vision_model_type, + target_width=tile_size, + target_height=tile_size, + embeddings_per_image=embeddings_per_tile, + ) + else: + self._video_frame_strategy = video_frame_strategy + + # Calculate all possible aspect ratios for each max_num_tiles (borrowed from ImageTilingStrategyV1) + self.target_ratios = { + max_num_tiles: sorted( + set( + (x, y) + for n in range(self._min_num_tiles, max_num_tiles + 1) + for x in range(1, n + 1) + for y in range(1, n + 1) + if x * y <= max_num_tiles and x * y >= self._min_num_tiles + ), + key=lambda x: x[0] * x[1], + ) + for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1) + } + + # Set up transform for dynamic resolution processing + pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] + self._transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std), + ] + ) + + def apply_params(self, params: MatchTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]: + # Handle video frames using the video frame strategy + if isinstance(params.media, VideoFrameMedia): + return self._video_frame_strategy.apply_params(params, **kwargs) + + # Handle images with dynamic resolution processing + image = params.media.value + # Calculate the target width and height (same logic as ImageTilingStrategyV1) + target_width = self._tile_size * params.tiling[0] + target_height = self._tile_size * params.tiling[1] + + # Resize the image to the target dimensions (same as ImageTilingStrategyV1) + resized_img = image.resize((target_width, target_height)) + + # Process as single dynamic resolution image + processed_images = [resized_img] + + # Add thumbnail if use_thumbnail=True and there's more than 1 tile (same as ImageTilingStrategyV1) + blocks = params.tiling[0] * params.tiling[1] + if self._use_thumbnail and blocks != 1: + thumbnail_img = image.resize((self._tile_size, self._tile_size)) + processed_images.append(thumbnail_img) + + return [self._transform(img) for img in processed_images] + + def compute_params( + self, + media_list: list[ImageMedia | VideoFrameMedia], + num_tokens_available: int | None = None, + max_num_tiles: int | None = None, + **kwargs, + ) -> list[MatchTilingDynamicResolutionParams]: + # Implement tile degradation logic similar to TileDegradationStrategy + # Use provided max_num_tiles or fall back to instance's max_num_tiles + # Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value + effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles + effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles) + max_num_tiles_to_use = effective_max_num_tiles + degradation_map = self._tile_degradation_map + + while True: + params = [] + total_embeddings_needed = 0 + + for media in media_list: + if isinstance(media, ImageMedia): + # Use tiling logic for images + img_size = (media.width, media.height) + aspect_ratio = img_size[0] / img_size[1] + + # Find the closest aspect ratio to the target + target_ratios = self.target_ratios[max_num_tiles_to_use] + tiling = self._find_closest_aspect_ratio_fn( + aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size + ) + + # Calculate target dimensions for dynamic resolution processing + target_width = self._tile_size * tiling[0] + target_height = self._tile_size * tiling[1] + num_embeddings = self._get_num_embeddings(target_width, target_height) + + # Account for thumbnail (same logic as ImageTilingStrategyV1) + num_tiles = 1 # Base dynamic resolution image + blocks = tiling[0] * tiling[1] + if self._use_thumbnail and blocks != 1: + num_tiles += 1 # Add 1 for thumbnail + # Add embeddings for thumbnail (tile_size x tile_size) + num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size) + + media_params = MatchTilingDynamicResolutionParams( + media=media, + num_tiles=num_tiles, + num_embeddings=num_embeddings, + tiling=tiling, + ) + elif isinstance(media, VideoFrameMedia): + # Use video frame strategy for video frames (always 1 tile) + video_params = self._video_frame_strategy.compute_params( + [media], 1 * self._embeddings_per_tile + )[0] + media_params = MatchTilingDynamicResolutionParams( + media=media, + num_tiles=video_params.num_tiles, + num_embeddings=video_params.num_embeddings, + tiling=(1, 1), # Video frames always use 1x1 tiling + ) + else: + raise ValueError(f"Unsupported media type: {type(media)}") + + params.append(media_params) + total_embeddings_needed += media_params.num_embeddings + + # Check if we need to degrade (only if degradation is enabled) + if not self._enable_tile_degradation: + break + if max_num_tiles_to_use == 1 or num_tokens_available is None: + break + if total_embeddings_needed > num_tokens_available: + if max_num_tiles_to_use in degradation_map: + max_num_tiles_to_use = degradation_map[max_num_tiles_to_use] + # Recalculate target ratios for the new max_num_tiles_to_use + if max_num_tiles_to_use not in self.target_ratios: + self.target_ratios[max_num_tiles_to_use] = sorted( + set( + (x, y) + for n in range(self._min_num_tiles, max_num_tiles_to_use + 1) + for x in range(1, n + 1) + for y in range(1, n + 1) + if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles + ), + key=lambda x: x[0] * x[1], + ) + else: + # End of degradation + break + else: + break + + return params + + def stack( + self, images: list[torch.Tensor] + ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]: + """Stack images using dynamic resolution approach with sequence packing""" + imgs_sizes = torch.tensor( + [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 + ) + + def rearrange_img(x): + py = x.shape[-2] // self._patch_size + px = x.shape[-1] // self._patch_size + x = einops.rearrange( + x, + "c (py yy) (px xx) -> (py px) (c yy xx)", + py=py, + yy=self._patch_size, + px=px, + xx=self._patch_size, + ) + return x + + if len(images) > 0: + imgs = [rearrange_img(img) for img in images] + + current_length = 0 + max_length = 0 + vision_cu_lengths = [0] + for img in imgs: + if max_length < img.shape[0]: + max_length = img.shape[0] + current_length += img.shape[0] + vision_cu_lengths.append(current_length) + + vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) + vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) + + return ( + torch.cat(imgs, dim=0).unsqueeze(0), + imgs_sizes, + vision_cu_lengths, + vision_max_lengths, + ) + else: + return ( + torch.tensor([[0]], dtype=torch.float32), + torch.tensor([[0,0]], dtype=torch.int32), + None, + None, + ) + + def __str__(self): + return f"MatchTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})" + + +@dataclass +class MaskedTilingDynamicResolutionParams(ImageTilingParams): + tiling: tuple[int, int] + + +class MaskedTilingDynamicResolutionStrategy(ImageTilingStrategy): + """ + Like MatchTilingDynamicResolutionStrategy, but ensures tiles are isolated in the + vision encoder by emitting per-tile packed samples (block-diagonal attention across tiles). + """ + + def __init__( + self, + vision_model_type: str, + tile_size: int, + use_thumbnail: bool, + min_num_tiles: int, + max_num_tiles: int, + embeddings_per_tile: int, + patch_size: int, + get_num_embeddings: Callable[[int, int], int], + find_closest_aspect_ratio_fn=find_closest_aspect_ratio, + pixel_shuffle: bool = False, + conv_merging: bool = False, + tile_degradation_map: dict[int, int] = None, + video_frame_strategy: ImageTilingStrategy = None, + enable_tile_degradation: bool = True, + ): + assert "radio" in vision_model_type, ( + "MaskedTilingDynamicResolution is only supported for radio models" + ) + + self._vision_model_type = vision_model_type + self._tile_size = tile_size + self._use_thumbnail = use_thumbnail + self._min_num_tiles = min_num_tiles + self._max_num_tiles = max_num_tiles + self._embeddings_per_tile = embeddings_per_tile + self._patch_size = patch_size + self._get_num_embeddings = get_num_embeddings + self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn + self._pixel_shuffle = pixel_shuffle + self._conv_merging = conv_merging + self._enable_tile_degradation = enable_tile_degradation + + if tile_degradation_map is None: + self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1} + else: + self._tile_degradation_map = tile_degradation_map + + if video_frame_strategy is None: + self._video_frame_strategy = NoTilingStrategy( + vision_model_type=vision_model_type, + target_width=tile_size, + target_height=tile_size, + embeddings_per_image=embeddings_per_tile, + ) + else: + self._video_frame_strategy = video_frame_strategy + + self.target_ratios = { + max_num_tiles: sorted( + set( + (x, y) + for n in range(self._min_num_tiles, max_num_tiles + 1) + for x in range(1, n + 1) + for y in range(1, n + 1) + if x * y <= max_num_tiles and x * y >= self._min_num_tiles + ), + key=lambda x: x[0] * x[1], + ) + for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1) + } + + pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] + self._transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std), + ] + ) + + def apply_params(self, params: MaskedTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]: + # Handle video frames using the video frame strategy + if isinstance(params.media, VideoFrameMedia): + return self._video_frame_strategy.apply_params(params, **kwargs) + + image = params.media.value + nx, ny = params.tiling + target_width = self._tile_size * nx + target_height = self._tile_size * ny + + resized_img = image.resize((target_width, target_height)) + + processed_images = [] + # Emit per-tile images (each becomes an isolated packed sample later) + for j in range(ny): + for i in range(nx): + box = ( + i * self._tile_size, + j * self._tile_size, + (i + 1) * self._tile_size, + (j + 1) * self._tile_size, + ) + tile_img = resized_img.crop(box) + processed_images.append(tile_img) + + if self._use_thumbnail and (nx * ny) != 1: + thumbnail_img = image.resize((self._tile_size, self._tile_size)) + processed_images.append(thumbnail_img) + + return [self._transform(img) for img in processed_images] + + def compute_params( + self, + media_list: list[ImageMedia | VideoFrameMedia], + num_tokens_available: int | None = None, + max_num_tiles: int | None = None, + data_augment: bool = False, + tiling_augment_prob: float = 0.4, + **kwargs, + ) -> list[MaskedTilingDynamicResolutionParams]: + effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles + effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles) + max_num_tiles_to_use = effective_max_num_tiles + degradation_map = self._tile_degradation_map + + while True: + params = [] + total_embeddings_needed = 0 + + for media in media_list: + if isinstance(media, ImageMedia): + img_size = (media.width, media.height) + aspect_ratio = img_size[0] / img_size[1] + + target_ratios = self.target_ratios[max_num_tiles_to_use] + tiling = self._find_closest_aspect_ratio_fn( + aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size + ) + + # Apply tiling augmentation if enabled + if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob: + tiling = self.augment_tiling(tiling) + + blocks = tiling[0] * tiling[1] + # Each tile is tile_size x tile_size + per_tile_emb = self._get_num_embeddings(self._tile_size, self._tile_size) + num_embeddings = blocks * per_tile_emb + + num_tiles = blocks + if self._use_thumbnail and blocks != 1: + num_tiles += 1 + num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size) + + media_params = MaskedTilingDynamicResolutionParams( + media=media, + num_tiles=num_tiles, + num_embeddings=num_embeddings, + tiling=tiling, + ) + elif isinstance(media, VideoFrameMedia): + video_params = self._video_frame_strategy.compute_params( + [media], 1 * self._embeddings_per_tile + )[0] + media_params = MaskedTilingDynamicResolutionParams( + media=media, + num_tiles=video_params.num_tiles, + num_embeddings=video_params.num_embeddings, + tiling=(1, 1), + ) + else: + raise ValueError(f"Unsupported media type: {type(media)}") + + params.append(media_params) + total_embeddings_needed += media_params.num_embeddings + + if not self._enable_tile_degradation: + break + if max_num_tiles_to_use == 1 or num_tokens_available is None: + break + if total_embeddings_needed > num_tokens_available: + if max_num_tiles_to_use in degradation_map: + max_num_tiles_to_use = degradation_map[max_num_tiles_to_use] + if max_num_tiles_to_use not in self.target_ratios: + self.target_ratios[max_num_tiles_to_use] = sorted( + set( + (x, y) + for n in range(self._min_num_tiles, max_num_tiles_to_use + 1) + for x in range(1, n + 1) + for y in range(1, n + 1) + if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles + ), + key=lambda x: x[0] * x[1], + ) + else: + break + else: + break + + return params + + def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]: + def num_tiles(tiling: tuple[int, int]) -> int: + return tiling[0] * tiling[1] + + def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]: + if random.random() < minus_prob: + # Minus one + if tiling[0] == 1 and tiling[1] == 1: + return tiling + elif tiling[0] == 1: + return (tiling[0], tiling[1] - 1) + elif tiling[1] == 1: + return (tiling[0] - 1, tiling[1]) + else: + if random.random() < 0.5: + return (tiling[0] - 1, tiling[1]) + else: + return (tiling[0], tiling[1] - 1) + else: + # Plus one + if num_tiles(tiling) < self._max_num_tiles: + tiling0 = (tiling[0] + 1, tiling[1]) + tiling1 = (tiling[0], tiling[1] + 1) + if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles: + return tiling + elif num_tiles(tiling0) > self._max_num_tiles: + return tiling1 + elif num_tiles(tiling1) > self._max_num_tiles: + return tiling0 + else: + if random.random() < 0.5: + return tiling0 + else: + return tiling1 + return tiling + + new_tiling = plus_minus_one(tiling) + return new_tiling + + def stack( + self, images: list[torch.Tensor] + ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]: + # Identical to dynamic resolution packing; each tile is already an independent image sample + imgs_sizes = torch.tensor( + [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 + ) + + def rearrange_img(x): + py = x.shape[-2] // self._patch_size + px = x.shape[-1] // self._patch_size + x = einops.rearrange( + x, + "c (py yy) (px xx) -> (py px) (c yy xx)", + py=py, + yy=self._patch_size, + px=px, + xx=self._patch_size, + ) + return x + + if len(images) > 0: + imgs = [rearrange_img(img) for img in images] + + current_length = 0 + max_length = 0 + vision_cu_lengths = [0] + for img in imgs: + if max_length < img.shape[0]: + max_length = img.shape[0] + current_length += img.shape[0] + vision_cu_lengths.append(current_length) + + vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) + vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) + + return ( + torch.cat(imgs, dim=0).unsqueeze(0), + imgs_sizes, + vision_cu_lengths, + vision_max_lengths, + ) + else: + return ( + torch.tensor([[0]], dtype=torch.float32), + torch.tensor([[0,0]], dtype=torch.int32), + None, + None, + ) + + def __str__(self): + return f"MaskedTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})" + +def create_image_tiling_strategy(args): + """ + Create an image tiling strategy based on the provided arguments. + + This function encapsulates the logic for creating the appropriate image tiling strategy + based on the training/evaluation configuration. It can be used by both training (task_encoder) + and evaluation code outside of data_loading/. + + Args: + args: Arguments object with the following relevant attributes: + - img_h, img_w: Image height and width + - patch_dim: Patch dimension + - vision_model_type: Vision model type (e.g., 'radio', 'clip', 'siglip') + - disable_vision_class_token: Whether to disable vision class token + - pixel_shuffle: Whether to use pixel shuffle + - use_tile_tags: Whether to use tile tags + - max_num_tiles: Maximum number of tiles + - tokenizer_prompt_format: Tokenizer prompt format + - image_break_token: Image break token (optional) + - conv_merging: Whether to use convolution merging + - dynamic_resolution: Whether to use dynamic resolution + - match_tiling_dynamic_resolution: Whether to match tiling with dynamic resolution + - use_area_weighted_aspect_ratio: Whether to use area-weighted aspect ratio + - use_thumbnail: Whether to use thumbnail + - dynamic_resolution_min_patches: Minimum number of patches for dynamic resolution + - dynamic_resolution_min_side: Minimum side length for dynamic resolution (optional) + - thumbnail_area_threshold: Thumbnail area threshold (optional) + - use_tiling: Whether to use tiling + + Returns: + ImageTilingStrategy: The created image tiling strategy + """ + from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings + + assert args.img_h == args.img_w, "img_h and img_w must be the same" + + match_tiling_dynamic_resolution = args.match_tiling_dynamic_resolution + masked_tiling_dynamic_resolution = getattr(args, "masked_tiling_dynamic_resolution", False) + dynamic_resolution = args.dynamic_resolution + use_tiling = args.use_tiling + use_area_weighted_aspect_ratio = args.use_area_weighted_aspect_ratio + + if match_tiling_dynamic_resolution: + assert dynamic_resolution, "must enable --dynamic-resolution if using --match-tiling-dynamic-resolution" + assert not use_tiling, "cannot use --use-tiling and --match-tiling-dynamic-resolution together" + if masked_tiling_dynamic_resolution: + assert dynamic_resolution, "must enable --dynamic-resolution if using --masked-tiling-dynamic-resolution" + assert not use_tiling, "cannot use --use-tiling and --masked-tiling-dynamic-resolution together" + assert not match_tiling_dynamic_resolution, "cannot combine --masked-tiling-dynamic-resolution with --match-tiling-dynamic-resolution" + + if dynamic_resolution: + if masked_tiling_dynamic_resolution: + num_image_embeddings_per_tile = get_num_image_embeddings( + img_h=args.img_h, + img_w=args.img_w, + patch_dim=args.patch_dim, + vision_model_type=args.vision_model_type, + disable_vision_class_token=args.disable_vision_class_token, + class_token_len=1, + pixel_shuffle=args.pixel_shuffle, + use_tile_tags=args.use_tile_tags, + max_num_tiles=args.max_num_tiles, + tokenizer_type=args.tokenizer_prompt_format, + use_image_break_token=args.image_break_token is not None, + conv_merging=args.conv_merging, + ) + image_tiling_strategy = MaskedTilingDynamicResolutionStrategy( + vision_model_type=args.vision_model_type, + tile_size=args.img_h, + use_thumbnail=args.use_thumbnail, + min_num_tiles=1, + max_num_tiles=args.max_num_tiles, + embeddings_per_tile=num_image_embeddings_per_tile, + patch_size=args.patch_dim, + get_num_embeddings=lambda width, height: get_num_image_embeddings( + img_h=height, + img_w=width, + patch_dim=args.patch_dim, + vision_model_type=args.vision_model_type, + disable_vision_class_token=args.disable_vision_class_token, + class_token_len=1, + pixel_shuffle=args.pixel_shuffle, + use_tile_tags=args.use_tile_tags, + max_num_tiles=args.max_num_tiles, + tokenizer_type=args.tokenizer_prompt_format, + use_image_break_token=args.image_break_token is not None, + conv_merging=args.conv_merging, + ), + find_closest_aspect_ratio_fn=( + find_closest_area_weighted_aspect_ratio + if use_area_weighted_aspect_ratio + else find_closest_aspect_ratio + ), + pixel_shuffle=args.pixel_shuffle, + conv_merging=args.conv_merging, + ) + elif match_tiling_dynamic_resolution: + num_image_embeddings_per_tile = get_num_image_embeddings( + img_h=args.img_h, + img_w=args.img_w, + patch_dim=args.patch_dim, + vision_model_type=args.vision_model_type, + disable_vision_class_token=args.disable_vision_class_token, + class_token_len=1, + pixel_shuffle=args.pixel_shuffle, + use_tile_tags=args.use_tile_tags, + max_num_tiles=args.max_num_tiles, + tokenizer_type=args.tokenizer_prompt_format, + use_image_break_token=args.image_break_token is not None, + conv_merging=args.conv_merging, + ) + image_tiling_strategy = MatchTilingDynamicResolutionStrategy( + vision_model_type=args.vision_model_type, + tile_size=args.img_h, + use_thumbnail=args.use_thumbnail, + min_num_tiles=1, + max_num_tiles=args.max_num_tiles, + embeddings_per_tile=num_image_embeddings_per_tile, + patch_size=args.patch_dim, + get_num_embeddings=lambda width, height: get_num_image_embeddings( + img_h=height, + img_w=width, + patch_dim=args.patch_dim, + vision_model_type=args.vision_model_type, + disable_vision_class_token=args.disable_vision_class_token, + class_token_len=1, + pixel_shuffle=args.pixel_shuffle, + use_tile_tags=args.use_tile_tags, + max_num_tiles=args.max_num_tiles, + tokenizer_type=args.tokenizer_prompt_format, + use_image_break_token=args.image_break_token is not None, + conv_merging=args.conv_merging, + ), + find_closest_aspect_ratio_fn=( + find_closest_area_weighted_aspect_ratio + if use_area_weighted_aspect_ratio + else find_closest_aspect_ratio + ), + pixel_shuffle=args.pixel_shuffle, + conv_merging=args.conv_merging, + ) + else: + image_tiling_strategy = DynamicResolutionImageTilingStrategy( + vision_model_type=args.vision_model_type, + min_num_patches=args.dynamic_resolution_min_patches, + patch_size=args.patch_dim, + get_num_embeddings=lambda width, height: get_num_image_embeddings( + img_h=height, + img_w=width, + patch_dim=args.patch_dim, + vision_model_type=args.vision_model_type, + disable_vision_class_token=args.disable_vision_class_token, + class_token_len=1, + pixel_shuffle=args.pixel_shuffle, + use_tile_tags=args.use_tile_tags, + max_num_tiles=args.max_num_tiles, + tokenizer_type=args.tokenizer_prompt_format, + use_image_break_token=args.image_break_token is not None, + conv_merging=args.conv_merging, + ), + pixel_shuffle=args.pixel_shuffle, + min_side=args.dynamic_resolution_min_side, + conv_merging=args.conv_merging, + use_thumbnail=args.use_thumbnail, + thumbnail_size=args.img_h, + thumbnail_area_threshold=args.thumbnail_area_threshold, + max_num_patches=args.dynamic_resolution_max_patches, + apply_data_augment=args.apply_data_augment, + ) + else: + num_image_embeddings_per_tile = get_num_image_embeddings( + img_h=args.img_h, + img_w=args.img_w, + patch_dim=args.patch_dim, + vision_model_type=args.vision_model_type, + disable_vision_class_token=args.disable_vision_class_token, + class_token_len=1, + pixel_shuffle=args.pixel_shuffle, + use_tile_tags=args.use_tile_tags, + max_num_tiles=args.max_num_tiles, + tokenizer_type=args.tokenizer_prompt_format, + use_image_break_token=args.image_break_token is not None, + conv_merging=args.conv_merging, + ) + if use_tiling: + image_strategy = ImageTilingStrategyV1( + vision_model_type=args.vision_model_type, + tile_size=args.img_h, + use_thumbnail=args.use_thumbnail, + min_num_tiles=1, + max_num_tiles=args.max_num_tiles, + embeddings_per_tile=num_image_embeddings_per_tile, + find_closest_aspect_ratio_fn=( + find_closest_area_weighted_aspect_ratio + if use_area_weighted_aspect_ratio + else find_closest_aspect_ratio + ), + ) + else: + image_strategy = NoTilingStrategy( + vision_model_type=args.vision_model_type, + embeddings_per_image=num_image_embeddings_per_tile, + target_width=args.img_w, + target_height=args.img_h, + ) + image_tiling_strategy = TileDegradationStrategy( + image_strategy=image_strategy, + video_frame_strategy=NoTilingStrategy( + vision_model_type=args.vision_model_type, + embeddings_per_image=num_image_embeddings_per_tile, + target_width=args.img_w, + target_height=args.img_h, + ), + embeddings_per_tile=num_image_embeddings_per_tile, + max_num_tiles=args.max_num_tiles, + ) + + return image_tiling_strategy From 1bceb28678fcf6f7a5c39cd082222023b393e5e5 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Thu, 11 Dec 2025 17:03:35 +0200 Subject: [PATCH 03/10] import image processing into model_executor/models/nano_nemotron_vl.py --- .../model_executor/models/nano_nemotron_vl.py | 489 +++++++++++++++++- 1 file changed, 488 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 6dfab595e5b92..1fbbf06d0a695 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -7,10 +7,14 @@ # LICENSE is in root directory. # -------------------------------------------------------- +import math +import random +from dataclasses import dataclass + import copy import warnings from abc import ABC, abstractmethod -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Iterable, Mapping, Sequence, Callable from typing import Annotated, Any, Literal, TypeAlias, TypeVar import numpy.typing as npt @@ -20,6 +24,7 @@ import torch.nn as nn import torchvision.transforms as T from PIL import Image from transformers import BatchFeature, PretrainedConfig, TensorType +import einops from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions @@ -84,6 +89,488 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely # Alternative: Set a specific higher limit # Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels +IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] +IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] +SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5] +SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5] +CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] +CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] +RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060] +RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250] + + +pixel_statistics = { + "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), + "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), + "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD), + "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), + "radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), +} + + +@dataclass +class DynamicResolutionParams: + image: Image.Image + num_tiles: int + num_embeddings: int + patch_size: tuple[int, int] + + +class DynamicResolutionImageTilingStrategy: + def __init__( + self, + vision_model_type: str, + min_num_patches: int, + patch_size: int, + get_num_embeddings: Callable[[int, int], int], + factor_max: float = 1.0, + pixel_shuffle: bool = False, + min_side: int | None = None, + conv_merging: bool = False, + use_thumbnail: bool = False, + thumbnail_size: int = 448, + thumbnail_area_threshold: float = 0.8, + max_num_patches: int = 0, + apply_data_augment: bool = False, + ): + assert "radio" in vision_model_type, ( + "Dynamic resolution is only supported for radio models" + ) + self._vision_model_type = vision_model_type + self._min_num_patches = min_num_patches + self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf") + self._patch_size = patch_size + self._get_num_embeddings = get_num_embeddings + self._factor_max = factor_max + self._pixel_shuffle = pixel_shuffle + self._min_side = min_side + self._conv_merging = conv_merging + self._use_thumbnail = use_thumbnail + self._thumbnail_size = thumbnail_size + self._thumbnail_area_threshold = thumbnail_area_threshold + pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] + self._transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)), + T.Normalize(mean=pixel_mean, std=pixel_std), + ] + ) + self._apply_data_augment = apply_data_augment + + def apply_params( + self, params: DynamicResolutionParams, **kwargs + ) -> list[torch.Tensor]: + # resize the image + resized_img = params.image.resize( + ( + params.patch_size[0] * self._patch_size, + params.patch_size[1] * self._patch_size, + ) + ) + processed_images = [resized_img] + + # Add thumbnail if enabled and image area is below threshold + if self._use_thumbnail: + # Calculate areas + resized_area = resized_img.size[0] * resized_img.size[1] + thumbnail_area = self._thumbnail_size * self._thumbnail_size + area_ratio = resized_area / thumbnail_area + + # Only add thumbnail if resized image area is less than threshold % of thumbnail area + if area_ratio < self._thumbnail_area_threshold: + thumbnail_img = params.image.resize( + (self._thumbnail_size, self._thumbnail_size) + ) + processed_images.append(thumbnail_img) + + return [self._transform(img) for img in processed_images] + + def process_media( + self, + image: Image.Image, + num_tokens_available: int, + data_augment: bool = False, + tiling_augment_prob: float = 0.4, + ) -> DynamicResolutionParams: + """Process a single media item and return its parameters. + + Args: + media: The media item to process + num_tokens_available: Number of tokens available for this media + data_augment: Whether to apply data augmentation to the image. Defaults to False. + Returns: + DynamicResolutionParams for the media + """ + current_num_tokens_available = num_tokens_available + assert isinstance(image, Image.Image), ( + "Dynamic resolution is only supported for image media" + ) + orig_width, orig_height = image.width, image.height + + closest_patch_height = round(orig_height / self._patch_size + 0.5) + closest_patch_width = round(orig_width / self._patch_size + 0.5) + patches = closest_patch_height * closest_patch_width + + factor = min( + math.sqrt(current_num_tokens_available / patches), self._factor_max + ) + target_patch_height = math.floor(factor * closest_patch_height) + target_patch_width = math.floor(factor * closest_patch_width) + + # We only consider self._min_num_patches if it is greater than current_num_tokens_available. + if ( + current_num_tokens_available > self._min_num_patches + and target_patch_height * target_patch_width < self._min_num_patches + ): + up_factor = math.sqrt( + self._min_num_patches / (target_patch_height * target_patch_width) + ) + target_patch_height = math.ceil(up_factor * target_patch_height) + target_patch_width = math.ceil(up_factor * target_patch_width) + + if ( + self._min_side is not None + and min(target_patch_width, target_patch_height) * self._patch_size + < self._min_side + ): + if target_patch_width <= target_patch_height: + up_factor = self._min_side / (target_patch_width * self._patch_size) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > current_num_tokens_available: + # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches + if ( + max(current_num_tokens_available // new_patch_width, 1) + * self._patch_size + < self._min_side + ): + up_factor = math.sqrt( + current_num_tokens_available + / (target_patch_height * target_patch_width) + ) + target_patch_height = math.floor( + up_factor * target_patch_height + ) + target_patch_width = math.floor(up_factor * target_patch_width) + target_patch_width = new_patch_width + target_patch_height = max( + current_num_tokens_available // new_patch_width, 1 + ) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width + else: + up_factor = self._min_side / (target_patch_height * self._patch_size) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > current_num_tokens_available: + # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches + if ( + max(current_num_tokens_available // new_patch_height, 1) + * self._patch_size + < self._min_side + ): + up_factor = math.sqrt( + current_num_tokens_available + / (target_patch_height * target_patch_width) + ) + target_patch_height = math.floor( + up_factor * target_patch_height + ) + target_patch_width = math.floor(up_factor * target_patch_width) + else: + target_patch_height = new_patch_height + target_patch_width = max( + current_num_tokens_available // new_patch_height, 1 + ) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width + + # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging) + # or by 4 when BOTH are enabled (two successive 2x reductions) + if self._pixel_shuffle or self._conv_merging: + required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2 + + rem_h = target_patch_height % required_divisor + if rem_h != 0: + inc_h = required_divisor - rem_h + if ( + target_patch_height + inc_h + ) * target_patch_width <= current_num_tokens_available: + target_patch_height += inc_h + else: + target_patch_height = max( + required_divisor, target_patch_height - rem_h + ) + + rem_w = target_patch_width % required_divisor + if rem_w != 0: + inc_w = required_divisor - rem_w + if ( + target_patch_height * (target_patch_width + inc_w) + <= current_num_tokens_available + ): + target_patch_width += inc_w + else: + target_patch_width = max( + required_divisor, target_patch_width - rem_w + ) + + if ( + data_augment + and self._apply_data_augment + and random.random() < tiling_augment_prob + ): + target_patch_width, target_patch_height = self.augment_resolution( + target_patch_width, target_patch_height, current_num_tokens_available + ) + + assert isinstance(image, Image.Image), ( + "Dynamic resolution is only supported for image media" + ) + + # Calculate embeddings for the main dynamic resolution image + num_embeddings = self._get_num_embeddings( + target_patch_width * self._patch_size, + target_patch_height * self._patch_size, + ) + + token_count = target_patch_width * target_patch_height + + # Add thumbnail embeddings if enabled and image area is below threshold + num_tiles = 1 # Base dynamic resolution image + if self._use_thumbnail: + # Calculate areas + resized_area = (target_patch_width * self._patch_size) * ( + target_patch_height * self._patch_size + ) + thumbnail_area = self._thumbnail_size * self._thumbnail_size + area_ratio = resized_area / thumbnail_area + + # Only add thumbnail if resized image area is less than threshold % of thumbnail area + if area_ratio < self._thumbnail_area_threshold: + num_tiles += 1 # Add 1 for thumbnail + # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) + num_embeddings += self._get_num_embeddings( + self._thumbnail_size, self._thumbnail_size + ) + token_count += ( + self._thumbnail_size + // self._patch_size + * self._thumbnail_size + // self._patch_size + ) + + return DynamicResolutionParams( + image=image, + num_tiles=num_tiles, + num_embeddings=num_embeddings, + patch_size=(target_patch_width, target_patch_height), + ), token_count + + def augment_resolution( + self, + target_patch_width: int, + target_patch_height: int, + current_num_tokens_available: int, + ) -> tuple[int, int]: + min_num_patch_one_side = 32 + + if random.random() < 0.5: + # Minus one + if ( + target_patch_width <= min_num_patch_one_side + and target_patch_height <= min_num_patch_one_side + ): + return target_patch_width, target_patch_height + elif target_patch_width <= min_num_patch_one_side: + return target_patch_width, target_patch_height - min_num_patch_one_side + elif target_patch_height <= min_num_patch_one_side: + return target_patch_width - min_num_patch_one_side, target_patch_height + else: + if random.random() < 0.5: + return ( + target_patch_width - min_num_patch_one_side, + target_patch_height, + ) + else: + return ( + target_patch_width, + target_patch_height - min_num_patch_one_side, + ) + else: + # Plus one + if target_patch_width * target_patch_height < current_num_tokens_available: + if random.random() < 0.5: + return ( + target_patch_width + min_num_patch_one_side, + target_patch_height, + ) + else: + return ( + target_patch_width, + target_patch_height + min_num_patch_one_side, + ) + return target_patch_width, target_patch_height + + def compute_params( + self, + media_list: list[Image.Image], + num_tokens_available: int | None = None, + max_num_tiles: int | None = None, + data_augment: bool = False, + **kwargs, + ) -> list[DynamicResolutionParams]: + """Compute parameters for all media with iterative token budgeting. + + Args: + media_list: List of media items to process + num_tokens_available: Total number of tokens available across all media + max_num_tiles: Maximum number of tiles (unused in this implementation) + data_augment: Whether to apply data augmentation to the image. Defaults to False. + Returns: + List of ImageTilingParams for each media item + """ + num_tokens_available = ( + num_tokens_available + * (4 if self._pixel_shuffle else 1) + * (4 if self._conv_merging else 1) + ) + # When the number of available token is too small, allow self._min_num_patches per media and + # let the sample be truncated. + num_tokens_available = max( + num_tokens_available, self._min_num_patches * len(media_list) + ) + + # Clip the number of tokens available per media to be between min and max patches. + num_tokens_available_per_media = [ + max(min(num_tokens_available, self._max_num_patches), self._min_num_patches) + for _ in range(len(media_list)) + ] + + # In theory this could be a while True loop, but in case the process_media method slightly + # changes, I want to make sure we don't get stuck in an infinite loop. + for _ in range(10): + # Step 1: Process each media with current token budget + params = [] + token_counts = [] + + for media, tokens_for_media in zip( + media_list, num_tokens_available_per_media + ): + param, token_count = self.process_media( + media, tokens_for_media, data_augment=data_augment + ) + params.append(param) + token_counts.append(token_count) + + # Step 2: Check if total tokens is within budget + total_tokens = sum(token_counts) + + if total_tokens <= num_tokens_available: + # We're within budget, return the params + return params + + # Step 3: We're over budget, need to scale down + # Calculate scaling factor to get under budget + scaling_factor = num_tokens_available / total_tokens + + # Recalculate token budgets for each media based on scaling + # Each media gets a proportional share of the total budget + scaled_down_num_tokens_available_per_media = [ + max(self._min_num_patches, int(token_count * scaling_factor)) + for token_count in token_counts + ] + scaled_down = any( + [ + scaled_down_num_tokens_available_per_media[i] + < num_tokens_available_per_media[i] + for i in range(len(num_tokens_available_per_media)) + ] + ) + # If there was not scaling down, we're stuck just use min_num_patches per media, else + # try with the scaled down num_tokens_available_per_media. + if not scaled_down: + num_tokens_available_per_media = [self._min_num_patches] * len( + media_list + ) + else: + num_tokens_available_per_media = ( + scaled_down_num_tokens_available_per_media + ) + return params + + def stack( + self, images: list[torch.Tensor] + ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: + imgs_sizes = torch.tensor( + [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 + ) + + def rearrange_img(x): + py = x.shape[-2] // self._patch_size + px = x.shape[-1] // self._patch_size + x = einops.rearrange( + x, + "c (py yy) (px xx) -> (py px) (c yy xx)", + py=py, + yy=self._patch_size, + px=px, + xx=self._patch_size, + ) + return x + + if len(images) > 0: + imgs = [rearrange_img(img) for img in images] + + current_length = 0 + max_length = 0 + vision_cu_lengths = [0] + for img in imgs: + if max_length < img.shape[0]: + max_length = img.shape[0] + current_length += img.shape[0] + vision_cu_lengths.append(current_length) + + vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) + vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) + + return ( + torch.cat(imgs, dim=0).unsqueeze(0), + imgs_sizes, + vision_cu_lengths, + vision_max_lengths, + ) + else: + return ( + torch.tensor([[0]], dtype=torch.float32), + torch.tensor([[0, 0]], dtype=torch.int32), + None, + None, + ) + + def __str__(self): + return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})" + + + +image_tiling_strategy = DynamicResolutionImageTilingStrategy( + vision_model_type="radio", + min_num_patches=4, + patch_size=16, + get_num_embeddings=lambda x, y: x * y * 2, + max_num_patches=64, +) + + IMG_START = "" IMG_END = "" IMG_CONTEXT = "" From 4e558858b8d0ffeb1d783e6ea4bc75432bb57bc5 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Thu, 11 Dec 2025 17:32:08 +0200 Subject: [PATCH 04/10] reformat commits --- .../model_executor/models/nano_nemotron_vl.py | 56 ++++++++++++------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 1fbbf06d0a695..82423891c11c7 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -7,16 +7,16 @@ # LICENSE is in root directory. # -------------------------------------------------------- +import copy import math import random -from dataclasses import dataclass - -import copy import warnings from abc import ABC, abstractmethod -from collections.abc import Iterable, Mapping, Sequence, Callable +from collections.abc import Callable, Iterable, Mapping, Sequence +from dataclasses import dataclass from typing import Annotated, Any, Literal, TypeAlias, TypeVar +import einops import numpy.typing as npt import regex as re import torch @@ -24,7 +24,6 @@ import torch.nn as nn import torchvision.transforms as T from PIL import Image from transformers import BatchFeature, PretrainedConfig, TensorType -import einops from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions @@ -181,7 +180,8 @@ class DynamicResolutionImageTilingStrategy: thumbnail_area = self._thumbnail_size * self._thumbnail_size area_ratio = resized_area / thumbnail_area - # Only add thumbnail if resized image area is less than threshold % of thumbnail area + # Only add thumbnail if resized image area is less than threshold % of + # thumbnail area if area_ratio < self._thumbnail_area_threshold: thumbnail_img = params.image.resize( (self._thumbnail_size, self._thumbnail_size) @@ -198,11 +198,11 @@ class DynamicResolutionImageTilingStrategy: tiling_augment_prob: float = 0.4, ) -> DynamicResolutionParams: """Process a single media item and return its parameters. - Args: media: The media item to process num_tokens_available: Number of tokens available for this media - data_augment: Whether to apply data augmentation to the image. Defaults to False. + data_augment: Whether to apply data augmentation to the image. Defaults to + False. Returns: DynamicResolutionParams for the media """ @@ -222,7 +222,8 @@ class DynamicResolutionImageTilingStrategy: target_patch_height = math.floor(factor * closest_patch_height) target_patch_width = math.floor(factor * closest_patch_width) - # We only consider self._min_num_patches if it is greater than current_num_tokens_available. + # We only consider self._min_num_patches if it is greater than + # current_num_tokens_available. if ( current_num_tokens_available > self._min_num_patches and target_patch_height * target_patch_width < self._min_num_patches @@ -244,7 +245,8 @@ class DynamicResolutionImageTilingStrategy: new_patch_width = math.ceil(up_factor * target_patch_width) if new_patch_height * new_patch_width > current_num_tokens_available: - # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches + # If only one side can be min_side, make as big as possible at + # native aspect ratio while staying below max_patches if ( max(current_num_tokens_available // new_patch_width, 1) * self._patch_size @@ -271,7 +273,8 @@ class DynamicResolutionImageTilingStrategy: new_patch_width = math.ceil(up_factor * target_patch_width) if new_patch_height * new_patch_width > current_num_tokens_available: - # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches + # If only one side can be min_side, make as big as possible at + # native aspect ratio while staying below max_patches if ( max(current_num_tokens_available // new_patch_height, 1) * self._patch_size @@ -355,7 +358,8 @@ class DynamicResolutionImageTilingStrategy: thumbnail_area = self._thumbnail_size * self._thumbnail_size area_ratio = resized_area / thumbnail_area - # Only add thumbnail if resized image area is less than threshold % of thumbnail area + # Only add thumbnail if resized image area is less than threshold % of + # thumbnail area if area_ratio < self._thumbnail_area_threshold: num_tiles += 1 # Add 1 for thumbnail # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) @@ -435,7 +439,8 @@ class DynamicResolutionImageTilingStrategy: media_list: List of media items to process num_tokens_available: Total number of tokens available across all media max_num_tiles: Maximum number of tiles (unused in this implementation) - data_augment: Whether to apply data augmentation to the image. Defaults to False. + data_augment: Whether to apply data augmentation to the image. Defaults to + False. Returns: List of ImageTilingParams for each media item """ @@ -444,19 +449,21 @@ class DynamicResolutionImageTilingStrategy: * (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1) ) - # When the number of available token is too small, allow self._min_num_patches per media and - # let the sample be truncated. + # When the number of available token is too small, allow self._min_num_patches + # per media and let the sample be truncated. num_tokens_available = max( num_tokens_available, self._min_num_patches * len(media_list) ) - # Clip the number of tokens available per media to be between min and max patches. + # Clip the number of tokens available per media to be between min and max + # patches. num_tokens_available_per_media = [ max(min(num_tokens_available, self._max_num_patches), self._min_num_patches) for _ in range(len(media_list)) ] - # In theory this could be a while True loop, but in case the process_media method slightly + # In theory this could be a while True loop, but in case the process_media + # method slightly # changes, I want to make sure we don't get stuck in an infinite loop. for _ in range(10): # Step 1: Process each media with current token budget @@ -496,8 +503,8 @@ class DynamicResolutionImageTilingStrategy: for i in range(len(num_tokens_available_per_media)) ] ) - # If there was not scaling down, we're stuck just use min_num_patches per media, else - # try with the scaled down num_tokens_available_per_media. + # If there was not scaling down, we're stuck just use min_num_patches per + # media, else try with the scaled down num_tokens_available_per_media. if not scaled_down: num_tokens_available_per_media = [self._min_num_patches] * len( media_list @@ -558,8 +565,15 @@ class DynamicResolutionImageTilingStrategy: ) def __str__(self): - return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})" - + return f"DynamicResolutionImageTransform(\ + vision_model_type={self._vision_model_type}, \ + min_num_patches={self._min_num_patches}, \ + patch_size={self._patch_size}, \ + pixel_shuffle={self._pixel_shuffle}, \ + conv_merging={self._conv_merging}, \ + use_thumbnail={self._use_thumbnail}, \ + thumbnail_size={self._thumbnail_size}, \ + thumbnail_area_threshold={self._thumbnail_area_threshold})" image_tiling_strategy = DynamicResolutionImageTilingStrategy( From 7b55a619e4dff16dba7f6e6c3e1b2c278fd68290 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Thu, 11 Dec 2025 18:04:01 +0200 Subject: [PATCH 05/10] rewire --- image_processing.py | 1828 ----------------- .../model_executor/models/nano_nemotron_vl.py | 16 +- 2 files changed, 8 insertions(+), 1836 deletions(-) delete mode 100644 image_processing.py diff --git a/image_processing.py b/image_processing.py deleted file mode 100644 index 5d43e764acadd..0000000000000 --- a/image_processing.py +++ /dev/null @@ -1,1828 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE. -from abc import ABC, abstractmethod -from dataclasses import dataclass -import math -from typing import Callable, Optional -import numpy as np -import random -from PIL import Image -import albumentations as A - -import einops -import torch -from torchvision import transforms as T -from torchvision.transforms import Compose -from torchvision.transforms.functional import InterpolationMode - -from data_loading.conversation_sample import ( - ImageMedia, - VideoFrameMedia, -) - -IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] -IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] -SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5] -SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5] -CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] -CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] -RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060] -RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250] - - -pixel_statistics = { - "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), - "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), - "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD), - "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), - "radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), -} - - -# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685 -# Copyright (c) 2023 OpenGVLab. -def find_closest_aspect_ratio( - aspect_ratio: float, - target_ratios: list[tuple[int, int]], - width: int, - height: int, - image_size: int, -) -> tuple[int, int]: - best_ratio_diff = float("inf") - best_ratio = (1, 1) - area = width * height - for ratio in target_ratios: - target_aspect_ratio = ratio[0] / ratio[1] - ratio_diff = abs(aspect_ratio - target_aspect_ratio) - if ratio_diff < best_ratio_diff: - best_ratio_diff = ratio_diff - best_ratio = ratio - elif ratio_diff == best_ratio_diff: - if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: - best_ratio = ratio - return best_ratio - - -def find_closest_area_weighted_aspect_ratio( - aspect_ratio: float, - target_ratios: list[tuple[int, int]], - width: int, - height: int, - image_size: int, -): - """ - Find the best number of tiles based on the aspect ratio and the area covered by the tiles. - """ - best_factor = float("-inf") - best_ratio = (1, 1) - area = width * height - for ratio in target_ratios: - target_aspect_ratio = ratio[0] / ratio[1] - factor_based_on_area_n_ratio = min( - (ratio[0] * ratio[1] * image_size * image_size) / area, 0.6 - ) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio) - if factor_based_on_area_n_ratio > best_factor: - best_factor = factor_based_on_area_n_ratio - best_ratio = ratio - return best_ratio - - -# Mike's optimized ToTensor. -def _fast_to_tensor(pic) -> torch.Tensor: - np_img = np.array(pic, copy=False) - img = torch.from_numpy(np_img) - img = img.permute(2, 0, 1) # HWC to CHW - fp_img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format) - fp_img.div_(255) - return fp_img - - -@dataclass -class ImageTilingParams: - media: ImageMedia | VideoFrameMedia - num_tiles: int - num_embeddings: int - - -class ImageTilingStrategy(ABC): - """ - Base class for image tiling strategies. - A tiling strategy is a function that takes a list of media and returns a list of image tiling parameters. - These can then be used to apply the tiling to the media. - - Subclasses must implement the `compute_params` and `apply_params` methods. - - The `transform` method is a convenience method that computes the transformation parameters and applies the transformation to the media. - - """ - - def transform( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: int | None = None, - ) -> list[torch.Tensor]: - """ - Transform the media and compute the transformation parameters. - """ - transform_media_list = self.compute_params(media_list, num_tokens_available) - return [ - self.apply_params(transform_media, **kwargs) - for transform_media in transform_media_list - ] - - @abstractmethod - def compute_params( - self, media_list: list[ImageMedia | VideoFrameMedia], num_tokens_available: int, max_num_tiles: int | None = None, **kwargs - ) -> list[ImageTilingParams]: - """ - Compute the transformation parameters and the number of tokens to use for the media. - - Args: - media_list: List of media to transform - num_tokens_available: Number of tokens available for all media - max_num_tiles: Maximum number of tiles allowed (optional, defaults to instance's max_num_tiles if not provided) - - Returns: - list of transformation parameters with the media - """ - ... - - @abstractmethod - def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]: - """ - Apply the transformation parameters to the media. - - Args: - transform_media: The media to apply the transformation to - - Returns: - list of transformed media tensors - """ - ... - - @abstractmethod - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]: - """ - Stack the images into a single tensor. - - Args: - media_list: List of images to stack - - Returns: - tuple of (stacked media, image sizes, vision cu lengths, vision max lengths) - """ - ... - - -class _FixedSizeStrategy(ImageTilingStrategy): - """ - Base class for fixed size image tiling strategies. - """ - - def __init__( - self, - vision_model_type: str, - target_width: int, - target_height: int, - embeddings_per_image: int, - ): - self._vision_model_type = vision_model_type - self._target_width = target_width - self._target_height = target_height - self._embeddings_per_image = embeddings_per_image - self._transform = self._build_transform( - (target_width, target_height), vision_model_type - ) - - # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 - # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 - @staticmethod - def _build_transform(target_size: tuple[int, int], vision_model_type: str): - """ - Build a transform for a given vision model type and target size. - """ - if vision_model_type in ("siglip", "internvit", "radio", "radio-g", "cradio-g"): - pixel_mean, pixel_std = pixel_statistics[vision_model_type] - - transform = T.Compose( - [ - T.Lambda( - lambda img: img.convert("RGB") if img.mode != "RGB" else img - ), - T.Resize( - (target_size[1], target_size[0]), - interpolation=InterpolationMode.BICUBIC, - ), - T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)), - T.Normalize(mean=pixel_mean, std=pixel_std), - ] - ) - # From the official CLIP repo. - elif vision_model_type == "clip": - pixel_mean, pixel_std = pixel_statistics[vision_model_type] - - transform = Compose( - [ - T.Resize( - (target_size[1], target_size[0]), - interpolation=InterpolationMode.BICUBIC, - ), - T.Lambda( - lambda img: img.convert("RGB") if img.mode != "RGB" else img - ), - T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)), - T.Normalize(mean=pixel_mean, std=pixel_std), - ] - ) - elif vision_model_type.startswith("hf://"): - from megatron.core.models.huggingface.module import get_hf_model_type - - model_type = get_hf_model_type(vision_model_type) - if "siglip" in model_type: - from transformers.models.siglip.image_processing_siglip import ( - SiglipImageProcessor, - ) - - processor = SiglipImageProcessor( - size={"height": target_size[1], "width": target_size[0]} - ) - - def transform(x): - x = x.convert("RGB") if x.mode != "RGB" else x - x = processor(x, return_tensors="pt") - return x["pixel_values"][0] - else: - raise NotImplementedError( - f"image processing not defined for huggingface model {vision_model_type}" - ) - else: - raise NotImplementedError( - f"image processing not defined for vision model {vision_model_type}" - ) - - return transform - - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: - return ( - torch.stack(images) if len(images) > 0 else None, - torch.tensor( - [(img.shape[1], img.shape[2]) for img in images], dtype=torch.int32 - ) if len(images) > 0 else None, - None, - None, - ) - - -class NoTilingStrategy(_FixedSizeStrategy): - """ - A simple image transformation that resizes the image to the target width and height. - """ - - def __init__( - self, - vision_model_type: str, - target_width: int, - target_height: int, - embeddings_per_image: int, - ): - super().__init__( - vision_model_type=vision_model_type, - target_width=target_width, - target_height=target_height, - embeddings_per_image=embeddings_per_image, - ) - - def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]: - return [self._transform(transform_media.media.value)] - - def compute_params( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: Optional[int] = None, - max_num_tiles: int | None = None, - **kwargs, - ) -> list[ImageTilingParams]: - return [ - ImageTilingParams( - media=media, num_tiles=1, num_embeddings=self._embeddings_per_image - ) - for media in media_list - ] - - def __str__(self): - return f"SimpleImageTransform(vision_model_type={self._vision_model_type}, num_tokens_per_image={self._embeddings_per_image})" - - -@dataclass -class ImageTilingParamsV1(ImageTilingParams): - tiling: tuple[int, int] - - -class ImageTilingStrategyV1(_FixedSizeStrategy): - """Tiling image transformation. - - This transformation splits the image into a grid of tiles and applies the transformation to each tile. - """ - - # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 - # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 - - def __init__( - self, - vision_model_type: str, - tile_size: int, - use_thumbnail: bool, - min_num_tiles: int, - max_num_tiles: int, - embeddings_per_tile: int, - find_closest_aspect_ratio_fn=find_closest_aspect_ratio, - ): - super().__init__( - vision_model_type=vision_model_type, - target_width=tile_size, - target_height=tile_size, - embeddings_per_image=embeddings_per_tile, - ) - - # print(f"Transformation params: {vision_model_type=}, {use_tiling=}, {tile_size=}, {use_thumbnail=}, {augment=}, {min_num_tiles=}, {max_num_tiles=}, {find_closest_aspect_ratio_fn=}") - self._tile_size = tile_size - self._use_thumbnail = use_thumbnail - self._min_num_tiles = min_num_tiles - self._max_num_tiles = max_num_tiles - self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn - - # Calculate all possible aspect ratios for each max_num_tiles. - self.target_ratios = { - max_num_tiles: sorted( - set( - (x, y) - for n in range(self._min_num_tiles, max_num_tiles + 1) - for x in range(1, n + 1) - for y in range(1, n + 1) - if x * y <= max_num_tiles and x * y >= self._min_num_tiles - ), - key=lambda x: x[0] * x[1], - ) - for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1) - } - - self.transform = A.Compose([ - A.OneOf([ - A.GaussNoise(var_limit=(5.0, 30.0)), - A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5)), - ], p=0.3), - A.OneOf([ - A.MedianBlur(blur_limit=5), - A.GaussianBlur(blur_limit=5), - ], p=0.2), - A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.5), - A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=15, val_shift_limit=15, p=0.3), - A.ImageCompression(quality_lower=70, quality_upper=100, p=0.3), - ]) - - def apply_params(self, transform_media: ImageTilingParams, data_augment: bool = False, **kwargs) -> list[torch.Tensor]: - assert isinstance(transform_media, ImageTilingParamsV1) - image = transform_media.media.value - - if data_augment: - image = self.transform(image=np.asarray(image))["image"] - image = Image.fromarray(image) - - # calculate the target width and height - target_width = self._tile_size * transform_media.tiling[0] - target_height = self._tile_size * transform_media.tiling[1] - blocks = transform_media.tiling[0] * transform_media.tiling[1] - - # resize the image - resized_img = image.resize((target_width, target_height)) - processed_images = [] - for i in range(blocks): - box = ( - (i % (target_width // self._tile_size)) * self._tile_size, - (i // (target_width // self._tile_size)) * self._tile_size, - ((i % (target_width // self._tile_size)) + 1) * self._tile_size, - ((i // (target_width // self._tile_size)) + 1) * self._tile_size, - ) - # split the image - split_img = resized_img.crop(box) - processed_images.append(split_img) - assert len(processed_images) == blocks - if self._use_thumbnail and len(processed_images) != 1: - thumbnail_img = image.resize((self._tile_size, self._tile_size)) - processed_images.append(thumbnail_img) - - return [self._transform(img) for img in processed_images] - - def compute_params( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: Optional[int] = None, - max_num_tiles: int | None = None, - data_augment: bool = False, - tiling_augment_prob: float = 0.4, - **kwargs, - ) -> list[ImageTilingParamsV1]: - # Use provided max_num_tiles or fall back to instance's max_num_tiles - # Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value - effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles - effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles) - - max_num_tiles_to_use = min( - num_tokens_available // self._embeddings_per_image, effective_max_num_tiles - ) - - # calculate the existing image aspect ratio - target_ratios = self.target_ratios[max_num_tiles_to_use] - - params = [] - for media in media_list: - if isinstance(media, ImageMedia): - img_size = (media.width, media.height) - elif isinstance(media, VideoFrameMedia): - img_size = (media.video_width, media.video_height) - else: - raise ValueError(f"Unsupported media type: {type(media)}") - - aspect_ratio = img_size[0] / img_size[1] - - # find the closest aspect ratio to the target - tiling = self._find_closest_aspect_ratio_fn( - aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size - ) - if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob: - tiling = self.augment_tiling(tiling) - num_tiles = tiling[0] * tiling[1] - if self._use_thumbnail and num_tiles != 1: - num_tiles += 1 - - params.append( - ImageTilingParamsV1( - media=media, - num_tiles=num_tiles, - num_embeddings=num_tiles * self._embeddings_per_image, - tiling=tiling, - ) - ) - - return params - - def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]: - def num_tiles(tiling: tuple[int, int]) -> int: - return tiling[0] * tiling[1] - - def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]: - if random.random() < minus_prob: - # Minus one - if tiling[0] == 1 and tiling[1] == 1: - return tiling - elif tiling[0] == 1: - return (tiling[0], tiling[1] - 1) - elif tiling[1] == 1: - return (tiling[0] - 1, tiling[1]) - else: - if random.random() < 0.5: - return (tiling[0] - 1, tiling[1]) - else: - return (tiling[0], tiling[1] - 1) - else: - # Plus one - if num_tiles(tiling) < self._max_num_tiles: - tiling0 = (tiling[0] + 1, tiling[1]) - tiling1 = (tiling[0], tiling[1] + 1) - if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles: - return tiling - elif num_tiles(tiling0) > self._max_num_tiles: - return tiling1 - elif num_tiles(tiling1) > self._max_num_tiles: - return tiling0 - else: - if random.random() < 0.5: - return tiling0 - else: - return tiling1 - return tiling - - new_tiling = plus_minus_one(tiling) - return new_tiling - - def __str__(self): - return f"TilingImageTransform(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, embeddings_per_tile={self._embeddings_per_image}, find_closest_aspect_ratio_fn={self._find_closest_aspect_ratio_fn})" - - -class TileDegradationStrategy(ImageTilingStrategy): - """Strategy for tiling images and video frames, each with their own tiling strategy, while trying to match the - number of tokens left in the sample by reducing the number of tiles if needed. - """ - - # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 - # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 - - def __init__( - self, - image_strategy: ImageTilingStrategy, - video_frame_strategy: ImageTilingStrategy, - embeddings_per_tile: int, - max_num_tiles: int, - tile_degradation_map: dict[int, int] = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1}, - ): - self._image_strategy = image_strategy - self._video_frame_strategy = video_frame_strategy - self._embeddings_per_tile = embeddings_per_tile - self._max_num_tiles = max_num_tiles - self._tile_degradation_map = tile_degradation_map - - def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]: - if isinstance(transform_media.media, ImageMedia): - return self._image_strategy.apply_params(transform_media, **kwargs) - elif isinstance(transform_media.media, VideoFrameMedia): - return self._video_frame_strategy.apply_params(transform_media, **kwargs) - else: - raise ValueError(f"Unsupported media type: {type(transform_media.media)}") - - def compute_params( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: int | None = None, - max_num_tiles: int | None = None, - **kwargs, - ) -> list[ImageTilingParams]: - # Use provided max_num_tiles or fall back to instance's max_num_tiles - effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles - max_num_tiles_to_use = effective_max_num_tiles - degradation_map = self._tile_degradation_map - - while True: - params = [] - img_num_tiles = [] - for media in media_list: - if isinstance(media, ImageMedia): - media_params = self._image_strategy.compute_params( - [media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs - )[0] - elif isinstance(media, VideoFrameMedia): - max_num_tiles_to_use = 1 - media_params = self._video_frame_strategy.compute_params( - [media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs - )[0] - else: - raise ValueError(f"Unsupported media type: {type(media)}") - img_num_tiles.append(media_params.num_tiles) - params.append(media_params) - if max_num_tiles_to_use == 1 or num_tokens_available is None: - break - if sum(img_num_tiles) * self._embeddings_per_tile > num_tokens_available: - if max_num_tiles_to_use in degradation_map: - max_num_tiles_to_use = degradation_map[max_num_tiles_to_use] - else: - # End of degradation - break - else: - break - return params - - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: - return self._image_strategy.stack(images) - - def __str__(self): - return f"TileDegradationImageTransform(max_num_tiles={self._max_num_tiles}, image_transform={self._image_strategy}, video_frame_transform={self._video_frame_strategy})" - - -@dataclass -class DynamicResolutionParams(ImageTilingParams): - patch_size: tuple[int, int] - - -class DynamicResolutionImageTilingStrategy(ImageTilingStrategy): - """Preprocess an image with dynamic resolution for vision transformers. - - This function resizes an image to optimize the number of patches while respecting - constraints on minimum/maximum patches, minimum side length, and compatibility - with pixel shuffle or convolution merging operations. - - The algorithm works by: - 1. Computing the initial patch grid size based on the image dimensions and res_step - 2. Scaling the patch grid to fit within the max_patches constraint - 3. Ensuring the result has at least min_patches - 4. Optionally enforcing a minimum side length constraint - 5. Rounding patch dimensions to even numbers for pixel_shuffle/conv_merging compatibility - 6. Resizing the image to the computed target dimensions - - Note: - The function preserves aspect ratio as much as possible while satisfying all constraints. - When constraints conflict (e.g., min_side vs max_patches), the function prioritizes - staying within max_patches while maximizing the image size. - - Example: - >>> from PIL import Image - >>> img = Image.open("example.jpg") # 800x600 image - >>> strategy = DynamicResolutionImageTilingStrategy(vision_model_type="radio", min_patches=4, max_patches=64, res_step=14, get_num_embeddings=lambda x, y: x * y * 2) - >>> params = strategy.compute_params([img]) - >>> img_tensor = strategy.apply_params(params[0]) - >>> # Returns image resized to maintain aspect ratio with 4-64 patches of size 14x14 - """ - - def __init__( - self, - vision_model_type: str, - min_num_patches: int, - patch_size: int, - get_num_embeddings: Callable[[int, int], int], - factor_max: float = 1.0, - pixel_shuffle: bool = False, - min_side: int | None = None, - conv_merging: bool = False, - use_thumbnail: bool = False, - thumbnail_size: int = 448, - thumbnail_area_threshold: float = 0.8, - max_num_patches: int = 0, - apply_data_augment: bool = False, - ): - """ - Args: - vision_model_type: Vision model type. - min_num_patches: Minimum number of patches required. Defaults to 1. - max_num_patches: Maximum number of patches allowed. Defaults to 0 (no maximum). - patch_size: Resolution step size (patch dimension). Defaults to 16. - get_num_embeddings: Function to get the number of embeddings from the patch size (width, height). - factor_max: Maximum scaling factor to apply. Defaults to 1.0. - pixel_shuffle: Whether to ensure compatibility with pixel shuffle operations by rounding to even patch - dimensions. Defaults to False. - min_side: Minimum side length in pixels. If specified, ensures at least one side meets this constraint. - Defaults to None. - conv_merging: Whether to ensure compatibility with convolution merging by rounding to even patch dimensions. - Defaults to False. - use_thumbnail: Whether to add a thumbnail image when processing. Defaults to False. - thumbnail_size: Size of the thumbnail image (width and height). Defaults to 448. - thumbnail_area_threshold: Maximum area percentage (0.0-1.0) of the resized image relative to thumbnail area - for which to add a thumbnail. If the resized image area is larger than this threshold of the thumbnail - area, no thumbnail will be added. Defaults to 0.8 (80%). - apply_data_augment: Whether to apply data augmentation to the image. Defaults to False. - """ - assert "radio" in vision_model_type, ( - "Dynamic resolution is only supported for radio models" - ) - self._vision_model_type = vision_model_type - self._min_num_patches = min_num_patches - self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf") - self._patch_size = patch_size - self._get_num_embeddings = get_num_embeddings - self._factor_max = factor_max - self._pixel_shuffle = pixel_shuffle - self._min_side = min_side - self._conv_merging = conv_merging - self._use_thumbnail = use_thumbnail - self._thumbnail_size = thumbnail_size - self._thumbnail_area_threshold = thumbnail_area_threshold - pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] - self._transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)), - T.Normalize(mean=pixel_mean, std=pixel_std), - ] - ) - self._apply_data_augment = apply_data_augment - - def apply_params(self, params: DynamicResolutionParams, **kwargs) -> list[torch.Tensor]: - # resize the image - resized_img = params.media.value.resize( - ( - params.patch_size[0] * self._patch_size, - params.patch_size[1] * self._patch_size, - ) - ) - processed_images = [resized_img] - - # Add thumbnail if enabled and image area is below threshold - if self._use_thumbnail: - # Calculate areas - resized_area = resized_img.size[0] * resized_img.size[1] - thumbnail_area = self._thumbnail_size * self._thumbnail_size - area_ratio = resized_area / thumbnail_area - - # Only add thumbnail if resized image area is less than threshold % of thumbnail area - if area_ratio < self._thumbnail_area_threshold: - thumbnail_img = params.media.value.resize((self._thumbnail_size, self._thumbnail_size)) - processed_images.append(thumbnail_img) - - return [self._transform(img) for img in processed_images] - - def process_media( - self, - media: ImageMedia | VideoFrameMedia, - num_tokens_available: int, - data_augment: bool = False, - tiling_augment_prob: float = 0.4, - ) -> DynamicResolutionParams: - """Process a single media item and return its parameters. - - Args: - media: The media item to process - num_tokens_available: Number of tokens available for this media - data_augment: Whether to apply data augmentation to the image. Defaults to False. - Returns: - DynamicResolutionParams for the media - """ - current_num_tokens_available = num_tokens_available - if isinstance(media, ImageMedia): - orig_width, orig_height = media.width, media.height - elif isinstance(media, VideoFrameMedia): - orig_width, orig_height = media.video_width, media.video_height - # current_num_tokens_available = 1024 #TEMP: hack for video - else: - raise ValueError(f"Unsupported media type: {type(media)}") - - closest_patch_height = round(orig_height / self._patch_size + 0.5) - closest_patch_width = round(orig_width / self._patch_size + 0.5) - patches = closest_patch_height * closest_patch_width - - factor = min(math.sqrt(current_num_tokens_available / patches), self._factor_max) - target_patch_height = math.floor(factor * closest_patch_height) - target_patch_width = math.floor(factor * closest_patch_width) - - # We only consider self._min_num_patches if it is greater than current_num_tokens_available. - if current_num_tokens_available > self._min_num_patches and target_patch_height * target_patch_width < self._min_num_patches: - up_factor = math.sqrt( - self._min_num_patches / (target_patch_height * target_patch_width) - ) - target_patch_height = math.ceil(up_factor * target_patch_height) - target_patch_width = math.ceil(up_factor * target_patch_width) - - if ( - self._min_side is not None - and min(target_patch_width, target_patch_height) * self._patch_size - < self._min_side - ): - if target_patch_width <= target_patch_height: - up_factor = self._min_side / (target_patch_width * self._patch_size) - new_patch_height = math.ceil(up_factor * target_patch_height) - new_patch_width = math.ceil(up_factor * target_patch_width) - - if new_patch_height * new_patch_width > current_num_tokens_available: - # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches - if ( - max(current_num_tokens_available // new_patch_width, 1) - * self._patch_size - < self._min_side - ): - up_factor = math.sqrt( - current_num_tokens_available - / (target_patch_height * target_patch_width) - ) - target_patch_height = math.floor( - up_factor * target_patch_height - ) - target_patch_width = math.floor( - up_factor * target_patch_width - ) - target_patch_width = new_patch_width - target_patch_height = max( - current_num_tokens_available // new_patch_width, 1 - ) - else: - target_patch_height = new_patch_height - target_patch_width = new_patch_width - else: - up_factor = self._min_side / ( - target_patch_height * self._patch_size - ) - new_patch_height = math.ceil(up_factor * target_patch_height) - new_patch_width = math.ceil(up_factor * target_patch_width) - - if new_patch_height * new_patch_width > current_num_tokens_available: - # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches - if ( - max(current_num_tokens_available // new_patch_height, 1) - * self._patch_size - < self._min_side - ): - up_factor = math.sqrt( - current_num_tokens_available - / (target_patch_height * target_patch_width) - ) - target_patch_height = math.floor( - up_factor * target_patch_height - ) - target_patch_width = math.floor( - up_factor * target_patch_width - ) - else: - target_patch_height = new_patch_height - target_patch_width = max( - current_num_tokens_available // new_patch_height, 1 - ) - else: - target_patch_height = new_patch_height - target_patch_width = new_patch_width - - # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging) - # or by 4 when BOTH are enabled (two successive 2x reductions) - if self._pixel_shuffle or self._conv_merging: - required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2 - - rem_h = target_patch_height % required_divisor - if rem_h != 0: - inc_h = required_divisor - rem_h - if (target_patch_height + inc_h) * target_patch_width <= current_num_tokens_available: - target_patch_height += inc_h - else: - target_patch_height = max(required_divisor, target_patch_height - rem_h) - - rem_w = target_patch_width % required_divisor - if rem_w != 0: - inc_w = required_divisor - rem_w - if target_patch_height * (target_patch_width + inc_w) <= current_num_tokens_available: - target_patch_width += inc_w - else: - target_patch_width = max(required_divisor, target_patch_width - rem_w) - - if data_augment and self._apply_data_augment and random.random() < tiling_augment_prob: - target_patch_width, target_patch_height = self.augment_resolution(target_patch_width, target_patch_height, current_num_tokens_available) - - #TEMP: hack for video - if isinstance(media, VideoFrameMedia): - target_patch_width = 32 - target_patch_height = 32 - - # Calculate embeddings for the main dynamic resolution image - num_embeddings = self._get_num_embeddings( - target_patch_width * self._patch_size, - target_patch_height * self._patch_size, - ) - - token_count = target_patch_width * target_patch_height - - # Add thumbnail embeddings if enabled and image area is below threshold - num_tiles = 1 # Base dynamic resolution image - if self._use_thumbnail: - # Calculate areas - resized_area = (target_patch_width * self._patch_size) * (target_patch_height * self._patch_size) - thumbnail_area = self._thumbnail_size * self._thumbnail_size - area_ratio = resized_area / thumbnail_area - - # Only add thumbnail if resized image area is less than threshold % of thumbnail area - if area_ratio < self._thumbnail_area_threshold: - num_tiles += 1 # Add 1 for thumbnail - # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) - num_embeddings += self._get_num_embeddings(self._thumbnail_size, self._thumbnail_size) - token_count += self._thumbnail_size // self._patch_size * self._thumbnail_size // self._patch_size - - return DynamicResolutionParams( - media=media, - num_tiles=num_tiles, - num_embeddings=num_embeddings, - patch_size=(target_patch_width, target_patch_height), - ), token_count - - def augment_resolution(self, target_patch_width: int, target_patch_height: int, current_num_tokens_available: int) -> tuple[int, int]: - - min_num_patch_one_side = 32 - - if random.random() < 0.5: - # Minus one - if target_patch_width <= min_num_patch_one_side and target_patch_height <= min_num_patch_one_side: - return target_patch_width, target_patch_height - elif target_patch_width <= min_num_patch_one_side: - return target_patch_width, target_patch_height - min_num_patch_one_side - elif target_patch_height <= min_num_patch_one_side: - return target_patch_width - min_num_patch_one_side, target_patch_height - else: - if random.random() < 0.5: - return target_patch_width - min_num_patch_one_side, target_patch_height - else: - return target_patch_width, target_patch_height - min_num_patch_one_side - else: - # Plus one - if target_patch_width * target_patch_height < current_num_tokens_available: - if random.random() < 0.5: - return target_patch_width + min_num_patch_one_side, target_patch_height - else: - return target_patch_width, target_patch_height + min_num_patch_one_side - return target_patch_width, target_patch_height - - def compute_params( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: int | None = None, - max_num_tiles: int | None = None, - data_augment: bool = False, - **kwargs, - ) -> list[ImageTilingParams]: - """Compute parameters for all media with iterative token budgeting. - - Args: - media_list: List of media items to process - num_tokens_available: Total number of tokens available across all media - max_num_tiles: Maximum number of tiles (unused in this implementation) - data_augment: Whether to apply data augmentation to the image. Defaults to False. - Returns: - List of ImageTilingParams for each media item - """ - num_tokens_available = num_tokens_available * (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1) - # When the number of available token is too small, allow self._min_num_patches per media and - # let the sample be truncated. - num_tokens_available = max(num_tokens_available, self._min_num_patches * len(media_list)) - - # Clip the number of tokens available per media to be between min and max patches. - num_tokens_available_per_media = [ - max(min(num_tokens_available, self._max_num_patches), self._min_num_patches) - for _ in range(len(media_list))] - - # In theory this could be a while True loop, but in case the process_media method slightly - # changes, I want to make sure we don't get stuck in an infinite loop. - for _ in range(10): - # Step 1: Process each media with current token budget - params = [] - token_counts = [] - - for media, tokens_for_media in zip(media_list, num_tokens_available_per_media): - param, token_count = self.process_media(media, tokens_for_media, data_augment=data_augment) - params.append(param) - token_counts.append(token_count) - - # Step 2: Check if total tokens is within budget - total_tokens = sum(token_counts) - - if total_tokens <= num_tokens_available: - # We're within budget, return the params - return params - - # Step 3: We're over budget, need to scale down - # Calculate scaling factor to get under budget - scaling_factor = num_tokens_available / total_tokens - - # Recalculate token budgets for each media based on scaling - # Each media gets a proportional share of the total budget - scaled_down_num_tokens_available_per_media = [ - max(self._min_num_patches, int(token_count * scaling_factor)) - for token_count in token_counts - ] - scaled_down = any([ - scaled_down_num_tokens_available_per_media[i] < num_tokens_available_per_media[i] - for i in range(len(num_tokens_available_per_media))]) - # If there was not scaling down, we're stuck just use min_num_patches per media, else - # try with the scaled down num_tokens_available_per_media. - if not scaled_down: - num_tokens_available_per_media = [self._min_num_patches] * len(media_list) - else: - num_tokens_available_per_media = scaled_down_num_tokens_available_per_media - return params - - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: - imgs_sizes = torch.tensor( - [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 - ) - - def rearrange_img(x): - py = x.shape[-2] // self._patch_size - px = x.shape[-1] // self._patch_size - x = einops.rearrange( - x, - "c (py yy) (px xx) -> (py px) (c yy xx)", - py=py, - yy=self._patch_size, - px=px, - xx=self._patch_size, - ) - return x - - if len(images) > 0: - imgs = [rearrange_img(img) for img in images] - - current_length = 0 - max_length = 0 - vision_cu_lengths = [0] - for img in imgs: - if max_length < img.shape[0]: - max_length = img.shape[0] - current_length += img.shape[0] - vision_cu_lengths.append(current_length) - - vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) - vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) - - return ( - torch.cat(imgs, dim=0).unsqueeze(0), - imgs_sizes, - vision_cu_lengths, - vision_max_lengths, - ) - else: - return ( - torch.tensor([[0]], dtype=torch.float32), - torch.tensor([[0,0]], dtype=torch.int32), - None, - None, - ) - - def __str__(self): - return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})" - - -@dataclass -class MatchTilingDynamicResolutionParams(ImageTilingParams): - tiling: tuple[int, int] - - -class MatchTilingDynamicResolutionStrategy(ImageTilingStrategy): - """ - Strategy that uses tiling logic to determine optimal image dimensions but processes - the image as a single dynamic resolution image instead of splitting into tiles. - - This combines the aspect ratio optimization from ImageTilingStrategyV1 with the - dynamic resolution processing from DynamicResolutionImageTilingStrategy. - - Also includes tile degradation logic similar to TileDegradationStrategy. - """ - - def __init__( - self, - vision_model_type: str, - tile_size: int, - use_thumbnail: bool, - min_num_tiles: int, - max_num_tiles: int, - embeddings_per_tile: int, - patch_size: int, - get_num_embeddings: Callable[[int, int], int], - find_closest_aspect_ratio_fn=find_closest_aspect_ratio, - pixel_shuffle: bool = False, - conv_merging: bool = False, - tile_degradation_map: dict[int, int] = None, - video_frame_strategy: ImageTilingStrategy = None, - enable_tile_degradation: bool = True, - ): - """ - Args: - vision_model_type: Vision model type (should support dynamic resolution) - tile_size: Size of each tile for tiling calculation - use_thumbnail: Whether tiling logic should include thumbnail - min_num_tiles: Minimum number of tiles for tiling calculation - max_num_tiles: Maximum number of tiles for tiling calculation - embeddings_per_tile: Embeddings per tile for tiling calculation - patch_size: Patch size for dynamic resolution processing - get_num_embeddings: Function to get number of embeddings from dimensions - find_closest_aspect_ratio_fn: Function to find closest aspect ratio - pixel_shuffle: Whether to ensure compatibility with pixel shuffle - conv_merging: Whether to ensure compatibility with convolution merging - tile_degradation_map: Map for degrading tiles when tokens are insufficient - video_frame_strategy: Strategy for processing video frames - enable_tile_degradation: Whether to enable tile degradation (default: True) - """ - assert "radio" in vision_model_type, ( - "MatchTilingDynamicResolution is only supported for radio models" - ) - - self._vision_model_type = vision_model_type - self._tile_size = tile_size - self._use_thumbnail = use_thumbnail - self._min_num_tiles = min_num_tiles - self._max_num_tiles = max_num_tiles - self._embeddings_per_tile = embeddings_per_tile - self._patch_size = patch_size - self._get_num_embeddings = get_num_embeddings - self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn - self._pixel_shuffle = pixel_shuffle - self._conv_merging = conv_merging - self._enable_tile_degradation = enable_tile_degradation - - # Tile degradation logic (similar to TileDegradationStrategy) - if tile_degradation_map is None: - self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1} - else: - self._tile_degradation_map = tile_degradation_map - - # Video frame strategy (similar to TileDegradationStrategy) - if video_frame_strategy is None: - self._video_frame_strategy = NoTilingStrategy( - vision_model_type=vision_model_type, - target_width=tile_size, - target_height=tile_size, - embeddings_per_image=embeddings_per_tile, - ) - else: - self._video_frame_strategy = video_frame_strategy - - # Calculate all possible aspect ratios for each max_num_tiles (borrowed from ImageTilingStrategyV1) - self.target_ratios = { - max_num_tiles: sorted( - set( - (x, y) - for n in range(self._min_num_tiles, max_num_tiles + 1) - for x in range(1, n + 1) - for y in range(1, n + 1) - if x * y <= max_num_tiles and x * y >= self._min_num_tiles - ), - key=lambda x: x[0] * x[1], - ) - for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1) - } - - # Set up transform for dynamic resolution processing - pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] - self._transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), - T.Normalize(mean=pixel_mean, std=pixel_std), - ] - ) - - def apply_params(self, params: MatchTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]: - # Handle video frames using the video frame strategy - if isinstance(params.media, VideoFrameMedia): - return self._video_frame_strategy.apply_params(params, **kwargs) - - # Handle images with dynamic resolution processing - image = params.media.value - # Calculate the target width and height (same logic as ImageTilingStrategyV1) - target_width = self._tile_size * params.tiling[0] - target_height = self._tile_size * params.tiling[1] - - # Resize the image to the target dimensions (same as ImageTilingStrategyV1) - resized_img = image.resize((target_width, target_height)) - - # Process as single dynamic resolution image - processed_images = [resized_img] - - # Add thumbnail if use_thumbnail=True and there's more than 1 tile (same as ImageTilingStrategyV1) - blocks = params.tiling[0] * params.tiling[1] - if self._use_thumbnail and blocks != 1: - thumbnail_img = image.resize((self._tile_size, self._tile_size)) - processed_images.append(thumbnail_img) - - return [self._transform(img) for img in processed_images] - - def compute_params( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: int | None = None, - max_num_tiles: int | None = None, - **kwargs, - ) -> list[MatchTilingDynamicResolutionParams]: - # Implement tile degradation logic similar to TileDegradationStrategy - # Use provided max_num_tiles or fall back to instance's max_num_tiles - # Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value - effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles - effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles) - max_num_tiles_to_use = effective_max_num_tiles - degradation_map = self._tile_degradation_map - - while True: - params = [] - total_embeddings_needed = 0 - - for media in media_list: - if isinstance(media, ImageMedia): - # Use tiling logic for images - img_size = (media.width, media.height) - aspect_ratio = img_size[0] / img_size[1] - - # Find the closest aspect ratio to the target - target_ratios = self.target_ratios[max_num_tiles_to_use] - tiling = self._find_closest_aspect_ratio_fn( - aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size - ) - - # Calculate target dimensions for dynamic resolution processing - target_width = self._tile_size * tiling[0] - target_height = self._tile_size * tiling[1] - num_embeddings = self._get_num_embeddings(target_width, target_height) - - # Account for thumbnail (same logic as ImageTilingStrategyV1) - num_tiles = 1 # Base dynamic resolution image - blocks = tiling[0] * tiling[1] - if self._use_thumbnail and blocks != 1: - num_tiles += 1 # Add 1 for thumbnail - # Add embeddings for thumbnail (tile_size x tile_size) - num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size) - - media_params = MatchTilingDynamicResolutionParams( - media=media, - num_tiles=num_tiles, - num_embeddings=num_embeddings, - tiling=tiling, - ) - elif isinstance(media, VideoFrameMedia): - # Use video frame strategy for video frames (always 1 tile) - video_params = self._video_frame_strategy.compute_params( - [media], 1 * self._embeddings_per_tile - )[0] - media_params = MatchTilingDynamicResolutionParams( - media=media, - num_tiles=video_params.num_tiles, - num_embeddings=video_params.num_embeddings, - tiling=(1, 1), # Video frames always use 1x1 tiling - ) - else: - raise ValueError(f"Unsupported media type: {type(media)}") - - params.append(media_params) - total_embeddings_needed += media_params.num_embeddings - - # Check if we need to degrade (only if degradation is enabled) - if not self._enable_tile_degradation: - break - if max_num_tiles_to_use == 1 or num_tokens_available is None: - break - if total_embeddings_needed > num_tokens_available: - if max_num_tiles_to_use in degradation_map: - max_num_tiles_to_use = degradation_map[max_num_tiles_to_use] - # Recalculate target ratios for the new max_num_tiles_to_use - if max_num_tiles_to_use not in self.target_ratios: - self.target_ratios[max_num_tiles_to_use] = sorted( - set( - (x, y) - for n in range(self._min_num_tiles, max_num_tiles_to_use + 1) - for x in range(1, n + 1) - for y in range(1, n + 1) - if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles - ), - key=lambda x: x[0] * x[1], - ) - else: - # End of degradation - break - else: - break - - return params - - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]: - """Stack images using dynamic resolution approach with sequence packing""" - imgs_sizes = torch.tensor( - [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 - ) - - def rearrange_img(x): - py = x.shape[-2] // self._patch_size - px = x.shape[-1] // self._patch_size - x = einops.rearrange( - x, - "c (py yy) (px xx) -> (py px) (c yy xx)", - py=py, - yy=self._patch_size, - px=px, - xx=self._patch_size, - ) - return x - - if len(images) > 0: - imgs = [rearrange_img(img) for img in images] - - current_length = 0 - max_length = 0 - vision_cu_lengths = [0] - for img in imgs: - if max_length < img.shape[0]: - max_length = img.shape[0] - current_length += img.shape[0] - vision_cu_lengths.append(current_length) - - vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) - vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) - - return ( - torch.cat(imgs, dim=0).unsqueeze(0), - imgs_sizes, - vision_cu_lengths, - vision_max_lengths, - ) - else: - return ( - torch.tensor([[0]], dtype=torch.float32), - torch.tensor([[0,0]], dtype=torch.int32), - None, - None, - ) - - def __str__(self): - return f"MatchTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})" - - -@dataclass -class MaskedTilingDynamicResolutionParams(ImageTilingParams): - tiling: tuple[int, int] - - -class MaskedTilingDynamicResolutionStrategy(ImageTilingStrategy): - """ - Like MatchTilingDynamicResolutionStrategy, but ensures tiles are isolated in the - vision encoder by emitting per-tile packed samples (block-diagonal attention across tiles). - """ - - def __init__( - self, - vision_model_type: str, - tile_size: int, - use_thumbnail: bool, - min_num_tiles: int, - max_num_tiles: int, - embeddings_per_tile: int, - patch_size: int, - get_num_embeddings: Callable[[int, int], int], - find_closest_aspect_ratio_fn=find_closest_aspect_ratio, - pixel_shuffle: bool = False, - conv_merging: bool = False, - tile_degradation_map: dict[int, int] = None, - video_frame_strategy: ImageTilingStrategy = None, - enable_tile_degradation: bool = True, - ): - assert "radio" in vision_model_type, ( - "MaskedTilingDynamicResolution is only supported for radio models" - ) - - self._vision_model_type = vision_model_type - self._tile_size = tile_size - self._use_thumbnail = use_thumbnail - self._min_num_tiles = min_num_tiles - self._max_num_tiles = max_num_tiles - self._embeddings_per_tile = embeddings_per_tile - self._patch_size = patch_size - self._get_num_embeddings = get_num_embeddings - self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn - self._pixel_shuffle = pixel_shuffle - self._conv_merging = conv_merging - self._enable_tile_degradation = enable_tile_degradation - - if tile_degradation_map is None: - self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1} - else: - self._tile_degradation_map = tile_degradation_map - - if video_frame_strategy is None: - self._video_frame_strategy = NoTilingStrategy( - vision_model_type=vision_model_type, - target_width=tile_size, - target_height=tile_size, - embeddings_per_image=embeddings_per_tile, - ) - else: - self._video_frame_strategy = video_frame_strategy - - self.target_ratios = { - max_num_tiles: sorted( - set( - (x, y) - for n in range(self._min_num_tiles, max_num_tiles + 1) - for x in range(1, n + 1) - for y in range(1, n + 1) - if x * y <= max_num_tiles and x * y >= self._min_num_tiles - ), - key=lambda x: x[0] * x[1], - ) - for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1) - } - - pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] - self._transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), - T.Normalize(mean=pixel_mean, std=pixel_std), - ] - ) - - def apply_params(self, params: MaskedTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]: - # Handle video frames using the video frame strategy - if isinstance(params.media, VideoFrameMedia): - return self._video_frame_strategy.apply_params(params, **kwargs) - - image = params.media.value - nx, ny = params.tiling - target_width = self._tile_size * nx - target_height = self._tile_size * ny - - resized_img = image.resize((target_width, target_height)) - - processed_images = [] - # Emit per-tile images (each becomes an isolated packed sample later) - for j in range(ny): - for i in range(nx): - box = ( - i * self._tile_size, - j * self._tile_size, - (i + 1) * self._tile_size, - (j + 1) * self._tile_size, - ) - tile_img = resized_img.crop(box) - processed_images.append(tile_img) - - if self._use_thumbnail and (nx * ny) != 1: - thumbnail_img = image.resize((self._tile_size, self._tile_size)) - processed_images.append(thumbnail_img) - - return [self._transform(img) for img in processed_images] - - def compute_params( - self, - media_list: list[ImageMedia | VideoFrameMedia], - num_tokens_available: int | None = None, - max_num_tiles: int | None = None, - data_augment: bool = False, - tiling_augment_prob: float = 0.4, - **kwargs, - ) -> list[MaskedTilingDynamicResolutionParams]: - effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles - effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles) - max_num_tiles_to_use = effective_max_num_tiles - degradation_map = self._tile_degradation_map - - while True: - params = [] - total_embeddings_needed = 0 - - for media in media_list: - if isinstance(media, ImageMedia): - img_size = (media.width, media.height) - aspect_ratio = img_size[0] / img_size[1] - - target_ratios = self.target_ratios[max_num_tiles_to_use] - tiling = self._find_closest_aspect_ratio_fn( - aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size - ) - - # Apply tiling augmentation if enabled - if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob: - tiling = self.augment_tiling(tiling) - - blocks = tiling[0] * tiling[1] - # Each tile is tile_size x tile_size - per_tile_emb = self._get_num_embeddings(self._tile_size, self._tile_size) - num_embeddings = blocks * per_tile_emb - - num_tiles = blocks - if self._use_thumbnail and blocks != 1: - num_tiles += 1 - num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size) - - media_params = MaskedTilingDynamicResolutionParams( - media=media, - num_tiles=num_tiles, - num_embeddings=num_embeddings, - tiling=tiling, - ) - elif isinstance(media, VideoFrameMedia): - video_params = self._video_frame_strategy.compute_params( - [media], 1 * self._embeddings_per_tile - )[0] - media_params = MaskedTilingDynamicResolutionParams( - media=media, - num_tiles=video_params.num_tiles, - num_embeddings=video_params.num_embeddings, - tiling=(1, 1), - ) - else: - raise ValueError(f"Unsupported media type: {type(media)}") - - params.append(media_params) - total_embeddings_needed += media_params.num_embeddings - - if not self._enable_tile_degradation: - break - if max_num_tiles_to_use == 1 or num_tokens_available is None: - break - if total_embeddings_needed > num_tokens_available: - if max_num_tiles_to_use in degradation_map: - max_num_tiles_to_use = degradation_map[max_num_tiles_to_use] - if max_num_tiles_to_use not in self.target_ratios: - self.target_ratios[max_num_tiles_to_use] = sorted( - set( - (x, y) - for n in range(self._min_num_tiles, max_num_tiles_to_use + 1) - for x in range(1, n + 1) - for y in range(1, n + 1) - if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles - ), - key=lambda x: x[0] * x[1], - ) - else: - break - else: - break - - return params - - def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]: - def num_tiles(tiling: tuple[int, int]) -> int: - return tiling[0] * tiling[1] - - def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]: - if random.random() < minus_prob: - # Minus one - if tiling[0] == 1 and tiling[1] == 1: - return tiling - elif tiling[0] == 1: - return (tiling[0], tiling[1] - 1) - elif tiling[1] == 1: - return (tiling[0] - 1, tiling[1]) - else: - if random.random() < 0.5: - return (tiling[0] - 1, tiling[1]) - else: - return (tiling[0], tiling[1] - 1) - else: - # Plus one - if num_tiles(tiling) < self._max_num_tiles: - tiling0 = (tiling[0] + 1, tiling[1]) - tiling1 = (tiling[0], tiling[1] + 1) - if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles: - return tiling - elif num_tiles(tiling0) > self._max_num_tiles: - return tiling1 - elif num_tiles(tiling1) > self._max_num_tiles: - return tiling0 - else: - if random.random() < 0.5: - return tiling0 - else: - return tiling1 - return tiling - - new_tiling = plus_minus_one(tiling) - return new_tiling - - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]: - # Identical to dynamic resolution packing; each tile is already an independent image sample - imgs_sizes = torch.tensor( - [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 - ) - - def rearrange_img(x): - py = x.shape[-2] // self._patch_size - px = x.shape[-1] // self._patch_size - x = einops.rearrange( - x, - "c (py yy) (px xx) -> (py px) (c yy xx)", - py=py, - yy=self._patch_size, - px=px, - xx=self._patch_size, - ) - return x - - if len(images) > 0: - imgs = [rearrange_img(img) for img in images] - - current_length = 0 - max_length = 0 - vision_cu_lengths = [0] - for img in imgs: - if max_length < img.shape[0]: - max_length = img.shape[0] - current_length += img.shape[0] - vision_cu_lengths.append(current_length) - - vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) - vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) - - return ( - torch.cat(imgs, dim=0).unsqueeze(0), - imgs_sizes, - vision_cu_lengths, - vision_max_lengths, - ) - else: - return ( - torch.tensor([[0]], dtype=torch.float32), - torch.tensor([[0,0]], dtype=torch.int32), - None, - None, - ) - - def __str__(self): - return f"MaskedTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})" - -def create_image_tiling_strategy(args): - """ - Create an image tiling strategy based on the provided arguments. - - This function encapsulates the logic for creating the appropriate image tiling strategy - based on the training/evaluation configuration. It can be used by both training (task_encoder) - and evaluation code outside of data_loading/. - - Args: - args: Arguments object with the following relevant attributes: - - img_h, img_w: Image height and width - - patch_dim: Patch dimension - - vision_model_type: Vision model type (e.g., 'radio', 'clip', 'siglip') - - disable_vision_class_token: Whether to disable vision class token - - pixel_shuffle: Whether to use pixel shuffle - - use_tile_tags: Whether to use tile tags - - max_num_tiles: Maximum number of tiles - - tokenizer_prompt_format: Tokenizer prompt format - - image_break_token: Image break token (optional) - - conv_merging: Whether to use convolution merging - - dynamic_resolution: Whether to use dynamic resolution - - match_tiling_dynamic_resolution: Whether to match tiling with dynamic resolution - - use_area_weighted_aspect_ratio: Whether to use area-weighted aspect ratio - - use_thumbnail: Whether to use thumbnail - - dynamic_resolution_min_patches: Minimum number of patches for dynamic resolution - - dynamic_resolution_min_side: Minimum side length for dynamic resolution (optional) - - thumbnail_area_threshold: Thumbnail area threshold (optional) - - use_tiling: Whether to use tiling - - Returns: - ImageTilingStrategy: The created image tiling strategy - """ - from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings - - assert args.img_h == args.img_w, "img_h and img_w must be the same" - - match_tiling_dynamic_resolution = args.match_tiling_dynamic_resolution - masked_tiling_dynamic_resolution = getattr(args, "masked_tiling_dynamic_resolution", False) - dynamic_resolution = args.dynamic_resolution - use_tiling = args.use_tiling - use_area_weighted_aspect_ratio = args.use_area_weighted_aspect_ratio - - if match_tiling_dynamic_resolution: - assert dynamic_resolution, "must enable --dynamic-resolution if using --match-tiling-dynamic-resolution" - assert not use_tiling, "cannot use --use-tiling and --match-tiling-dynamic-resolution together" - if masked_tiling_dynamic_resolution: - assert dynamic_resolution, "must enable --dynamic-resolution if using --masked-tiling-dynamic-resolution" - assert not use_tiling, "cannot use --use-tiling and --masked-tiling-dynamic-resolution together" - assert not match_tiling_dynamic_resolution, "cannot combine --masked-tiling-dynamic-resolution with --match-tiling-dynamic-resolution" - - if dynamic_resolution: - if masked_tiling_dynamic_resolution: - num_image_embeddings_per_tile = get_num_image_embeddings( - img_h=args.img_h, - img_w=args.img_w, - patch_dim=args.patch_dim, - vision_model_type=args.vision_model_type, - disable_vision_class_token=args.disable_vision_class_token, - class_token_len=1, - pixel_shuffle=args.pixel_shuffle, - use_tile_tags=args.use_tile_tags, - max_num_tiles=args.max_num_tiles, - tokenizer_type=args.tokenizer_prompt_format, - use_image_break_token=args.image_break_token is not None, - conv_merging=args.conv_merging, - ) - image_tiling_strategy = MaskedTilingDynamicResolutionStrategy( - vision_model_type=args.vision_model_type, - tile_size=args.img_h, - use_thumbnail=args.use_thumbnail, - min_num_tiles=1, - max_num_tiles=args.max_num_tiles, - embeddings_per_tile=num_image_embeddings_per_tile, - patch_size=args.patch_dim, - get_num_embeddings=lambda width, height: get_num_image_embeddings( - img_h=height, - img_w=width, - patch_dim=args.patch_dim, - vision_model_type=args.vision_model_type, - disable_vision_class_token=args.disable_vision_class_token, - class_token_len=1, - pixel_shuffle=args.pixel_shuffle, - use_tile_tags=args.use_tile_tags, - max_num_tiles=args.max_num_tiles, - tokenizer_type=args.tokenizer_prompt_format, - use_image_break_token=args.image_break_token is not None, - conv_merging=args.conv_merging, - ), - find_closest_aspect_ratio_fn=( - find_closest_area_weighted_aspect_ratio - if use_area_weighted_aspect_ratio - else find_closest_aspect_ratio - ), - pixel_shuffle=args.pixel_shuffle, - conv_merging=args.conv_merging, - ) - elif match_tiling_dynamic_resolution: - num_image_embeddings_per_tile = get_num_image_embeddings( - img_h=args.img_h, - img_w=args.img_w, - patch_dim=args.patch_dim, - vision_model_type=args.vision_model_type, - disable_vision_class_token=args.disable_vision_class_token, - class_token_len=1, - pixel_shuffle=args.pixel_shuffle, - use_tile_tags=args.use_tile_tags, - max_num_tiles=args.max_num_tiles, - tokenizer_type=args.tokenizer_prompt_format, - use_image_break_token=args.image_break_token is not None, - conv_merging=args.conv_merging, - ) - image_tiling_strategy = MatchTilingDynamicResolutionStrategy( - vision_model_type=args.vision_model_type, - tile_size=args.img_h, - use_thumbnail=args.use_thumbnail, - min_num_tiles=1, - max_num_tiles=args.max_num_tiles, - embeddings_per_tile=num_image_embeddings_per_tile, - patch_size=args.patch_dim, - get_num_embeddings=lambda width, height: get_num_image_embeddings( - img_h=height, - img_w=width, - patch_dim=args.patch_dim, - vision_model_type=args.vision_model_type, - disable_vision_class_token=args.disable_vision_class_token, - class_token_len=1, - pixel_shuffle=args.pixel_shuffle, - use_tile_tags=args.use_tile_tags, - max_num_tiles=args.max_num_tiles, - tokenizer_type=args.tokenizer_prompt_format, - use_image_break_token=args.image_break_token is not None, - conv_merging=args.conv_merging, - ), - find_closest_aspect_ratio_fn=( - find_closest_area_weighted_aspect_ratio - if use_area_weighted_aspect_ratio - else find_closest_aspect_ratio - ), - pixel_shuffle=args.pixel_shuffle, - conv_merging=args.conv_merging, - ) - else: - image_tiling_strategy = DynamicResolutionImageTilingStrategy( - vision_model_type=args.vision_model_type, - min_num_patches=args.dynamic_resolution_min_patches, - patch_size=args.patch_dim, - get_num_embeddings=lambda width, height: get_num_image_embeddings( - img_h=height, - img_w=width, - patch_dim=args.patch_dim, - vision_model_type=args.vision_model_type, - disable_vision_class_token=args.disable_vision_class_token, - class_token_len=1, - pixel_shuffle=args.pixel_shuffle, - use_tile_tags=args.use_tile_tags, - max_num_tiles=args.max_num_tiles, - tokenizer_type=args.tokenizer_prompt_format, - use_image_break_token=args.image_break_token is not None, - conv_merging=args.conv_merging, - ), - pixel_shuffle=args.pixel_shuffle, - min_side=args.dynamic_resolution_min_side, - conv_merging=args.conv_merging, - use_thumbnail=args.use_thumbnail, - thumbnail_size=args.img_h, - thumbnail_area_threshold=args.thumbnail_area_threshold, - max_num_patches=args.dynamic_resolution_max_patches, - apply_data_augment=args.apply_data_augment, - ) - else: - num_image_embeddings_per_tile = get_num_image_embeddings( - img_h=args.img_h, - img_w=args.img_w, - patch_dim=args.patch_dim, - vision_model_type=args.vision_model_type, - disable_vision_class_token=args.disable_vision_class_token, - class_token_len=1, - pixel_shuffle=args.pixel_shuffle, - use_tile_tags=args.use_tile_tags, - max_num_tiles=args.max_num_tiles, - tokenizer_type=args.tokenizer_prompt_format, - use_image_break_token=args.image_break_token is not None, - conv_merging=args.conv_merging, - ) - if use_tiling: - image_strategy = ImageTilingStrategyV1( - vision_model_type=args.vision_model_type, - tile_size=args.img_h, - use_thumbnail=args.use_thumbnail, - min_num_tiles=1, - max_num_tiles=args.max_num_tiles, - embeddings_per_tile=num_image_embeddings_per_tile, - find_closest_aspect_ratio_fn=( - find_closest_area_weighted_aspect_ratio - if use_area_weighted_aspect_ratio - else find_closest_aspect_ratio - ), - ) - else: - image_strategy = NoTilingStrategy( - vision_model_type=args.vision_model_type, - embeddings_per_image=num_image_embeddings_per_tile, - target_width=args.img_w, - target_height=args.img_h, - ) - image_tiling_strategy = TileDegradationStrategy( - image_strategy=image_strategy, - video_frame_strategy=NoTilingStrategy( - vision_model_type=args.vision_model_type, - embeddings_per_image=num_image_embeddings_per_tile, - target_width=args.img_w, - target_height=args.img_h, - ), - embeddings_per_tile=num_image_embeddings_per_tile, - max_num_tiles=args.max_num_tiles, - ) - - return image_tiling_strategy diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 82423891c11c7..3b8a3841cf938 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -113,7 +113,7 @@ pixel_statistics = { @dataclass class DynamicResolutionParams: - image: Image.Image + media: Image.Image num_tiles: int num_embeddings: int patch_size: tuple[int, int] @@ -165,7 +165,7 @@ class DynamicResolutionImageTilingStrategy: self, params: DynamicResolutionParams, **kwargs ) -> list[torch.Tensor]: # resize the image - resized_img = params.image.resize( + resized_img = params.media.resize( ( params.patch_size[0] * self._patch_size, params.patch_size[1] * self._patch_size, @@ -183,7 +183,7 @@ class DynamicResolutionImageTilingStrategy: # Only add thumbnail if resized image area is less than threshold % of # thumbnail area if area_ratio < self._thumbnail_area_threshold: - thumbnail_img = params.image.resize( + thumbnail_img = params.media.resize( (self._thumbnail_size, self._thumbnail_size) ) processed_images.append(thumbnail_img) @@ -192,7 +192,7 @@ class DynamicResolutionImageTilingStrategy: def process_media( self, - image: Image.Image, + media: Image.Image, num_tokens_available: int, data_augment: bool = False, tiling_augment_prob: float = 0.4, @@ -207,10 +207,10 @@ class DynamicResolutionImageTilingStrategy: DynamicResolutionParams for the media """ current_num_tokens_available = num_tokens_available - assert isinstance(image, Image.Image), ( + assert isinstance(media, Image.Image), ( "Dynamic resolution is only supported for image media" ) - orig_width, orig_height = image.width, image.height + orig_width, orig_height = media.width, media.height closest_patch_height = round(orig_height / self._patch_size + 0.5) closest_patch_width = round(orig_width / self._patch_size + 0.5) @@ -336,7 +336,7 @@ class DynamicResolutionImageTilingStrategy: target_patch_width, target_patch_height, current_num_tokens_available ) - assert isinstance(image, Image.Image), ( + assert isinstance(media, Image.Image), ( "Dynamic resolution is only supported for image media" ) @@ -374,7 +374,7 @@ class DynamicResolutionImageTilingStrategy: ) return DynamicResolutionParams( - image=image, + media=media, num_tiles=num_tiles, num_embeddings=num_embeddings, patch_size=(target_patch_width, target_patch_height), From 3be49316c91fcd6a7f9579d21a2fb7307e9c4e1e Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Mon, 15 Dec 2025 17:37:46 +0200 Subject: [PATCH 06/10] refactor --- .../model_executor/models/nano_nemotron_vl.py | 1060 ++++++++--------- 1 file changed, 511 insertions(+), 549 deletions(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 3b8a3841cf938..3bac36744ef5e 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -12,7 +12,7 @@ import math import random import warnings from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from typing import Annotated, Any, Literal, TypeAlias, TypeVar @@ -88,501 +88,9 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely # Alternative: Set a specific higher limit # Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels -IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] -IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] -SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5] -SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5] -CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] -CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] -RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060] -RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250] - -pixel_statistics = { - "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), - "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), - "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD), - "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), - "radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), -} - - -@dataclass -class DynamicResolutionParams: - media: Image.Image - num_tiles: int - num_embeddings: int - patch_size: tuple[int, int] - - -class DynamicResolutionImageTilingStrategy: - def __init__( - self, - vision_model_type: str, - min_num_patches: int, - patch_size: int, - get_num_embeddings: Callable[[int, int], int], - factor_max: float = 1.0, - pixel_shuffle: bool = False, - min_side: int | None = None, - conv_merging: bool = False, - use_thumbnail: bool = False, - thumbnail_size: int = 448, - thumbnail_area_threshold: float = 0.8, - max_num_patches: int = 0, - apply_data_augment: bool = False, - ): - assert "radio" in vision_model_type, ( - "Dynamic resolution is only supported for radio models" - ) - self._vision_model_type = vision_model_type - self._min_num_patches = min_num_patches - self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf") - self._patch_size = patch_size - self._get_num_embeddings = get_num_embeddings - self._factor_max = factor_max - self._pixel_shuffle = pixel_shuffle - self._min_side = min_side - self._conv_merging = conv_merging - self._use_thumbnail = use_thumbnail - self._thumbnail_size = thumbnail_size - self._thumbnail_area_threshold = thumbnail_area_threshold - pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] - self._transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)), - T.Normalize(mean=pixel_mean, std=pixel_std), - ] - ) - self._apply_data_augment = apply_data_augment - - def apply_params( - self, params: DynamicResolutionParams, **kwargs - ) -> list[torch.Tensor]: - # resize the image - resized_img = params.media.resize( - ( - params.patch_size[0] * self._patch_size, - params.patch_size[1] * self._patch_size, - ) - ) - processed_images = [resized_img] - - # Add thumbnail if enabled and image area is below threshold - if self._use_thumbnail: - # Calculate areas - resized_area = resized_img.size[0] * resized_img.size[1] - thumbnail_area = self._thumbnail_size * self._thumbnail_size - area_ratio = resized_area / thumbnail_area - - # Only add thumbnail if resized image area is less than threshold % of - # thumbnail area - if area_ratio < self._thumbnail_area_threshold: - thumbnail_img = params.media.resize( - (self._thumbnail_size, self._thumbnail_size) - ) - processed_images.append(thumbnail_img) - - return [self._transform(img) for img in processed_images] - - def process_media( - self, - media: Image.Image, - num_tokens_available: int, - data_augment: bool = False, - tiling_augment_prob: float = 0.4, - ) -> DynamicResolutionParams: - """Process a single media item and return its parameters. - Args: - media: The media item to process - num_tokens_available: Number of tokens available for this media - data_augment: Whether to apply data augmentation to the image. Defaults to - False. - Returns: - DynamicResolutionParams for the media - """ - current_num_tokens_available = num_tokens_available - assert isinstance(media, Image.Image), ( - "Dynamic resolution is only supported for image media" - ) - orig_width, orig_height = media.width, media.height - - closest_patch_height = round(orig_height / self._patch_size + 0.5) - closest_patch_width = round(orig_width / self._patch_size + 0.5) - patches = closest_patch_height * closest_patch_width - - factor = min( - math.sqrt(current_num_tokens_available / patches), self._factor_max - ) - target_patch_height = math.floor(factor * closest_patch_height) - target_patch_width = math.floor(factor * closest_patch_width) - - # We only consider self._min_num_patches if it is greater than - # current_num_tokens_available. - if ( - current_num_tokens_available > self._min_num_patches - and target_patch_height * target_patch_width < self._min_num_patches - ): - up_factor = math.sqrt( - self._min_num_patches / (target_patch_height * target_patch_width) - ) - target_patch_height = math.ceil(up_factor * target_patch_height) - target_patch_width = math.ceil(up_factor * target_patch_width) - - if ( - self._min_side is not None - and min(target_patch_width, target_patch_height) * self._patch_size - < self._min_side - ): - if target_patch_width <= target_patch_height: - up_factor = self._min_side / (target_patch_width * self._patch_size) - new_patch_height = math.ceil(up_factor * target_patch_height) - new_patch_width = math.ceil(up_factor * target_patch_width) - - if new_patch_height * new_patch_width > current_num_tokens_available: - # If only one side can be min_side, make as big as possible at - # native aspect ratio while staying below max_patches - if ( - max(current_num_tokens_available // new_patch_width, 1) - * self._patch_size - < self._min_side - ): - up_factor = math.sqrt( - current_num_tokens_available - / (target_patch_height * target_patch_width) - ) - target_patch_height = math.floor( - up_factor * target_patch_height - ) - target_patch_width = math.floor(up_factor * target_patch_width) - target_patch_width = new_patch_width - target_patch_height = max( - current_num_tokens_available // new_patch_width, 1 - ) - else: - target_patch_height = new_patch_height - target_patch_width = new_patch_width - else: - up_factor = self._min_side / (target_patch_height * self._patch_size) - new_patch_height = math.ceil(up_factor * target_patch_height) - new_patch_width = math.ceil(up_factor * target_patch_width) - - if new_patch_height * new_patch_width > current_num_tokens_available: - # If only one side can be min_side, make as big as possible at - # native aspect ratio while staying below max_patches - if ( - max(current_num_tokens_available // new_patch_height, 1) - * self._patch_size - < self._min_side - ): - up_factor = math.sqrt( - current_num_tokens_available - / (target_patch_height * target_patch_width) - ) - target_patch_height = math.floor( - up_factor * target_patch_height - ) - target_patch_width = math.floor(up_factor * target_patch_width) - else: - target_patch_height = new_patch_height - target_patch_width = max( - current_num_tokens_available // new_patch_height, 1 - ) - else: - target_patch_height = new_patch_height - target_patch_width = new_patch_width - - # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging) - # or by 4 when BOTH are enabled (two successive 2x reductions) - if self._pixel_shuffle or self._conv_merging: - required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2 - - rem_h = target_patch_height % required_divisor - if rem_h != 0: - inc_h = required_divisor - rem_h - if ( - target_patch_height + inc_h - ) * target_patch_width <= current_num_tokens_available: - target_patch_height += inc_h - else: - target_patch_height = max( - required_divisor, target_patch_height - rem_h - ) - - rem_w = target_patch_width % required_divisor - if rem_w != 0: - inc_w = required_divisor - rem_w - if ( - target_patch_height * (target_patch_width + inc_w) - <= current_num_tokens_available - ): - target_patch_width += inc_w - else: - target_patch_width = max( - required_divisor, target_patch_width - rem_w - ) - - if ( - data_augment - and self._apply_data_augment - and random.random() < tiling_augment_prob - ): - target_patch_width, target_patch_height = self.augment_resolution( - target_patch_width, target_patch_height, current_num_tokens_available - ) - - assert isinstance(media, Image.Image), ( - "Dynamic resolution is only supported for image media" - ) - - # Calculate embeddings for the main dynamic resolution image - num_embeddings = self._get_num_embeddings( - target_patch_width * self._patch_size, - target_patch_height * self._patch_size, - ) - - token_count = target_patch_width * target_patch_height - - # Add thumbnail embeddings if enabled and image area is below threshold - num_tiles = 1 # Base dynamic resolution image - if self._use_thumbnail: - # Calculate areas - resized_area = (target_patch_width * self._patch_size) * ( - target_patch_height * self._patch_size - ) - thumbnail_area = self._thumbnail_size * self._thumbnail_size - area_ratio = resized_area / thumbnail_area - - # Only add thumbnail if resized image area is less than threshold % of - # thumbnail area - if area_ratio < self._thumbnail_area_threshold: - num_tiles += 1 # Add 1 for thumbnail - # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) - num_embeddings += self._get_num_embeddings( - self._thumbnail_size, self._thumbnail_size - ) - token_count += ( - self._thumbnail_size - // self._patch_size - * self._thumbnail_size - // self._patch_size - ) - - return DynamicResolutionParams( - media=media, - num_tiles=num_tiles, - num_embeddings=num_embeddings, - patch_size=(target_patch_width, target_patch_height), - ), token_count - - def augment_resolution( - self, - target_patch_width: int, - target_patch_height: int, - current_num_tokens_available: int, - ) -> tuple[int, int]: - min_num_patch_one_side = 32 - - if random.random() < 0.5: - # Minus one - if ( - target_patch_width <= min_num_patch_one_side - and target_patch_height <= min_num_patch_one_side - ): - return target_patch_width, target_patch_height - elif target_patch_width <= min_num_patch_one_side: - return target_patch_width, target_patch_height - min_num_patch_one_side - elif target_patch_height <= min_num_patch_one_side: - return target_patch_width - min_num_patch_one_side, target_patch_height - else: - if random.random() < 0.5: - return ( - target_patch_width - min_num_patch_one_side, - target_patch_height, - ) - else: - return ( - target_patch_width, - target_patch_height - min_num_patch_one_side, - ) - else: - # Plus one - if target_patch_width * target_patch_height < current_num_tokens_available: - if random.random() < 0.5: - return ( - target_patch_width + min_num_patch_one_side, - target_patch_height, - ) - else: - return ( - target_patch_width, - target_patch_height + min_num_patch_one_side, - ) - return target_patch_width, target_patch_height - - def compute_params( - self, - media_list: list[Image.Image], - num_tokens_available: int | None = None, - max_num_tiles: int | None = None, - data_augment: bool = False, - **kwargs, - ) -> list[DynamicResolutionParams]: - """Compute parameters for all media with iterative token budgeting. - - Args: - media_list: List of media items to process - num_tokens_available: Total number of tokens available across all media - max_num_tiles: Maximum number of tiles (unused in this implementation) - data_augment: Whether to apply data augmentation to the image. Defaults to - False. - Returns: - List of ImageTilingParams for each media item - """ - num_tokens_available = ( - num_tokens_available - * (4 if self._pixel_shuffle else 1) - * (4 if self._conv_merging else 1) - ) - # When the number of available token is too small, allow self._min_num_patches - # per media and let the sample be truncated. - num_tokens_available = max( - num_tokens_available, self._min_num_patches * len(media_list) - ) - - # Clip the number of tokens available per media to be between min and max - # patches. - num_tokens_available_per_media = [ - max(min(num_tokens_available, self._max_num_patches), self._min_num_patches) - for _ in range(len(media_list)) - ] - - # In theory this could be a while True loop, but in case the process_media - # method slightly - # changes, I want to make sure we don't get stuck in an infinite loop. - for _ in range(10): - # Step 1: Process each media with current token budget - params = [] - token_counts = [] - - for media, tokens_for_media in zip( - media_list, num_tokens_available_per_media - ): - param, token_count = self.process_media( - media, tokens_for_media, data_augment=data_augment - ) - params.append(param) - token_counts.append(token_count) - - # Step 2: Check if total tokens is within budget - total_tokens = sum(token_counts) - - if total_tokens <= num_tokens_available: - # We're within budget, return the params - return params - - # Step 3: We're over budget, need to scale down - # Calculate scaling factor to get under budget - scaling_factor = num_tokens_available / total_tokens - - # Recalculate token budgets for each media based on scaling - # Each media gets a proportional share of the total budget - scaled_down_num_tokens_available_per_media = [ - max(self._min_num_patches, int(token_count * scaling_factor)) - for token_count in token_counts - ] - scaled_down = any( - [ - scaled_down_num_tokens_available_per_media[i] - < num_tokens_available_per_media[i] - for i in range(len(num_tokens_available_per_media)) - ] - ) - # If there was not scaling down, we're stuck just use min_num_patches per - # media, else try with the scaled down num_tokens_available_per_media. - if not scaled_down: - num_tokens_available_per_media = [self._min_num_patches] * len( - media_list - ) - else: - num_tokens_available_per_media = ( - scaled_down_num_tokens_available_per_media - ) - return params - - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: - imgs_sizes = torch.tensor( - [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 - ) - - def rearrange_img(x): - py = x.shape[-2] // self._patch_size - px = x.shape[-1] // self._patch_size - x = einops.rearrange( - x, - "c (py yy) (px xx) -> (py px) (c yy xx)", - py=py, - yy=self._patch_size, - px=px, - xx=self._patch_size, - ) - return x - - if len(images) > 0: - imgs = [rearrange_img(img) for img in images] - - current_length = 0 - max_length = 0 - vision_cu_lengths = [0] - for img in imgs: - if max_length < img.shape[0]: - max_length = img.shape[0] - current_length += img.shape[0] - vision_cu_lengths.append(current_length) - - vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) - vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) - - return ( - torch.cat(imgs, dim=0).unsqueeze(0), - imgs_sizes, - vision_cu_lengths, - vision_max_lengths, - ) - else: - return ( - torch.tensor([[0]], dtype=torch.float32), - torch.tensor([[0, 0]], dtype=torch.int32), - None, - None, - ) - - def __str__(self): - return f"DynamicResolutionImageTransform(\ - vision_model_type={self._vision_model_type}, \ - min_num_patches={self._min_num_patches}, \ - patch_size={self._patch_size}, \ - pixel_shuffle={self._pixel_shuffle}, \ - conv_merging={self._conv_merging}, \ - use_thumbnail={self._use_thumbnail}, \ - thumbnail_size={self._thumbnail_size}, \ - thumbnail_area_threshold={self._thumbnail_area_threshold})" - - -image_tiling_strategy = DynamicResolutionImageTilingStrategy( - vision_model_type="radio", - min_num_patches=4, - patch_size=16, - get_num_embeddings=lambda x, y: x * y * 2, - max_num_patches=64, -) +# TODO(nhaber): get 2048 from config +# TODO(nhaber): does use_thumbnail=True work? IMG_START = "" @@ -753,7 +261,12 @@ def video_to_pixel_values( return torch.stack(frames_tensors) -def input_conditioner(x, norm_mean, norm_std): +def input_conditioner( + x: torch.Tensor, norm_mean: torch.Tensor, norm_std: torch.Tensor +) -> torch.Tensor: + assert isinstance(x, torch.Tensor), "x must be a tensor" + assert isinstance(norm_mean, torch.Tensor), "norm_mean must be a tensor" + assert isinstance(norm_std, torch.Tensor), "norm_std must be a tensor" return (x - norm_mean) / norm_std @@ -792,15 +305,20 @@ class BaseNanoNemotronVLProcessor(ABC): self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES image_size: int = config.force_image_size - patch_size: int = config.patch_size + self.patch_size: int = getattr(config, "patch_size", 16) + self.downsample_ratio: float = self.config.downsample_ratio - self.num_image_token = int( - (image_size // patch_size) ** 2 * (config.downsample_ratio**2) - ) self.image_size = image_size self.use_thumbnail: bool = config.use_thumbnail - self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1) - self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1) + self.norm_mean = torch.tensor(config.norm_mean).reshape(1, 3, 1, 1) + self.norm_std = torch.tensor(config.norm_std).reshape(1, 3, 1, 1) + + def num_image_token(self, *, image_width: int, image_height: int) -> int: + image_size = math.sqrt(image_width * image_height) + num_tokens = int( + (image_size // self.patch_size) ** 2 * (self.downsample_ratio**2) + ) + return num_tokens @property @abstractmethod @@ -832,10 +350,13 @@ class BaseNanoNemotronVLProcessor(ABC): use_thumbnail=self.use_thumbnail, ) - return num_patches * self.num_image_token + return num_patches * self.num_image_token( + image_width=image_width, image_height=image_height + ) def _images_to_pixel_values_lst( self, + text: list[str], images: list[Image.Image], max_num_tiles: int, ) -> list[torch.Tensor]: @@ -859,7 +380,9 @@ class BaseNanoNemotronVLProcessor(ABC): if len(images) == 0: image_inputs = {} else: - pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles) + pixel_values_lst = self._images_to_pixel_values_lst( + text=text, images=images, max_num_tiles=max_num_tiles + ) image_inputs = { "pixel_values_flat": input_conditioner( torch.cat(pixel_values_lst), self.norm_mean, self.norm_std @@ -881,7 +404,10 @@ class BaseNanoNemotronVLProcessor(ABC): for i, pixel_values in enumerate(pixel_values_lst): num_patches = pixel_values.shape[0] - feature_size = num_patches * self.num_image_token + feature_size = num_patches * self.num_image_token( + image_width=pixel_values.shape[1], + image_height=pixel_values.shape[2], + ) image_repl = self.get_image_repl(feature_size, num_patches) parts[i] = parts[i].replace("", image_repl.full) text = ["".join(parts)] @@ -894,6 +420,7 @@ class BaseNanoNemotronVLProcessor(ABC): input_item = [input_item] return input_item + @abstractmethod def __call__( self, text: str | list[str] | None = None, @@ -901,26 +428,487 @@ class BaseNanoNemotronVLProcessor(ABC): return_tensors: str | TensorType | None = None, max_num_tiles: int | None = None, ) -> BatchFeature: - # Use default if not provided - if max_num_tiles is None: - max_num_tiles = self.max_num_tiles + raise NotImplementedError - text, images = [self._make_batch_input(x) for x in (text, images)] - text, image_inputs = self._preprocess_image( - text=text, - images=images, - max_num_tiles=max_num_tiles, +@dataclass +class DynamicResolutionParams: + media: Image.Image + num_tiles: int + num_embeddings: int + patch_size: tuple[int, int] + + +class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): + CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] + CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] + + def __init__( + self, + config: PretrainedConfig, + tokenizer: TokenizerLike, + *args, + max_num_tiles: int | None = None, + min_num_patches: int = 4, + factor_max: float = 1.0, + pixel_shuffle: bool = True, + min_side: int | None = None, + conv_merging: bool = False, + use_thumbnail: bool = False, + thumbnail_size: int = 448, + thumbnail_area_threshold: float = 0.8, + apply_data_augment: bool = False, + **kwargs, + ) -> None: + super().__init__( + config=config, tokenizer=tokenizer, max_num_tiles=max_num_tiles, **kwargs + ) + self._min_num_patches = min_num_patches + self._factor_max = factor_max + self._pixel_shuffle = pixel_shuffle + self._min_side = min_side + self._conv_merging = conv_merging + self._use_thumbnail = use_thumbnail + self._thumbnail_size = thumbnail_size + self._thumbnail_area_threshold = thumbnail_area_threshold + self._transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)), + ] + ) + self._apply_data_augment = apply_data_augment + + self.norm_mean = torch.tensor(self.CLIP_PIXEL_MEAN).reshape(1, 3, 1, 1) + self.norm_std = torch.tensor(self.CLIP_PIXEL_STD).reshape(1, 3, 1, 1) + self.downsample_ratio = 2 if pixel_shuffle else 1 + + def apply_params(self, params: DynamicResolutionParams) -> torch.Tensor: + resized_img = params.media.resize( + ( + params.patch_size[0] * self.patch_size, + params.patch_size[1] * self.patch_size, + ) + ) + # processed_images = [resized_img] + + # # Add thumbnail if enabled and image area is below threshold + # if self._use_thumbnail: + # # Calculate areas + # resized_area = resized_img.size[0] * resized_img.size[1] + # thumbnail_area = self._thumbnail_size * self._thumbnail_size + # area_ratio = resized_area / thumbnail_area + + # # Only add thumbnail if resized image area is less than threshold % of + # # thumbnail area + # if area_ratio < self._thumbnail_area_threshold: + # thumbnail_img = params.media.resize( + # (self._thumbnail_size, self._thumbnail_size) + # ) + # processed_images.append(thumbnail_img) + + return self._transform(resized_img) + + def process_media( + self, + media: Image.Image, + num_tokens_available: int, + data_augment: bool = False, + tiling_augment_prob: float = 0.4, + ) -> DynamicResolutionParams: + """Process a single media item and return its parameters. + Args: + media: The media item to process + num_tokens_available: Number of tokens available for this media + data_augment: Whether to apply data augmentation to the image. Defaults to + False. + Returns: + DynamicResolutionParams for the media + """ + current_num_tokens_available = num_tokens_available + assert isinstance(media, Image.Image), ( + "Dynamic resolution is only supported for image media" + ) + orig_width, orig_height = media.width, media.height + + closest_patch_height = round(orig_height / self.patch_size + 0.5) + closest_patch_width = round(orig_width / self.patch_size + 0.5) + patches = closest_patch_height * closest_patch_width + + factor = min( + math.sqrt(current_num_tokens_available / patches), self._factor_max + ) + target_patch_height = math.floor(factor * closest_patch_height) + target_patch_width = math.floor(factor * closest_patch_width) + + # We only consider self._min_num_patches if it is greater than + # current_num_tokens_available. + if ( + current_num_tokens_available > self._min_num_patches + and target_patch_height * target_patch_width < self._min_num_patches + ): + up_factor = math.sqrt( + self._min_num_patches / (target_patch_height * target_patch_width) + ) + target_patch_height = math.ceil(up_factor * target_patch_height) + target_patch_width = math.ceil(up_factor * target_patch_width) + + if ( + self._min_side is not None + and min(target_patch_width, target_patch_height) * self.patch_size + < self._min_side + ): + if target_patch_width <= target_patch_height: + up_factor = self._min_side / (target_patch_width * self.patch_size) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > current_num_tokens_available: + # If only one side can be min_side, make as big as possible at + # native aspect ratio while staying below max_patches + if ( + max(current_num_tokens_available // new_patch_width, 1) + * self.patch_size + < self._min_side + ): + up_factor = math.sqrt( + current_num_tokens_available + / (target_patch_height * target_patch_width) + ) + target_patch_height = math.floor( + up_factor * target_patch_height + ) + target_patch_width = math.floor(up_factor * target_patch_width) + target_patch_width = new_patch_width + target_patch_height = max( + current_num_tokens_available // new_patch_width, 1 + ) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width + else: + up_factor = self._min_side / (target_patch_height * self.patch_size) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > current_num_tokens_available: + # If only one side can be min_side, make as big as possible at + # native aspect ratio while staying below max_patches + if ( + max(current_num_tokens_available // new_patch_height, 1) + * self.patch_size + < self._min_side + ): + up_factor = math.sqrt( + current_num_tokens_available + / (target_patch_height * target_patch_width) + ) + target_patch_height = math.floor( + up_factor * target_patch_height + ) + target_patch_width = math.floor(up_factor * target_patch_width) + else: + target_patch_height = new_patch_height + target_patch_width = max( + current_num_tokens_available // new_patch_height, 1 + ) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width + + # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging) + # or by 4 when BOTH are enabled (two successive 2x reductions) + if self._pixel_shuffle or self._conv_merging: + required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2 + + rem_h = target_patch_height % required_divisor + if rem_h != 0: + inc_h = required_divisor - rem_h + if ( + target_patch_height + inc_h + ) * target_patch_width <= current_num_tokens_available: + target_patch_height += inc_h + else: + target_patch_height = max( + required_divisor, target_patch_height - rem_h + ) + + rem_w = target_patch_width % required_divisor + if rem_w != 0: + inc_w = required_divisor - rem_w + if ( + target_patch_height * (target_patch_width + inc_w) + <= current_num_tokens_available + ): + target_patch_width += inc_w + else: + target_patch_width = max( + required_divisor, target_patch_width - rem_w + ) + + if ( + data_augment + and self._apply_data_augment + and random.random() < tiling_augment_prob + ): + target_patch_width, target_patch_height = self.augment_resolution( + target_patch_width, target_patch_height, current_num_tokens_available + ) + + assert isinstance(media, Image.Image), ( + "Dynamic resolution is only supported for image media" ) - text_inputs = self.tokenizer(text, add_special_tokens=False) + # Calculate embeddings for the main dynamic resolution image + num_embeddings = self.num_image_token( + image_width=target_patch_width, image_height=target_patch_height + ) - combined_outputs = {**text_inputs, **image_inputs} + token_count = target_patch_width * target_patch_height - return BatchFeature(combined_outputs, tensor_type=return_tensors) + # Add thumbnail embeddings if enabled and image area is below threshold + num_tiles = 1 # Base dynamic resolution image + if self._use_thumbnail: + # Calculate areas + resized_area = (target_patch_width * self.patch_size) * ( + target_patch_height * self.patch_size + ) + thumbnail_area = self._thumbnail_size * self._thumbnail_size + area_ratio = resized_area / thumbnail_area + + # Only add thumbnail if resized image area is less than threshold % of + # thumbnail area + if area_ratio < self._thumbnail_area_threshold: + num_tiles += 1 # Add 1 for thumbnail + # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) + num_embeddings += self.num_image_token( + image_width=self._thumbnail_size, image_height=self._thumbnail_size + ) + token_count += ( + self._thumbnail_size + // self.patch_size + * self._thumbnail_size + // self.patch_size + ) + + return DynamicResolutionParams( + media=media, + num_tiles=num_tiles, + num_embeddings=num_embeddings, + patch_size=(target_patch_width, target_patch_height), + ), token_count + + def augment_resolution( + self, + target_patch_width: int, + target_patch_height: int, + current_num_tokens_available: int, + ) -> tuple[int, int]: + min_num_patch_one_side = 32 + + if random.random() < 0.5: + # Minus one + if ( + target_patch_width <= min_num_patch_one_side + and target_patch_height <= min_num_patch_one_side + ): + return target_patch_width, target_patch_height + elif target_patch_width <= min_num_patch_one_side: + return target_patch_width, target_patch_height - min_num_patch_one_side + elif target_patch_height <= min_num_patch_one_side: + return target_patch_width - min_num_patch_one_side, target_patch_height + else: + if random.random() < 0.5: + return ( + target_patch_width - min_num_patch_one_side, + target_patch_height, + ) + else: + return ( + target_patch_width, + target_patch_height - min_num_patch_one_side, + ) + else: + # Plus one + if target_patch_width * target_patch_height < current_num_tokens_available: + if random.random() < 0.5: + return ( + target_patch_width + min_num_patch_one_side, + target_patch_height, + ) + else: + return ( + target_patch_width, + target_patch_height + min_num_patch_one_side, + ) + return target_patch_width, target_patch_height + + def compute_params( + self, + media_list: list[Image.Image], + num_tokens_available: int | None = None, + data_augment: bool = False, + ) -> list[DynamicResolutionParams]: + """Compute parameters for all media with iterative token budgeting. + + Args: + media_list: List of media items to process + num_tokens_available: Total number of tokens available across all media + data_augment: Whether to apply data augmentation to the image. Defaults to + False. + Returns: + List of ImageTilingParams for each media item + """ + num_tokens_available = ( + num_tokens_available + * (4 if self._pixel_shuffle else 1) + * (4 if self._conv_merging else 1) + ) + # When the number of available token is too small, allow self._min_num_patches + # per media and let the sample be truncated. + num_tokens_available = max( + num_tokens_available, self._min_num_patches * len(media_list) + ) + + # Clip the number of tokens available per media to be between min and max + # patches. + num_tokens_available_per_media = [ + max(num_tokens_available, self._min_num_patches) + for _ in range(len(media_list)) + ] + + # In theory this could be a while True loop, but in case the process_media + # method slightly + # changes, I want to make sure we don't get stuck in an infinite loop. + for _ in range(10): + # Step 1: Process each media with current token budget + params = [] + token_counts = [] + + for media, tokens_for_media in zip( + media_list, num_tokens_available_per_media + ): + param, token_count = self.process_media( + media, tokens_for_media, data_augment=data_augment + ) + params.append(param) + token_counts.append(token_count) + + # Step 2: Check if total tokens is within budget + total_tokens = sum(token_counts) + + if total_tokens <= num_tokens_available: + # We're within budget, return the params + return params + + # Step 3: We're over budget, need to scale down + # Calculate scaling factor to get under budget + scaling_factor = num_tokens_available / total_tokens + + # Recalculate token budgets for each media based on scaling + # Each media gets a proportional share of the total budget + scaled_down_num_tokens_available_per_media = [ + max(self._min_num_patches, int(token_count * scaling_factor)) + for token_count in token_counts + ] + scaled_down = any( + [ + scaled_down_num_tokens_available_per_media[i] + < num_tokens_available_per_media[i] + for i in range(len(num_tokens_available_per_media)) + ] + ) + # If there was not scaling down, we're stuck just use min_num_patches per + # media, else try with the scaled down num_tokens_available_per_media. + if not scaled_down: + num_tokens_available_per_media = [self._min_num_patches] * len( + media_list + ) + else: + num_tokens_available_per_media = ( + scaled_down_num_tokens_available_per_media + ) + return params + + def stack( + self, images: list[torch.Tensor] + ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: + imgs_sizes = torch.tensor( + [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 + ) + + def rearrange_img(x): + py = x.shape[-2] // self.patch_size + px = x.shape[-1] // self.patch_size + x = einops.rearrange( + x, + "c (py yy) (px xx) -> (py px) (c yy xx)", + py=py, + yy=self.patch_size, + px=px, + xx=self.patch_size, + ) + return x + + if len(images) > 0: + imgs = [rearrange_img(img) for img in images] + + current_length = 0 + max_length = 0 + vision_cu_lengths = [0] + for img in imgs: + if max_length < img.shape[0]: + max_length = img.shape[0] + current_length += img.shape[0] + vision_cu_lengths.append(current_length) + + vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) + vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) + + return ( + torch.cat(imgs, dim=0).unsqueeze(0), + imgs_sizes, + vision_cu_lengths, + vision_max_lengths, + ) + else: + return ( + torch.tensor([[0]], dtype=torch.float32), + torch.tensor([[0, 0]], dtype=torch.int32), + None, + None, + ) + + def _images_to_pixel_values_lst( + self, + text: list[str], + images: list[Image.Image], + max_num_tiles: int, + ) -> list[torch.Tensor]: + num_tokens_available = 2048 - len(text) - 4 + params_per_image = self.compute_params( + images, num_tokens_available=num_tokens_available + ) + images = [] + for param in params_per_image: + t = self.apply_params(param) + if t.ndim == 3: + t = t.unsqueeze(0) + images.append(t) + return images + + def __str__(self): + return f"DynamicResolutionImageTransform(\ + min_num_patches={self._min_num_patches}, \ + patch_size={self.patch_size}, \ + pixel_shuffle={self._pixel_shuffle}, \ + conv_merging={self._conv_merging}, \ + use_thumbnail={self._use_thumbnail}, \ + thumbnail_size={self._thumbnail_size}, \ + thumbnail_area_threshold={self._thumbnail_area_threshold})" -class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): +class NanoNemotronVLProcessor(DynamicResolutionImageTiler): """ HF Processor with extended video processing logic. Code for video processing is adapted from video example: @@ -1312,7 +1300,9 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): processor = self.get_hf_processor() # we get the CustomProcessor here max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token + max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token( + image_width=256, image_height=256 + ) # TODO(nhaber): get 256 dynamically max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) @@ -1457,7 +1447,9 @@ class NanoNemotronVLMultiModalProcessor( video_num_patches = [] def get_video_replacement_internvl(item_idx: int): - feature_size = hf_processor.num_image_token + feature_size = hf_processor.num_image_token( + image_width=256, image_height=256 + ) # TODO(nhaber): get 256 dynamically video, metadata = mm_items["video"][item_idx] num_patches = video_num_patches[item_idx] if num_patches is not None: @@ -1633,9 +1625,6 @@ class NemotronH_Nano_VL_V2( patch_size = config.patch_size self.patch_size = patch_size self.template = config.template - self.num_image_token = int( - (image_size // patch_size) ** 2 * (config.downsample_ratio**2) - ) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version self.image_tag_type = config.image_tag_type @@ -2153,33 +2142,6 @@ class NemotronH_Nano_VL_V2( if save_to_file and sys.stdout != original_stdout: sys.stdout = original_stdout - def get_model_info(self): - """ - Get basic model information as a dictionary. - """ - total_params = sum(p.numel() for p in self.parameters()) - - component_info = {} - for name, param in self.named_parameters(): - component = name.split(".")[0] - if component not in component_info: - component_info[component] = {"params": 0, "size": 0} - component_info[component]["params"] += 1 - component_info[component]["size"] += param.numel() - - return { - "model_name": "NemotronH_Nano_VL_V2", - "total_parameters": total_params, - "memory_estimate_mb": total_params * 2 / (1024**2), # bfloat16 - "components": component_info, - "config": { - "image_size": getattr(self.config, "force_image_size", None), - "patch_size": getattr(self.config, "patch_size", None), - "num_image_token": self.num_image_token, - "downsample_ratio": self.downsample_ratio, - }, - } - def get_vit_model_from_radio_config(self, hf_config): hf_config_vision = hf_config.vision_config model_name = hf_config_vision.args.get("model") From 52e5e55a19436300877a36c518ac390088973c01 Mon Sep 17 00:00:00 2001 From: Netanel Haber Date: Mon, 22 Dec 2025 03:41:37 -0800 Subject: [PATCH 07/10] it runs at least :shrug: --- .../model_executor/models/nano_nemotron_vl.py | 308 ++++++++++-------- 1 file changed, 171 insertions(+), 137 deletions(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 3bac36744ef5e..ad5a57c511fe8 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -10,7 +10,6 @@ import copy import math import random -import warnings from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass @@ -24,6 +23,7 @@ import torch.nn as nn import torchvision.transforms as T from PIL import Image from transformers import BatchFeature, PretrainedConfig, TensorType +from typing_extensions import assert_never from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions @@ -62,7 +62,6 @@ from vllm.multimodal.inputs import ( from vllm.multimodal.parse import ( ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems, MultiModalDataParser, ) @@ -91,6 +90,7 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely # TODO(nhaber): get 2048 from config # TODO(nhaber): does use_thumbnail=True work? +# TODO(nhaber): mixing images and videos will mess up the "text_prompt_length" calculation. IMG_START = "" @@ -102,6 +102,46 @@ IMG_CONTEXT = "" DEFAULT_NUM_TILES = 12 +@dataclass(kw_only=True, frozen=True) +class Dims: + height: int + width: int + + +CONV_MERGING = False # This is assumed to be False for now +PIXEL_SHUFFLE = True # This is assumed to be True for now +REDUCTION_FACTOR = 2 ** (PIXEL_SHUFFLE + CONV_MERGING) + +def width_and_height_for_max_num_tokens_available( + *, + target_num_tokens_post_shuffle: int, + patch_size: int, +) -> Dims: + """ + TODO(nhaber): optimize this so it squeezes closer to target number of tokens. + Calculate image dimensions that produce approximately `target` tokens after + pixel_shuffle. + + With pixel_shuffle enabled, each 2x2 patch grid becomes 1 token, so we + need 4*B patches to get B tokens. + + Examples: + >>> dims = width_and_height_for_max_num_tokens_available(B=8192, patch_size=16) + >>> assert dims.width, dims.height == (2880, 2880) + >>> assert ((dims.width // 16) * (dims.height // 16) // 4) == 8100 # tokens after shuffle + """ + side_pixels = math.isqrt(target_num_tokens_post_shuffle) * REDUCTION_FACTOR * patch_size + assert isinstance(side_pixels, int) and side_pixels % patch_size == 0 + return Dims(width=side_pixels, height=side_pixels) + +@dataclass +class DynamicResolutionParams: + media: Image.Image + num_tiles: int + num_embeddings: int + patch_size: tuple[int, int] + + class NanoNemotronVLImagePixelInputs(TensorSchema): """ Dimensions: @@ -313,10 +353,10 @@ class BaseNanoNemotronVLProcessor(ABC): self.norm_mean = torch.tensor(config.norm_mean).reshape(1, 3, 1, 1) self.norm_std = torch.tensor(config.norm_std).reshape(1, 3, 1, 1) - def num_image_token(self, *, image_width: int, image_height: int) -> int: - image_size = math.sqrt(image_width * image_height) + def num_image_token_per_tile(self, *, tile_width: int, tile_height: int) -> int: + tile_size = math.sqrt(tile_width * tile_height) num_tokens = int( - (image_size // self.patch_size) ** 2 * (self.downsample_ratio**2) + (tile_size // self.patch_size) ** 2 * (self.downsample_ratio**2) ) return num_tokens @@ -342,7 +382,7 @@ class BaseNanoNemotronVLProcessor(ABC): ) -> int: target_ratios = get_internvl_target_ratios(1, max_num_tiles) - num_patches, _, _ = calculate_internvl_targets( + num_tiles, _, _ = calculate_internvl_targets( orig_width=image_width, orig_height=image_height, target_ratios=target_ratios, @@ -350,16 +390,16 @@ class BaseNanoNemotronVLProcessor(ABC): use_thumbnail=self.use_thumbnail, ) - return num_patches * self.num_image_token( - image_width=image_width, image_height=image_height + return num_tiles * self.num_image_token_per_tile( + tile_width=image_width, tile_height=image_height ) def _images_to_pixel_values_lst( self, - text: list[str], + text_prompt_length: int, images: list[Image.Image], max_num_tiles: int, - ) -> list[torch.Tensor]: + ) -> tuple[list[torch.Tensor], list[int]]: return [ image_to_pixel_values( image, @@ -380,8 +420,19 @@ class BaseNanoNemotronVLProcessor(ABC): if len(images) == 0: image_inputs = {} else: - pixel_values_lst = self._images_to_pixel_values_lst( - text=text, images=images, max_num_tiles=max_num_tiles + assert len(text) == 1, ( + "hf_processor is called on the output of get_dummy_text, " + "which should be a single string" + ) + text_prompt_length = len( + self.tokenizer( + text[0].replace("", ""), add_special_tokens=False + )["input_ids"] + ) + pixel_values_lst, token_counts = self._images_to_pixel_values_lst( + text_prompt_length=text_prompt_length, + images=images, + max_num_tiles=max_num_tiles, ) image_inputs = { "pixel_values_flat": input_conditioner( @@ -402,12 +453,10 @@ class BaseNanoNemotronVLProcessor(ABC): "same as the number of images" ) - for i, pixel_values in enumerate(pixel_values_lst): + for i, (pixel_values, feature_size) in enumerate( + zip(pixel_values_lst, token_counts, strict=True) + ): num_patches = pixel_values.shape[0] - feature_size = num_patches * self.num_image_token( - image_width=pixel_values.shape[1], - image_height=pixel_values.shape[2], - ) image_repl = self.get_image_repl(feature_size, num_patches) parts[i] = parts[i].replace("", image_repl.full) text = ["".join(parts)] @@ -431,14 +480,6 @@ class BaseNanoNemotronVLProcessor(ABC): raise NotImplementedError -@dataclass -class DynamicResolutionParams: - media: Image.Image - num_tiles: int - num_embeddings: int - patch_size: tuple[int, int] - - class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] @@ -448,6 +489,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): config: PretrainedConfig, tokenizer: TokenizerLike, *args, + max_model_len: int, max_num_tiles: int | None = None, min_num_patches: int = 4, factor_max: float = 1.0, @@ -463,6 +505,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): super().__init__( config=config, tokenizer=tokenizer, max_num_tiles=max_num_tiles, **kwargs ) + self.max_model_len = max_model_len + self._min_num_patches = min_num_patches self._factor_max = factor_max self._pixel_shuffle = pixel_shuffle @@ -483,6 +527,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): self.norm_std = torch.tensor(self.CLIP_PIXEL_STD).reshape(1, 3, 1, 1) self.downsample_ratio = 2 if pixel_shuffle else 1 + feature_size_cache: dict[ + Image.Image, int + ] = {} # TODO(nhaber): Find a less silly way of doing this... Why can't this be a class variable? + def apply_params(self, params: DynamicResolutionParams) -> torch.Tensor: resized_img = params.media.resize( ( @@ -515,7 +563,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): num_tokens_available: int, data_augment: bool = False, tiling_augment_prob: float = 0.4, - ) -> DynamicResolutionParams: + ) -> tuple[DynamicResolutionParams, int]: """Process a single media item and return its parameters. Args: media: The media item to process @@ -531,8 +579,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): ) orig_width, orig_height = media.width, media.height - closest_patch_height = round(orig_height / self.patch_size + 0.5) - closest_patch_width = round(orig_width / self.patch_size + 0.5) + closest_patch_height = math.ceil( + orig_height / self.patch_size + ) # TODO(nhaber): Ask Tyler - the previous round + 0.5 code is dangerous [banker's rounding], no? If we flip this back to the round, the max_wh_fill_budget needs to do -1 for each of w;h to be safe + closest_patch_width = math.ceil(orig_width / self.patch_size) patches = closest_patch_height * closest_patch_width factor = min( @@ -660,8 +710,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): ) # Calculate embeddings for the main dynamic resolution image - num_embeddings = self.num_image_token( - image_width=target_patch_width, image_height=target_patch_height + num_embeddings_per_tile = self.num_image_token_per_tile( + tile_width=target_patch_width, tile_height=target_patch_height ) token_count = target_patch_width * target_patch_height @@ -681,8 +731,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): if area_ratio < self._thumbnail_area_threshold: num_tiles += 1 # Add 1 for thumbnail # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) - num_embeddings += self.num_image_token( - image_width=self._thumbnail_size, image_height=self._thumbnail_size + num_embeddings += self.num_image_token_per_tile( + tile_width=self._thumbnail_size, tile_height=self._thumbnail_size ) token_count += ( self._thumbnail_size @@ -694,7 +744,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): return DynamicResolutionParams( media=media, num_tiles=num_tiles, - num_embeddings=num_embeddings, + num_embeddings=num_embeddings_per_tile, patch_size=(target_patch_width, target_patch_height), ), token_count @@ -748,7 +798,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): media_list: list[Image.Image], num_tokens_available: int | None = None, data_augment: bool = False, - ) -> list[DynamicResolutionParams]: + ) -> tuple[list[DynamicResolutionParams], list[int]]: """Compute parameters for all media with iterative token budgeting. Args: @@ -782,8 +832,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): # changes, I want to make sure we don't get stuck in an infinite loop. for _ in range(10): # Step 1: Process each media with current token budget - params = [] - token_counts = [] + params: list[DynamicResolutionParams] = [] + token_counts: list[int] = [] for media, tokens_for_media in zip( media_list, num_tokens_available_per_media @@ -799,7 +849,12 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): if total_tokens <= num_tokens_available: # We're within budget, return the params - return params + # Convert from patch count to actual token count after downsampling + divisor = (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1) + adjusted_token_counts = [tc // divisor for tc in token_counts] + for param, feature_size in zip(params, adjusted_token_counts, strict=True): + self.feature_size_cache[id(param.media)] = feature_size + return params, adjusted_token_counts # Step 3: We're over budget, need to scale down # Calculate scaling factor to get under budget @@ -828,7 +883,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): num_tokens_available_per_media = ( scaled_down_num_tokens_available_per_media ) - return params + assert_never(num_tokens_available_per_media) def stack( self, images: list[torch.Tensor] @@ -879,15 +934,18 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): None, ) + def max_num_tokens_available(self, text_prompt_length: int) -> int: + return self.max_model_len - text_prompt_length - 4 + def _images_to_pixel_values_lst( self, - text: list[str], + text_prompt_length: int, images: list[Image.Image], max_num_tiles: int, - ) -> list[torch.Tensor]: - num_tokens_available = 2048 - len(text) - 4 - params_per_image = self.compute_params( - images, num_tokens_available=num_tokens_available + ) -> tuple[list[torch.Tensor], list[int]]: + num_tokens_available = self.max_num_tokens_available(text_prompt_length) + params_per_image, feature_sizes = self.compute_params( + images, num_tokens_available ) images = [] for param in params_per_image: @@ -895,17 +953,12 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): if t.ndim == 3: t = t.unsqueeze(0) images.append(t) - return images + return images, feature_sizes - def __str__(self): - return f"DynamicResolutionImageTransform(\ - min_num_patches={self._min_num_patches}, \ - patch_size={self.patch_size}, \ - pixel_shuffle={self._pixel_shuffle}, \ - conv_merging={self._conv_merging}, \ - use_thumbnail={self._use_thumbnail}, \ - thumbnail_size={self._thumbnail_size}, \ - thumbnail_area_threshold={self._thumbnail_area_threshold})" + def get_cached_feature_size(self, image: Image.Image) -> int: + feature_size = self.feature_size_cache[id(image)] + del self.feature_size_cache[id(image)] + return feature_size class NanoNemotronVLProcessor(DynamicResolutionImageTiler): @@ -920,6 +973,7 @@ class NanoNemotronVLProcessor(DynamicResolutionImageTiler): config: PretrainedConfig, tokenizer: TokenizerLike, *, + max_model_len: int, max_num_tiles: int | None = None, min_dynamic_patch: int | None = None, max_dynamic_patch: int | None = None, @@ -930,6 +984,7 @@ class NanoNemotronVLProcessor(DynamicResolutionImageTiler): super().__init__( config=config, tokenizer=tokenizer, + max_model_len=max_model_len, max_num_tiles=max_num_tiles, min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, @@ -1205,7 +1260,7 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): def get_hf_processor( self, **kwargs: object, - ) -> BaseNanoNemotronVLProcessor: + ) -> DynamicResolutionImageTiler: raise NotImplementedError def get_supported_mm_limits(self) -> Mapping[str, int | None]: @@ -1228,31 +1283,6 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): max_num_tiles=max_num_tiles, ) - def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize: - processor = self.get_hf_processor() - - base_size = processor.image_size - target_ratios = get_internvl_target_ratios(1, max_num_tiles) - - largest_feature_size, largest_feature_pinpoint = 0, None - for wr, hr in target_ratios: - width, height = base_size * wr, base_size * hr - - feat_size = self.get_num_image_tokens( - image_width=width, - image_height=height, - max_num_tiles=max_num_tiles, - processor=processor, - ) - if feat_size > largest_feature_size: - largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, height=height) - - if largest_feature_size == 0 or largest_feature_pinpoint is None: - raise ValueError("Cannot have a largest feature size of 0!") - - return largest_feature_pinpoint - def get_max_image_tokens(self) -> int: processor = self.get_hf_processor() # Use default max_num_tiles for max tokens calculation @@ -1277,7 +1307,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): @property def supports_video(self): - return self.get_hf_processor().supports_video + return False # TODO(nhaber): add video support def get_supported_mm_limits(self): video_limit = {"video": None} if self.supports_video else {} @@ -1300,8 +1330,10 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): processor = self.get_hf_processor() # we get the CustomProcessor here max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token( - image_width=256, image_height=256 + max_total_frames = ( + seq_len - max_image_tokens + ) // processor.num_image_token_per_tile( + tile_width=256, tile_height=256 ) # TODO(nhaber): get 256 dynamically max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) @@ -1313,6 +1345,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): tokenizer=self.get_tokenizer(), video_token=self.get_video_token(), video_pruning_rate=self.get_video_pruning_rate(), + max_model_len=self.ctx.model_config.max_model_len, **kwargs, ) @@ -1362,17 +1395,8 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) else: - image_size = images.get_image_size(item_idx) - # Extract max_num_tiles from kwargs, default to 12 - max_num_tiles = hf_processor_mm_kwargs.get( - "max_num_tiles", hf_processor.max_num_tiles - ) - feature_size = self.info.get_num_image_tokens( - image_width=image_size.width, - image_height=image_size.height, - max_num_tiles=max_num_tiles, - processor=hf_processor, - ) + image = images.get(item_idx) + feature_size = hf_processor.get_cached_feature_size(image) num_patches = None local_image_num_patches = image_num_patches @@ -1447,8 +1471,8 @@ class NanoNemotronVLMultiModalProcessor( video_num_patches = [] def get_video_replacement_internvl(item_idx: int): - feature_size = hf_processor.num_image_token( - image_width=256, image_height=256 + feature_size = hf_processor.num_image_token_per_tile( + tile_width=256, tile_height=256 ) # TODO(nhaber): get 256 dynamically video, metadata = mm_items["video"][item_idx] num_patches = video_num_patches[item_idx] @@ -1510,19 +1534,20 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: - # Use default max_num_tiles for dummy data generation - max_num_tiles = 12 - target_width, target_height = self.info.get_image_size_with_most_features( - max_num_tiles - ) num_images = mm_counts.get("image", 0) + processor = self.info.get_hf_processor() + B = processor.max_num_tokens_available(text_prompt_length=num_images) + target_dims = width_and_height_for_max_num_tokens_available( + target_num_tokens_post_shuffle=B, + patch_size=processor.patch_size, + ) image_overrides = mm_options.get("image") if mm_options else None return { "image": self._get_dummy_images( - width=target_width, - height=target_height, + width=target_dims.width, + height=target_dims.height, num_images=num_images, overrides=image_overrides, ) @@ -1672,33 +1697,36 @@ class NemotronH_Nano_VL_V2( IMG_CONTEXT, add_special_tokens=False ) - def pixel_shuffle(self, x, scale_factor=0.5): - n, w, h, c = x.size() - # N, W, H, C --> N, W, H * scale, C // scale - x = x.view( - n, - w, - int(h * scale_factor), - int(c / scale_factor), - ) - # N, W, H * scale, C // scale --> N, H * scale, W, C // scale - x = x.permute(0, 2, 1, 3).contiguous() - # N, H * scale, W, C // scale --> - # N, H * scale, W * scale, C // (scale ** 2) - x = x.view( - n, - int(h * scale_factor), - int(w * scale_factor), - int(c / (scale_factor * scale_factor)), - ) - if self.ps_version == "v1": - warnings.warn( - "In ps_version 'v1', the height and width have not " - "been swapped back, which results in a transposed image.", - stacklevel=2, + def pixel_shuffle_dynamic_res(self, x, *, imgs_sizes): + scale_factor = self.downsample_ratio + patch_dim = self.patch_size + seq_lens = torch.prod(imgs_sizes // patch_dim, dim=-1) + splits = torch.split(x, seq_lens.tolist(), dim=-2) + out = [] + for i, sv in enumerate(splits): + h = imgs_sizes[i][0] // patch_dim + w = imgs_sizes[i][1] // patch_dim + sv = sv.reshape(sv.shape[0], h, w, -1) + + n, h, w, c = sv.size() + + sv = sv.view(n, h, int(w * scale_factor), int(c / scale_factor)) + sv = sv.permute(0, 2, 1, 3).contiguous() + sv = sv.view( + n, + int(w * scale_factor), + int(h * scale_factor), + int(c / (scale_factor * scale_factor)), ) - else: - x = x.permute(0, 2, 1, 3).contiguous() + + if self.ps_version == "v2": + sv = sv.permute(0, 2, 1, 3).contiguous() + + sv = sv.reshape(sv.shape[0], -1, sv.shape[-1]) + out.append(sv) + + x = torch.cat(out, dim=-2) + return x def extract_feature(self, pixel_values): @@ -1710,16 +1738,22 @@ class NemotronH_Nano_VL_V2( n = pixel_values.shape[0] vit_embeds_list = [] for i in range(0, n, micro_batch_size): - vit_embeds = self.vision_model(pixel_values[i : i + micro_batch_size]) + current = pixel_values[i : i + micro_batch_size] + vit_embeds = self.vision_model(current) vit_embeds = vit_embeds.to(dtype=torch.bfloat16) - h = w = int(vit_embeds.shape[1] ** 0.5) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = self.pixel_shuffle( - vit_embeds, scale_factor=self.downsample_ratio - ) - vit_embeds = vit_embeds.reshape( - vit_embeds.shape[0], -1, vit_embeds.shape[-1] - ) + + # pixel_shuffle_dynamic_res expects patches concatenated along dim=-2, + # but vision model outputs (batch, patches, hidden). Process each image + # individually to handle this correctly. + _, _, h, w = current.shape + shuffled_embeds = [] + for j in range(vit_embeds.shape[0]): + single_embed = vit_embeds[j : j + 1] # (1, patches, hidden) + single_shuffled = self.pixel_shuffle_dynamic_res( + single_embed, imgs_sizes=torch.tensor([(h, w)]) + ) + shuffled_embeds.append(single_shuffled) + vit_embeds = torch.cat(shuffled_embeds, dim=0) vit_embeds = self.mlp1(vit_embeds) vit_embeds_list.append(vit_embeds) From 2a7ea9ba37d5c70b9913f7eae5caeac2c44ffec8 Mon Sep 17 00:00:00 2001 From: Netanel Haber Date: Mon, 22 Dec 2025 06:27:32 -0800 Subject: [PATCH 08/10] add logs --- .../model_executor/models/nano_nemotron_vl.py | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index ad5a57c511fe8..cc9a839b87ef5 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -112,6 +112,13 @@ CONV_MERGING = False # This is assumed to be False for now PIXEL_SHUFFLE = True # This is assumed to be True for now REDUCTION_FACTOR = 2 ** (PIXEL_SHUFFLE + CONV_MERGING) +def num_image_token_per_tile(*, tile_dims: Dims, patch_size: int, downsample_ratio: int) -> int: + tile_size = math.sqrt(tile_dims.width * tile_dims.height) + num_tokens = int( + (tile_size // patch_size) ** 2 * (downsample_ratio**2) + ) + return num_tokens + def width_and_height_for_max_num_tokens_available( *, target_num_tokens_post_shuffle: int, @@ -129,6 +136,7 @@ def width_and_height_for_max_num_tokens_available( >>> dims = width_and_height_for_max_num_tokens_available(B=8192, patch_size=16) >>> assert dims.width, dims.height == (2880, 2880) >>> assert ((dims.width // 16) * (dims.height // 16) // 4) == 8100 # tokens after shuffle + >>> assert num_image_token_per_tile(tile_dims=dims, patch_size=16, downsample_ratio=2) == 8100 """ side_pixels = math.isqrt(target_num_tokens_post_shuffle) * REDUCTION_FACTOR * patch_size assert isinstance(side_pixels, int) and side_pixels % patch_size == 0 @@ -353,13 +361,6 @@ class BaseNanoNemotronVLProcessor(ABC): self.norm_mean = torch.tensor(config.norm_mean).reshape(1, 3, 1, 1) self.norm_std = torch.tensor(config.norm_std).reshape(1, 3, 1, 1) - def num_image_token_per_tile(self, *, tile_width: int, tile_height: int) -> int: - tile_size = math.sqrt(tile_width * tile_height) - num_tokens = int( - (tile_size // self.patch_size) ** 2 * (self.downsample_ratio**2) - ) - return num_tokens - @property @abstractmethod def image_token_id(self) -> int: @@ -390,8 +391,10 @@ class BaseNanoNemotronVLProcessor(ABC): use_thumbnail=self.use_thumbnail, ) - return num_tiles * self.num_image_token_per_tile( - tile_width=image_width, tile_height=image_height + return num_tiles * num_image_token_per_tile( + tile_dims=Dims(width=image_width, height=image_height), + patch_size=self.patch_size, + downsample_ratio=self.downsample_ratio ) def _images_to_pixel_values_lst( @@ -710,8 +713,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): ) # Calculate embeddings for the main dynamic resolution image - num_embeddings_per_tile = self.num_image_token_per_tile( - tile_width=target_patch_width, tile_height=target_patch_height + num_embeddings_per_tile = num_image_token_per_tile( + tile_dims=Dims(width=target_patch_width, height=target_patch_height), + patch_size=self.patch_size, + downsample_ratio=self.downsample_ratio ) token_count = target_patch_width * target_patch_height @@ -731,8 +736,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): if area_ratio < self._thumbnail_area_threshold: num_tiles += 1 # Add 1 for thumbnail # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) - num_embeddings += self.num_image_token_per_tile( - tile_width=self._thumbnail_size, tile_height=self._thumbnail_size + num_embeddings_per_tile += num_image_token_per_tile( + tile_dims=Dims(width=self._thumbnail_size, height=self._thumbnail_size), + patch_size=self.patch_size, + downsample_ratio=self.downsample_ratio ) token_count += ( self._thumbnail_size @@ -947,6 +954,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): params_per_image, feature_sizes = self.compute_params( images, num_tokens_available ) + print(f"{feature_sizes=}") + print(f"{params_per_image=}") images = [] for param in params_per_image: t = self.apply_params(param) @@ -1332,8 +1341,10 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = ( seq_len - max_image_tokens - ) // processor.num_image_token_per_tile( - tile_width=256, tile_height=256 + ) // num_image_token_per_tile( + tile_dims=Dims(width=256, height=256), + patch_size=processor.patch_size, + downsample_ratio=processor.downsample_ratio ) # TODO(nhaber): get 256 dynamically max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) @@ -1471,8 +1482,10 @@ class NanoNemotronVLMultiModalProcessor( video_num_patches = [] def get_video_replacement_internvl(item_idx: int): - feature_size = hf_processor.num_image_token_per_tile( - tile_width=256, tile_height=256 + feature_size = num_image_token_per_tile( + tile_dims=Dims(width=256, height=256), + patch_size=hf_processor.patch_size, + downsample_ratio=hf_processor.downsample_ratio ) # TODO(nhaber): get 256 dynamically video, metadata = mm_items["video"][item_idx] num_patches = video_num_patches[item_idx] From eac0271b0fe3774281bdeef27c057fde6c263bc4 Mon Sep 17 00:00:00 2001 From: Netanel Haber Date: Tue, 23 Dec 2025 04:14:41 -0800 Subject: [PATCH 09/10] get num_embeddings from params + cleanup + minimize diff + ruff --- .../model_executor/models/nano_nemotron_vl.py | 298 +++++++++--------- 1 file changed, 150 insertions(+), 148 deletions(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index cc9a839b87ef5..2a41d43ab9660 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -88,7 +88,6 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely # Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels -# TODO(nhaber): get 2048 from config # TODO(nhaber): does use_thumbnail=True work? # TODO(nhaber): mixing images and videos will mess up the "text_prompt_length" calculation. @@ -102,28 +101,20 @@ IMG_CONTEXT = "" DEFAULT_NUM_TILES = 12 -@dataclass(kw_only=True, frozen=True) -class Dims: - height: int - width: int - - -CONV_MERGING = False # This is assumed to be False for now -PIXEL_SHUFFLE = True # This is assumed to be True for now -REDUCTION_FACTOR = 2 ** (PIXEL_SHUFFLE + CONV_MERGING) - -def num_image_token_per_tile(*, tile_dims: Dims, patch_size: int, downsample_ratio: int) -> int: - tile_size = math.sqrt(tile_dims.width * tile_dims.height) - num_tokens = int( - (tile_size // patch_size) ** 2 * (downsample_ratio**2) - ) +def num_image_token_per_tile( + *, width: int, height: int, patch_size: int, downsample_ratio: int +) -> int: + tile_size = math.sqrt((width // patch_size) * (height // patch_size)) + num_tokens = int(tile_size**2 // (downsample_ratio**2)) return num_tokens + def width_and_height_for_max_num_tokens_available( *, target_num_tokens_post_shuffle: int, patch_size: int, -) -> Dims: + downsample_ratio: int, +) -> tuple[int, int]: """ TODO(nhaber): optimize this so it squeezes closer to target number of tokens. Calculate image dimensions that produce approximately `target` tokens after @@ -133,14 +124,26 @@ def width_and_height_for_max_num_tokens_available( need 4*B patches to get B tokens. Examples: - >>> dims = width_and_height_for_max_num_tokens_available(B=8192, patch_size=16) - >>> assert dims.width, dims.height == (2880, 2880) - >>> assert ((dims.width // 16) * (dims.height // 16) // 4) == 8100 # tokens after shuffle - >>> assert num_image_token_per_tile(tile_dims=dims, patch_size=16, downsample_ratio=2) == 8100 + >>> width, height = width_and_height_for_max_num_tokens_available( + ... target_num_tokens_post_shuffle=8192, + ... patch_size=16, + ... downsample_ratio=2, + ... ) + >>> assert width, height == (2880, 2880) + >>> assert (width // 16) * (height // 16) // 2**2 == 8100 # tokens post-shuffle + >>> assert ( + ... num_image_token_per_tile( + ... width=width, height=height, patch_size=16, downsample_ratio=2 + ... ) + ... == 8100 + ... ) """ - side_pixels = math.isqrt(target_num_tokens_post_shuffle) * REDUCTION_FACTOR * patch_size + side_pixels = ( + math.isqrt(target_num_tokens_post_shuffle) * downsample_ratio * patch_size + ) assert isinstance(side_pixels, int) and side_pixels % patch_size == 0 - return Dims(width=side_pixels, height=side_pixels) + return side_pixels, side_pixels + @dataclass class DynamicResolutionParams: @@ -354,7 +357,7 @@ class BaseNanoNemotronVLProcessor(ABC): self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES image_size: int = config.force_image_size self.patch_size: int = getattr(config, "patch_size", 16) - self.downsample_ratio: float = self.config.downsample_ratio + # self.downsample_ratio: float = self.config.downsample_ratio self.image_size = image_size self.use_thumbnail: bool = config.use_thumbnail @@ -392,9 +395,10 @@ class BaseNanoNemotronVLProcessor(ABC): ) return num_tiles * num_image_token_per_tile( - tile_dims=Dims(width=image_width, height=image_height), + width=image_width, + height=image_height, patch_size=self.patch_size, - downsample_ratio=self.downsample_ratio + downsample_ratio=self.downsample_ratio, ) def _images_to_pixel_values_lst( @@ -508,8 +512,9 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): super().__init__( config=config, tokenizer=tokenizer, max_num_tiles=max_num_tiles, **kwargs ) - self.max_model_len = max_model_len + self._patch_size: int = getattr(config, "patch_size", 16) + self.max_model_len = max_model_len self._min_num_patches = min_num_patches self._factor_max = factor_max self._pixel_shuffle = pixel_shuffle @@ -518,47 +523,90 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): self._use_thumbnail = use_thumbnail self._thumbnail_size = thumbnail_size self._thumbnail_area_threshold = thumbnail_area_threshold + self.norm_mean = torch.tensor(self.CLIP_PIXEL_MEAN).reshape(1, 3, 1, 1) + self.norm_std = torch.tensor(self.CLIP_PIXEL_STD).reshape(1, 3, 1, 1) self._transform = T.Compose( [ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)), + # T.Normalize(mean=pixel_mean, std=pixel_std), - This is done down below with input_conditioner ] ) self._apply_data_augment = apply_data_augment + reduction_factor = 1 / self.config.downsample_ratio + assert reduction_factor == 2.0, ( + "I don't understand what's going on if this isn't 4" + ) + self.downsample_ratio = int(reduction_factor) ** (pixel_shuffle + conv_merging) + assert self.downsample_ratio == 2, ( + f"I don't understand what's going on if {self.downsample_ratio=} isn't 2" + ) - self.norm_mean = torch.tensor(self.CLIP_PIXEL_MEAN).reshape(1, 3, 1, 1) - self.norm_std = torch.tensor(self.CLIP_PIXEL_STD).reshape(1, 3, 1, 1) - self.downsample_ratio = 2 if pixel_shuffle else 1 + def _get_num_embeddings(self, width: int, height: int) -> int: + return num_image_token_per_tile( + width=width, + height=height, + patch_size=self._patch_size, + downsample_ratio=self.downsample_ratio, + ) + + def max_num_tokens_available(self, text_prompt_length: int) -> int: + return self.max_model_len - text_prompt_length - 4 + + def _images_to_pixel_values_lst( + self, + text_prompt_length: int, + images: list[Image.Image], + max_num_tiles: int, + ) -> tuple[list[torch.Tensor], list[int]]: + num_tokens_available = self.max_num_tokens_available(text_prompt_length) + params_per_image = self.compute_params(images, num_tokens_available) + + feature_sizes = [] + images = [] + for param in params_per_image: + for t in self.apply_params(param): + if t.ndim == 3: + t = t.unsqueeze(0) + images.append(t) + feature_sizes.append(param.num_embeddings) + print(f"{feature_sizes=}") + print(f"{params_per_image=}") + return images, feature_sizes feature_size_cache: dict[ Image.Image, int - ] = {} # TODO(nhaber): Find a less silly way of doing this... Why can't this be a class variable? + ] = {} # TODO(nhaber): Find a less silly way of doing this... Why can't this be an instance variable? - def apply_params(self, params: DynamicResolutionParams) -> torch.Tensor: + def get_cached_feature_size(self, image: Image.Image) -> int: + feature_size = self.feature_size_cache[id(image)] + del self.feature_size_cache[id(image)] + return feature_size + + def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]: resized_img = params.media.resize( ( - params.patch_size[0] * self.patch_size, - params.patch_size[1] * self.patch_size, + params.patch_size[0] * self._patch_size, + params.patch_size[1] * self._patch_size, ) ) - # processed_images = [resized_img] + processed_images = [resized_img] - # # Add thumbnail if enabled and image area is below threshold - # if self._use_thumbnail: - # # Calculate areas - # resized_area = resized_img.size[0] * resized_img.size[1] - # thumbnail_area = self._thumbnail_size * self._thumbnail_size - # area_ratio = resized_area / thumbnail_area + # Add thumbnail if enabled and image area is below threshold + if self._use_thumbnail: + # Calculate areas + resized_area = resized_img.size[0] * resized_img.size[1] + thumbnail_area = self._thumbnail_size * self._thumbnail_size + area_ratio = resized_area / thumbnail_area - # # Only add thumbnail if resized image area is less than threshold % of - # # thumbnail area - # if area_ratio < self._thumbnail_area_threshold: - # thumbnail_img = params.media.resize( - # (self._thumbnail_size, self._thumbnail_size) - # ) - # processed_images.append(thumbnail_img) + # Only add thumbnail if resized image area is less than threshold % of thumbnail area + if area_ratio < self._thumbnail_area_threshold: + thumbnail_img = params.media.resize( + (self._thumbnail_size, self._thumbnail_size) + ) + processed_images.append(thumbnail_img) - return self._transform(resized_img) + return [self._transform(img) for img in processed_images] def process_media( self, @@ -568,11 +616,11 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): tiling_augment_prob: float = 0.4, ) -> tuple[DynamicResolutionParams, int]: """Process a single media item and return its parameters. + Args: media: The media item to process num_tokens_available: Number of tokens available for this media - data_augment: Whether to apply data augmentation to the image. Defaults to - False. + data_augment: Whether to apply data augmentation to the image. Defaults to False. Returns: DynamicResolutionParams for the media """ @@ -581,11 +629,9 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): "Dynamic resolution is only supported for image media" ) orig_width, orig_height = media.width, media.height - - closest_patch_height = math.ceil( - orig_height / self.patch_size - ) # TODO(nhaber): Ask Tyler - the previous round + 0.5 code is dangerous [banker's rounding], no? If we flip this back to the round, the max_wh_fill_budget needs to do -1 for each of w;h to be safe - closest_patch_width = math.ceil(orig_width / self.patch_size) + # TODO(nhaber): Ask Tyler - the round + 0.5 code is dangerous [banker's rounding], no? + closest_patch_height = round(orig_height / self._patch_size + 0.5) + closest_patch_width = round(orig_width / self._patch_size + 0.5) patches = closest_patch_height * closest_patch_width factor = min( @@ -594,8 +640,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): target_patch_height = math.floor(factor * closest_patch_height) target_patch_width = math.floor(factor * closest_patch_width) - # We only consider self._min_num_patches if it is greater than - # current_num_tokens_available. + # We only consider self._min_num_patches if it is greater than current_num_tokens_available. if ( current_num_tokens_available > self._min_num_patches and target_patch_height * target_patch_width < self._min_num_patches @@ -608,20 +653,19 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): if ( self._min_side is not None - and min(target_patch_width, target_patch_height) * self.patch_size + and min(target_patch_width, target_patch_height) * self._patch_size < self._min_side ): if target_patch_width <= target_patch_height: - up_factor = self._min_side / (target_patch_width * self.patch_size) + up_factor = self._min_side / (target_patch_width * self._patch_size) new_patch_height = math.ceil(up_factor * target_patch_height) new_patch_width = math.ceil(up_factor * target_patch_width) if new_patch_height * new_patch_width > current_num_tokens_available: - # If only one side can be min_side, make as big as possible at - # native aspect ratio while staying below max_patches + # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches if ( max(current_num_tokens_available // new_patch_width, 1) - * self.patch_size + * self._patch_size < self._min_side ): up_factor = math.sqrt( @@ -640,16 +684,15 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): target_patch_height = new_patch_height target_patch_width = new_patch_width else: - up_factor = self._min_side / (target_patch_height * self.patch_size) + up_factor = self._min_side / (target_patch_height * self._patch_size) new_patch_height = math.ceil(up_factor * target_patch_height) new_patch_width = math.ceil(up_factor * target_patch_width) if new_patch_height * new_patch_width > current_num_tokens_available: - # If only one side can be min_side, make as big as possible at - # native aspect ratio while staying below max_patches + # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches if ( max(current_num_tokens_available // new_patch_height, 1) - * self.patch_size + * self._patch_size < self._min_side ): up_factor = math.sqrt( @@ -708,15 +751,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): target_patch_width, target_patch_height, current_num_tokens_available ) - assert isinstance(media, Image.Image), ( - "Dynamic resolution is only supported for image media" - ) - # Calculate embeddings for the main dynamic resolution image - num_embeddings_per_tile = num_image_token_per_tile( - tile_dims=Dims(width=target_patch_width, height=target_patch_height), - patch_size=self.patch_size, - downsample_ratio=self.downsample_ratio + num_embeddings = self._get_num_embeddings( + target_patch_width * self._patch_size, + target_patch_height * self._patch_size, ) token_count = target_patch_width * target_patch_height @@ -725,33 +763,30 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): num_tiles = 1 # Base dynamic resolution image if self._use_thumbnail: # Calculate areas - resized_area = (target_patch_width * self.patch_size) * ( - target_patch_height * self.patch_size + resized_area = (target_patch_width * self._patch_size) * ( + target_patch_height * self._patch_size ) thumbnail_area = self._thumbnail_size * self._thumbnail_size area_ratio = resized_area / thumbnail_area - # Only add thumbnail if resized image area is less than threshold % of - # thumbnail area + # Only add thumbnail if resized image area is less than threshold % of thumbnail area if area_ratio < self._thumbnail_area_threshold: num_tiles += 1 # Add 1 for thumbnail # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) - num_embeddings_per_tile += num_image_token_per_tile( - tile_dims=Dims(width=self._thumbnail_size, height=self._thumbnail_size), - patch_size=self.patch_size, - downsample_ratio=self.downsample_ratio + num_embeddings += self._get_num_embeddings( + self._thumbnail_size, self._thumbnail_size ) token_count += ( self._thumbnail_size - // self.patch_size + // self._patch_size * self._thumbnail_size - // self.patch_size + // self._patch_size ) return DynamicResolutionParams( media=media, num_tiles=num_tiles, - num_embeddings=num_embeddings_per_tile, + num_embeddings=num_embeddings, patch_size=(target_patch_width, target_patch_height), ), token_count @@ -805,7 +840,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): media_list: list[Image.Image], num_tokens_available: int | None = None, data_augment: bool = False, - ) -> tuple[list[DynamicResolutionParams], list[int]]: + ) -> list[DynamicResolutionParams]: """Compute parameters for all media with iterative token budgeting. Args: @@ -821,26 +856,24 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): * (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1) ) - # When the number of available token is too small, allow self._min_num_patches - # per media and let the sample be truncated. + # When the number of available token is too small, allow self._min_num_patches per media and + # let the sample be truncated. num_tokens_available = max( num_tokens_available, self._min_num_patches * len(media_list) ) - # Clip the number of tokens available per media to be between min and max - # patches. + # Clip the number of tokens available per media to be between min and max patches. num_tokens_available_per_media = [ max(num_tokens_available, self._min_num_patches) for _ in range(len(media_list)) ] - # In theory this could be a while True loop, but in case the process_media - # method slightly + # In theory this could be a while True loop, but in case the process_media method slightly # changes, I want to make sure we don't get stuck in an infinite loop. for _ in range(10): # Step 1: Process each media with current token budget - params: list[DynamicResolutionParams] = [] - token_counts: list[int] = [] + params = [] + token_counts = [] for media, tokens_for_media in zip( media_list, num_tokens_available_per_media @@ -850,18 +883,14 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): ) params.append(param) token_counts.append(token_count) + self.feature_size_cache[id(param.media)] = param.num_embeddings # Step 2: Check if total tokens is within budget total_tokens = sum(token_counts) if total_tokens <= num_tokens_available: # We're within budget, return the params - # Convert from patch count to actual token count after downsampling - divisor = (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1) - adjusted_token_counts = [tc // divisor for tc in token_counts] - for param, feature_size in zip(params, adjusted_token_counts, strict=True): - self.feature_size_cache[id(param.media)] = feature_size - return params, adjusted_token_counts + return params # Step 3: We're over budget, need to scale down # Calculate scaling factor to get under budget @@ -880,8 +909,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): for i in range(len(num_tokens_available_per_media)) ] ) - # If there was not scaling down, we're stuck just use min_num_patches per - # media, else try with the scaled down num_tokens_available_per_media. + # If there was not scaling down, we're stuck just use min_num_patches per media, else + # try with the scaled down num_tokens_available_per_media. if not scaled_down: num_tokens_available_per_media = [self._min_num_patches] * len( media_list @@ -900,15 +929,15 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): ) def rearrange_img(x): - py = x.shape[-2] // self.patch_size - px = x.shape[-1] // self.patch_size + py = x.shape[-2] // self._patch_size + px = x.shape[-1] // self._patch_size x = einops.rearrange( x, "c (py yy) (px xx) -> (py px) (c yy xx)", py=py, - yy=self.patch_size, + yy=self._patch_size, px=px, - xx=self.patch_size, + xx=self._patch_size, ) return x @@ -941,34 +970,6 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): None, ) - def max_num_tokens_available(self, text_prompt_length: int) -> int: - return self.max_model_len - text_prompt_length - 4 - - def _images_to_pixel_values_lst( - self, - text_prompt_length: int, - images: list[Image.Image], - max_num_tiles: int, - ) -> tuple[list[torch.Tensor], list[int]]: - num_tokens_available = self.max_num_tokens_available(text_prompt_length) - params_per_image, feature_sizes = self.compute_params( - images, num_tokens_available - ) - print(f"{feature_sizes=}") - print(f"{params_per_image=}") - images = [] - for param in params_per_image: - t = self.apply_params(param) - if t.ndim == 3: - t = t.unsqueeze(0) - images.append(t) - return images, feature_sizes - - def get_cached_feature_size(self, image: Image.Image) -> int: - feature_size = self.feature_size_cache[id(image)] - del self.feature_size_cache[id(image)] - return feature_size - class NanoNemotronVLProcessor(DynamicResolutionImageTiler): """ @@ -1339,12 +1340,11 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): processor = self.get_hf_processor() # we get the CustomProcessor here max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = ( - seq_len - max_image_tokens - ) // num_image_token_per_tile( - tile_dims=Dims(width=256, height=256), - patch_size=processor.patch_size, - downsample_ratio=processor.downsample_ratio + max_total_frames = (seq_len - max_image_tokens) // num_image_token_per_tile( + width=256, + height=256, + patch_size=processor._patch_size, + downsample_ratio=processor.downsample_ratio, ) # TODO(nhaber): get 256 dynamically max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) @@ -1483,9 +1483,10 @@ class NanoNemotronVLMultiModalProcessor( def get_video_replacement_internvl(item_idx: int): feature_size = num_image_token_per_tile( - tile_dims=Dims(width=256, height=256), - patch_size=hf_processor.patch_size, - downsample_ratio=hf_processor.downsample_ratio + width=256, + height=256, + patch_size=hf_processor._patch_size, + downsample_ratio=hf_processor.downsample_ratio, ) # TODO(nhaber): get 256 dynamically video, metadata = mm_items["video"][item_idx] num_patches = video_num_patches[item_idx] @@ -1550,17 +1551,18 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): num_images = mm_counts.get("image", 0) processor = self.info.get_hf_processor() B = processor.max_num_tokens_available(text_prompt_length=num_images) - target_dims = width_and_height_for_max_num_tokens_available( + target_width, target_height = width_and_height_for_max_num_tokens_available( target_num_tokens_post_shuffle=B, - patch_size=processor.patch_size, + patch_size=processor._patch_size, + downsample_ratio=processor.downsample_ratio, ) image_overrides = mm_options.get("image") if mm_options else None return { "image": self._get_dummy_images( - width=target_dims.width, - height=target_dims.height, + width=target_width, + height=target_height, num_images=num_images, overrides=image_overrides, ) From 6f1249fdb28615972f3b89234d11f3b19a0bd72a Mon Sep 17 00:00:00 2001 From: Netanel Haber Date: Tue, 23 Dec 2025 06:27:03 -0800 Subject: [PATCH 10/10] minimize diff --- vllm/model_executor/models/nano_nemotron_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 2a41d43ab9660..23d9496ea37c0 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -386,7 +386,7 @@ class BaseNanoNemotronVLProcessor(ABC): ) -> int: target_ratios = get_internvl_target_ratios(1, max_num_tiles) - num_tiles, _, _ = calculate_internvl_targets( + num_patches, _, _ = calculate_internvl_targets( orig_width=image_width, orig_height=image_height, target_ratios=target_ratios, @@ -394,7 +394,7 @@ class BaseNanoNemotronVLProcessor(ABC): use_thumbnail=self.use_thumbnail, ) - return num_tiles * num_image_token_per_tile( + return num_patches * num_image_token_per_tile( width=image_width, height=image_height, patch_size=self.patch_size,