From 6979edb5755c320eac3e1432036c1c8242085012 Mon Sep 17 00:00:00 2001
From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Date: Thu, 11 Dec 2025 13:49:27 +0200
Subject: [PATCH 01/10] import
---
image_processing.py | 508 ++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 508 insertions(+)
create mode 100644 image_processing.py
diff --git a/image_processing.py b/image_processing.py
new file mode 100644
index 0000000000000..979ff681edc87
--- /dev/null
+++ b/image_processing.py
@@ -0,0 +1,508 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE.
+import math
+
+import torch
+from einops import rearrange
+from torchvision import transforms as T
+from torchvision.transforms import Compose
+from torchvision.transforms.functional import InterpolationMode
+
+
+IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
+IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
+SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
+SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
+CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
+CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
+RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060]
+RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250]
+
+
+pixel_statistics = {
+ "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
+ "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
+ "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
+ "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
+ "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
+ "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
+ "internvit300M": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
+ "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
+}
+
+
+# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685
+# Copyright (c) 2023 OpenGVLab.
+def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
+ best_ratio_diff = float('inf')
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_ratio = ratio
+ elif ratio_diff == best_ratio_diff:
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
+ best_ratio = ratio
+ return best_ratio
+
+
+def find_closest_area_weighted_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
+ """
+ Find the best number of tiles based on the aspect ratio and the area covered by the tiles.
+ """
+ best_factor = float('-inf')
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ factor_based_on_area_n_ratio = (
+ min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6) *
+ min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio))
+ if factor_based_on_area_n_ratio > best_factor:
+ best_factor = factor_based_on_area_n_ratio
+ best_ratio = ratio
+ return best_ratio
+
+
+def process_images(sample_imgs, patch_dim, dynamic_resolution, batch_mode=False):
+ """Process a batch of images for multimodal training or evaluation.
+
+ This function handles image preprocessing with support for both static and dynamic
+ resolution processing. For dynamic resolution, it rearranges images into patches
+ and computes cumulative sequence lengths for efficient batching.
+
+ Args:
+ sample_imgs (List[torch.Tensor]): List of image tensors with shape (C, H, W).
+ patch_dim (int): Dimension of each patch (e.g., 14 for 14x14 patches).
+ dynamic_resolution (bool): Whether to use dynamic resolution processing.
+ If True, images are rearranged into patches with variable sequence lengths.
+ If False, images are simply stacked into a batch tensor.
+ batch_mode (bool, optional): Whether this is being called from training batch processing.
+ If True, wraps tensors in additional list dimension for consistency with batch format.
+ If False, returns tensors directly as used in evaluation. Defaults to False.
+
+ Returns:
+ tuple: A 4-tuple containing:
+ - images (torch.Tensor): Processed image tensor.
+ For dynamic resolution: shape (1, total_patches, patch_features) if batch_mode=False,
+ or shape (1, total_patches, patch_features) if batch_mode=True
+ For static resolution: shape (batch_size, C, H, W)
+ - imgs_sizes (torch.Tensor): Image sizes tensor with shape (N, 2)
+ containing [width, height] for each image, or [[0,0]] if no images.
+ - vision_cu_lengths (torch.Tensor or None): Cumulative sequence lengths
+ for dynamic resolution. Shape (batch_size + 1,) for evaluation mode,
+ or shape (1, batch_size + 1) for batch mode. None for static resolution.
+ - vision_max_lengths (torch.Tensor or None): Maximum sequence length
+ among all images for dynamic resolution. Scalar tensor for evaluation mode,
+ or shape (1,) for batch mode. None for static resolution.
+
+ Note:
+ This function is designed for processing one microbatch at a time for dynamic resolution.
+ """
+ vision_cu_lengths = None
+ vision_max_lengths = None
+
+ if len(sample_imgs) > 0:
+ imgs_sizes = torch.tensor([[img.shape[1], img.shape[2]] for img in sample_imgs], dtype=torch.int32)
+ if dynamic_resolution:
+ def rearrange_img(x):
+ py = x.shape[-2] // patch_dim
+ px = x.shape[-1] // patch_dim
+ x = rearrange(x, 'c (py yy) (px xx) -> (py px) (c yy xx)',
+ py=py, yy=patch_dim,
+ px=px, xx=patch_dim,
+ )
+ return x
+ imgs = [rearrange_img(img) for img in sample_imgs]
+
+ current_length = 0
+ max_length = 0
+ vision_cu_lengths = [0]
+ for img in imgs:
+ if max_length < img.shape[0]:
+ max_length = img.shape[0]
+ current_length += img.shape[0]
+ vision_cu_lengths.append(current_length)
+
+ vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
+ vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
+
+ # For batch mode, wrap in additional dimension for consistency
+ if batch_mode:
+ vision_cu_lengths = vision_cu_lengths.unsqueeze(0) # Shape: (1, batch_size + 1)
+ vision_max_lengths = vision_max_lengths.unsqueeze(0) # Shape: (1,)
+
+ imgs = torch.cat(imgs, dim=0)
+ images = imgs.unsqueeze(0)
+ else:
+ images = torch.stack(sample_imgs)
+ else:
+ imgs_sizes = torch.tensor([[0,0]], dtype=torch.int32)
+ if len(sample_imgs) == 0 and batch_mode:
+ # For batch mode when no images, use appropriate dummy tensor
+ images = torch.tensor([[0]], dtype=torch.float32)
+ else:
+ images = torch.stack(sample_imgs)
+
+ return images, imgs_sizes, vision_cu_lengths, vision_max_lengths
+
+
+class ImageTransform:
+ """Image transformation."""
+
+ def __init__(self, input_size, vision_model_type, *, dynamic_resolution=False, res_step=16, min_num_patches=1, max_num_patches=128, pixel_shuffle=False, min_side=None, conv_merging=False, match_tiling_dynamic_resolution=False, masked_tiling_dynamic_resolution=False, thumbnail_area_threshold=0.8):
+ self._transform = _build_transform(input_size, vision_model_type)
+ self._vision_model_type = vision_model_type
+ self._dynamic_resolution = dynamic_resolution
+ self._res_step = res_step
+ self._min_num_patches = min_num_patches
+ self._max_num_patches = max_num_patches
+ self._pixel_shuffle = pixel_shuffle
+ self._min_side = min_side
+ self._conv_merging = conv_merging
+ self._match_tiling_dynamic_resolution = match_tiling_dynamic_resolution
+ self._masked_tiling_dynamic_resolution = masked_tiling_dynamic_resolution
+ self._thumbnail_area_threshold = thumbnail_area_threshold
+
+ def __call__(self, img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, find_closest_aspect_ratio_fn=find_closest_aspect_ratio, is_video=False):
+ assert not augment, "Image augmentation not implemented."
+ if use_tiling:
+ assert img_h == img_w, "dynamic tiling expects equal tile height and width"
+ imgs = dynamic_preprocess(
+ img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail,
+ find_closest_aspect_ratio_fn=find_closest_aspect_ratio_fn)
+ imgs = [self._transform(img) for img in imgs]
+ elif self._masked_tiling_dynamic_resolution:
+ assert img_h == img_w, "masked tiling dynamic resolution expects equal tile height and width"
+ assert "radio" in self._vision_model_type, "Masked tiling dynamic resolution is only supported for radio models"
+
+ # Use tiling logic to determine tile grid (nx, ny)
+ orig_width, orig_height = img.size
+ aspect_ratio = orig_width / orig_height
+
+ target_ratios = set(
+ (i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1)
+ if i * j <= max_num_tiles and i * j >= 1
+ )
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ tiling = find_closest_aspect_ratio_fn(
+ aspect_ratio, target_ratios, orig_width, orig_height, img_h
+ )
+
+ # Resize and split into tiles of size (img_h x img_h)
+ target_width = img_h * tiling[0]
+ target_height = img_w * tiling[1]
+ blocks = tiling[0] * tiling[1]
+
+ resized_img = img.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // img_h)) * img_h,
+ (i // (target_width // img_h)) * img_h,
+ ((i % (target_width // img_h)) + 1) * img_h,
+ ((i // (target_width // img_h)) + 1) * img_h,
+ )
+ tile_img = resized_img.crop(box)
+ processed_images.append(tile_img)
+ assert len(processed_images) == blocks
+
+ # Optional thumbnail
+ if use_thumbnail and blocks != 1:
+ thumbnail_img = img.resize((img_h, img_h))
+ processed_images.append(thumbnail_img)
+
+ pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.ToTensor(),
+ T.Normalize(mean=pixel_mean, std=pixel_std),
+ ])
+ imgs = [transform(im) for im in processed_images]
+ elif self._match_tiling_dynamic_resolution:
+ assert img_h == img_w, "match tiling dynamic resolution expects equal tile height and width"
+ assert "radio" in self._vision_model_type, "Match tiling dynamic resolution is only supported for radio models"
+
+ # Use tiling logic to determine optimal dimensions
+ orig_width, orig_height = img.size
+ aspect_ratio = orig_width / orig_height
+
+ # Calculate target ratios (same logic as tiling)
+ target_ratios = set(
+ (i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
+ i * j <= max_num_tiles and i * j >= 1)
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # Find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio_fn(
+ aspect_ratio, target_ratios, orig_width, orig_height, img_h)
+
+ # Calculate the target width and height using tiling logic
+ target_width = img_h * target_aspect_ratio[0]
+ target_height = img_w * target_aspect_ratio[1]
+
+ # Resize image to target dimensions (same as tiling, but don't split)
+ resized_img = img.resize((target_width, target_height))
+
+ # Process as single dynamic resolution image
+ pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.ToTensor(),
+ T.Normalize(mean=pixel_mean, std=pixel_std),
+ ])
+ processed_images = [resized_img]
+
+ # Add thumbnail if use_thumbnail=True and there's more than 1 tile
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+ if use_thumbnail and blocks != 1:
+ thumbnail_img = img.resize((img_h, img_h))
+ processed_images.append(thumbnail_img)
+
+ imgs = [transform(img) for img in processed_images]
+ elif self._dynamic_resolution:
+ pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.ToTensor(),
+ T.Normalize(mean=pixel_mean, std=pixel_std),
+ ])
+ processed_img = dynamic_res_preprocess(img, min_patches=self._min_num_patches, max_patches=self._max_num_patches, res_step=self._res_step, pixel_shuffle=self._pixel_shuffle, min_side=self._min_side, conv_merging=self._conv_merging, is_video=is_video)
+ processed_images = [processed_img]
+
+ # Add thumbnail if enabled and image area is below threshold
+ if use_thumbnail:
+ # Calculate areas
+ processed_width, processed_height = processed_img.size
+ resized_area = processed_width * processed_height
+ thumbnail_area = img_h * img_h # img_h should be square thumbnail size
+ area_ratio = resized_area / thumbnail_area
+
+ # Only add thumbnail if resized image area is less than threshold % of thumbnail area
+ if area_ratio < self._thumbnail_area_threshold:
+ thumbnail_img = img.resize((img_h, img_h)) # Use square thumbnail with img_h size
+ processed_images.append(thumbnail_img)
+
+ imgs = [transform(img) for img in processed_images]
+ else:
+ imgs = [self._transform(img)]
+
+ return imgs
+
+
+# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702
+# Copyright (c) 2023 OpenGVLab.
+def dynamic_preprocess(
+ image, min_num=1, max_num=6, image_size=448, use_thumbnail=False,
+ find_closest_aspect_ratio_fn=find_closest_aspect_ratio):
+ orig_width, orig_height = image.size
+ aspect_ratio = orig_width / orig_height
+
+ # calculate the existing image aspect ratio
+ target_ratios = set(
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
+ i * j <= max_num and i * j >= min_num)
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio_fn(
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
+
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+ assert len(processed_images) == blocks
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = image.resize((image_size, image_size))
+ processed_images.append(thumbnail_img)
+ return processed_images
+
+
+def dynamic_res_preprocess(image, min_patches=1, max_patches=128, res_step=16, factor_max=1., pixel_shuffle=False, min_side=None, conv_merging=False, is_video=False):
+ """Preprocess an image with dynamic resolution for vision transformers.
+
+ This function resizes an image to optimize the number of patches while respecting
+ constraints on minimum/maximum patches, minimum side length, and compatibility
+ with pixel shuffle or convolution merging operations.
+
+ The algorithm works by:
+ 1. Computing the initial patch grid size based on the image dimensions and res_step
+ 2. Scaling the patch grid to fit within the max_patches constraint
+ 3. Ensuring the result has at least min_patches
+ 4. Optionally enforcing a minimum side length constraint
+ 5. Rounding patch dimensions to even numbers for pixel_shuffle/conv_merging compatibility
+ 6. Resizing the image to the computed target dimensions
+
+ Args:
+ image (PIL.Image): Input image to preprocess.
+ min_patches (int, optional): Minimum number of patches required. Defaults to 1.
+ max_patches (int, optional): Maximum number of patches allowed. Defaults to 128.
+ res_step (int, optional): Resolution step size (patch dimension). Defaults to 16.
+ factor_max (float, optional): Maximum scaling factor to apply. Defaults to 1.0.
+ pixel_shuffle (bool, optional): Whether to ensure compatibility with pixel shuffle
+ operations by rounding to even patch dimensions. Defaults to False.
+ min_side (int, optional): Minimum side length in pixels. If specified, ensures
+ at least one side meets this constraint. Defaults to None.
+ conv_merging (bool, optional): Whether to ensure compatibility with convolution
+ merging by rounding to even patch dimensions. Defaults to False.
+
+ Returns:
+ PIL.Image: Resized image with dimensions optimized for patch-based processing.
+ The output dimensions will be (target_patch_width * res_step, target_patch_height * res_step).
+
+ Note:
+ The function preserves aspect ratio as much as possible while satisfying all constraints.
+ When constraints conflict (e.g., min_side vs max_patches), the function prioritizes
+ staying within max_patches while maximizing the image size.
+
+ Example:
+ >>> from PIL import Image
+ >>> img = Image.open("example.jpg") # 800x600 image
+ >>> resized_img = dynamic_res_preprocess(img, min_patches=4, max_patches=64, res_step=14)
+ >>> # Returns image resized to maintain aspect ratio with 4-64 patches of size 14x14
+ """
+ orig_width, orig_height = image.size
+
+ closest_patch_height = round(orig_height / res_step + 0.5)
+ closest_patch_width = round(orig_width / res_step + 0.5)
+ patches = closest_patch_height * closest_patch_width
+
+ factor = min(math.sqrt(max_patches / patches), factor_max)
+ target_patch_height = math.floor(factor * closest_patch_height)
+ target_patch_width = math.floor(factor * closest_patch_width)
+
+ if target_patch_height * target_patch_width < min_patches:
+ up_factor = math.sqrt(min_patches / (target_patch_height * target_patch_width))
+ target_patch_height = math.ceil(up_factor * target_patch_height)
+ target_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if min_side is not None and min(target_patch_width, target_patch_height) * res_step < min_side:
+ if target_patch_width <= target_patch_height:
+ up_factor = min_side / (target_patch_width * res_step)
+ new_patch_height = math.ceil(up_factor * target_patch_height)
+ new_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if new_patch_height * new_patch_width > max_patches:
+ # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
+ if max(max_patches // new_patch_width, 1) * res_step < min_side:
+ up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width))
+ target_patch_height = math.floor(up_factor * target_patch_height)
+ target_patch_width = math.floor(up_factor * target_patch_width)
+ target_patch_width = new_patch_width
+ target_patch_height = max(max_patches // new_patch_width, 1)
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = new_patch_width
+ else:
+ up_factor = min_side / (target_patch_height * res_step)
+ new_patch_height = math.ceil(up_factor * target_patch_height)
+ new_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if new_patch_height * new_patch_width > max_patches:
+ # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
+ if max(max_patches // new_patch_height, 1) * res_step < min_side:
+ up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width))
+ target_patch_height = math.floor(up_factor * target_patch_height)
+ target_patch_width = math.floor(up_factor * target_patch_width)
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = max(max_patches // new_patch_height, 1)
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = new_patch_width
+
+ # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
+ # or by 4 when BOTH are enabled (two successive 2x reductions)
+ if pixel_shuffle or conv_merging:
+ required_divisor = 4 if (pixel_shuffle and conv_merging) else 2
+
+ rem_h = target_patch_height % required_divisor
+ if rem_h != 0:
+ inc_h = required_divisor - rem_h
+ if (target_patch_height + inc_h) * target_patch_width <= max_patches:
+ target_patch_height += inc_h
+ else:
+ target_patch_height = max(1, target_patch_height - rem_h)
+
+ rem_w = target_patch_width % required_divisor
+ if rem_w != 0:
+ inc_w = required_divisor - rem_w
+ if target_patch_height * (target_patch_width + inc_w) <= max_patches:
+ target_patch_width += inc_w
+ else:
+ target_patch_width = max(1, target_patch_width - rem_w)
+ assert target_patch_height * target_patch_width <= max_patches
+
+ #TEMP: hacky way to process video same as in training
+ if is_video:
+ # max_patches = 1024
+ # min_patches = 512
+ target_patch_width = 32
+ target_patch_height = 32
+
+
+ # resize the image
+ resized_img = image.resize((target_patch_width * res_step, target_patch_height * res_step))
+
+ return resized_img
+
+
+
+# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
+# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
+def _build_transform(input_size, vision_model_type):
+ if vision_model_type in ("siglip", "internvit", "internvit300M", "radio", "radio-g", "cradio-g"):
+ pixel_mean, pixel_std = pixel_statistics[vision_model_type]
+
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=pixel_mean, std=pixel_std)
+ ])
+ elif vision_model_type == "clip":
+ pixel_mean, pixel_std = pixel_statistics[vision_model_type]
+
+ transform = Compose([
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.ToTensor(),
+ T.Normalize(mean=pixel_mean, std=pixel_std),
+ ])
+ elif vision_model_type.startswith("hf://"):
+ from megatron.core.models.huggingface.module import get_hf_model_type
+
+ model_type = get_hf_model_type(vision_model_type)
+ if "siglip" in model_type:
+ from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
+
+ processor = SiglipImageProcessor(size={"height": input_size, "width": input_size})
+
+ def transform(x):
+ x = x.convert("RGB") if x.mode != "RGB" else x
+ x = processor(x, return_tensors="pt")
+ return x["pixel_values"][0]
+ else:
+ raise NotImplementedError(f"image processing not defined for huggingface model {vision_model_type}")
+ else:
+ raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}")
+
+ return transform
From 50ffea98266c3a7910306ce185e198171bf68bb6 Mon Sep 17 00:00:00 2001
From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Date: Thu, 11 Dec 2025 14:12:03 +0200
Subject: [PATCH 02/10] update
---
image_processing.py | 2126 +++++++++++++++++++++++++++++++++++--------
1 file changed, 1723 insertions(+), 403 deletions(-)
diff --git a/image_processing.py b/image_processing.py
index 979ff681edc87..5d43e764acadd 100644
--- a/image_processing.py
+++ b/image_processing.py
@@ -1,12 +1,23 @@
-# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE.
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
import math
+from typing import Callable, Optional
+import numpy as np
+import random
+from PIL import Image
+import albumentations as A
+import einops
import torch
-from einops import rearrange
from torchvision import transforms as T
from torchvision.transforms import Compose
from torchvision.transforms.functional import InterpolationMode
+from data_loading.conversation_sample import (
+ ImageMedia,
+ VideoFrameMedia,
+)
IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
@@ -24,16 +35,23 @@ pixel_statistics = {
"internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
"radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
- "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "internvit300M": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
"huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
+ "radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
+ "cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
+ "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
}
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685
# Copyright (c) 2023 OpenGVLab.
-def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
- best_ratio_diff = float('inf')
+def find_closest_aspect_ratio(
+ aspect_ratio: float,
+ target_ratios: list[tuple[int, int]],
+ width: int,
+ height: int,
+ image_size: int,
+) -> tuple[int, int]:
+ best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
@@ -48,301 +66,548 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_
return best_ratio
-def find_closest_area_weighted_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
+def find_closest_area_weighted_aspect_ratio(
+ aspect_ratio: float,
+ target_ratios: list[tuple[int, int]],
+ width: int,
+ height: int,
+ image_size: int,
+):
"""
Find the best number of tiles based on the aspect ratio and the area covered by the tiles.
"""
- best_factor = float('-inf')
+ best_factor = float("-inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
- factor_based_on_area_n_ratio = (
- min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6) *
- min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio))
+ factor_based_on_area_n_ratio = min(
+ (ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
+ ) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
-def process_images(sample_imgs, patch_dim, dynamic_resolution, batch_mode=False):
- """Process a batch of images for multimodal training or evaluation.
-
- This function handles image preprocessing with support for both static and dynamic
- resolution processing. For dynamic resolution, it rearranges images into patches
- and computes cumulative sequence lengths for efficient batching.
-
- Args:
- sample_imgs (List[torch.Tensor]): List of image tensors with shape (C, H, W).
- patch_dim (int): Dimension of each patch (e.g., 14 for 14x14 patches).
- dynamic_resolution (bool): Whether to use dynamic resolution processing.
- If True, images are rearranged into patches with variable sequence lengths.
- If False, images are simply stacked into a batch tensor.
- batch_mode (bool, optional): Whether this is being called from training batch processing.
- If True, wraps tensors in additional list dimension for consistency with batch format.
- If False, returns tensors directly as used in evaluation. Defaults to False.
-
- Returns:
- tuple: A 4-tuple containing:
- - images (torch.Tensor): Processed image tensor.
- For dynamic resolution: shape (1, total_patches, patch_features) if batch_mode=False,
- or shape (1, total_patches, patch_features) if batch_mode=True
- For static resolution: shape (batch_size, C, H, W)
- - imgs_sizes (torch.Tensor): Image sizes tensor with shape (N, 2)
- containing [width, height] for each image, or [[0,0]] if no images.
- - vision_cu_lengths (torch.Tensor or None): Cumulative sequence lengths
- for dynamic resolution. Shape (batch_size + 1,) for evaluation mode,
- or shape (1, batch_size + 1) for batch mode. None for static resolution.
- - vision_max_lengths (torch.Tensor or None): Maximum sequence length
- among all images for dynamic resolution. Scalar tensor for evaluation mode,
- or shape (1,) for batch mode. None for static resolution.
-
- Note:
- This function is designed for processing one microbatch at a time for dynamic resolution.
+# Mike's optimized ToTensor.
+def _fast_to_tensor(pic) -> torch.Tensor:
+ np_img = np.array(pic, copy=False)
+ img = torch.from_numpy(np_img)
+ img = img.permute(2, 0, 1) # HWC to CHW
+ fp_img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
+ fp_img.div_(255)
+ return fp_img
+
+
+@dataclass
+class ImageTilingParams:
+ media: ImageMedia | VideoFrameMedia
+ num_tiles: int
+ num_embeddings: int
+
+
+class ImageTilingStrategy(ABC):
"""
- vision_cu_lengths = None
- vision_max_lengths = None
-
- if len(sample_imgs) > 0:
- imgs_sizes = torch.tensor([[img.shape[1], img.shape[2]] for img in sample_imgs], dtype=torch.int32)
- if dynamic_resolution:
- def rearrange_img(x):
- py = x.shape[-2] // patch_dim
- px = x.shape[-1] // patch_dim
- x = rearrange(x, 'c (py yy) (px xx) -> (py px) (c yy xx)',
- py=py, yy=patch_dim,
- px=px, xx=patch_dim,
- )
- return x
- imgs = [rearrange_img(img) for img in sample_imgs]
+ Base class for image tiling strategies.
+ A tiling strategy is a function that takes a list of media and returns a list of image tiling parameters.
+ These can then be used to apply the tiling to the media.
- current_length = 0
- max_length = 0
- vision_cu_lengths = [0]
- for img in imgs:
- if max_length < img.shape[0]:
- max_length = img.shape[0]
- current_length += img.shape[0]
- vision_cu_lengths.append(current_length)
-
- vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
- vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
-
- # For batch mode, wrap in additional dimension for consistency
- if batch_mode:
- vision_cu_lengths = vision_cu_lengths.unsqueeze(0) # Shape: (1, batch_size + 1)
- vision_max_lengths = vision_max_lengths.unsqueeze(0) # Shape: (1,)
-
- imgs = torch.cat(imgs, dim=0)
- images = imgs.unsqueeze(0)
- else:
- images = torch.stack(sample_imgs)
- else:
- imgs_sizes = torch.tensor([[0,0]], dtype=torch.int32)
- if len(sample_imgs) == 0 and batch_mode:
- # For batch mode when no images, use appropriate dummy tensor
- images = torch.tensor([[0]], dtype=torch.float32)
- else:
- images = torch.stack(sample_imgs)
+ Subclasses must implement the `compute_params` and `apply_params` methods.
- return images, imgs_sizes, vision_cu_lengths, vision_max_lengths
+ The `transform` method is a convenience method that computes the transformation parameters and applies the transformation to the media.
+
+ """
+
+ def transform(
+ self,
+ media_list: list[ImageMedia | VideoFrameMedia],
+ num_tokens_available: int | None = None,
+ ) -> list[torch.Tensor]:
+ """
+ Transform the media and compute the transformation parameters.
+ """
+ transform_media_list = self.compute_params(media_list, num_tokens_available)
+ return [
+ self.apply_params(transform_media, **kwargs)
+ for transform_media in transform_media_list
+ ]
+
+ @abstractmethod
+ def compute_params(
+ self, media_list: list[ImageMedia | VideoFrameMedia], num_tokens_available: int, max_num_tiles: int | None = None, **kwargs
+ ) -> list[ImageTilingParams]:
+ """
+ Compute the transformation parameters and the number of tokens to use for the media.
+
+ Args:
+ media_list: List of media to transform
+ num_tokens_available: Number of tokens available for all media
+ max_num_tiles: Maximum number of tiles allowed (optional, defaults to instance's max_num_tiles if not provided)
+
+ Returns:
+ list of transformation parameters with the media
+ """
+ ...
+
+ @abstractmethod
+ def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]:
+ """
+ Apply the transformation parameters to the media.
+
+ Args:
+ transform_media: The media to apply the transformation to
+
+ Returns:
+ list of transformed media tensors
+ """
+ ...
+
+ @abstractmethod
+ def stack(
+ self, images: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]:
+ """
+ Stack the images into a single tensor.
+
+ Args:
+ media_list: List of images to stack
+
+ Returns:
+ tuple of (stacked media, image sizes, vision cu lengths, vision max lengths)
+ """
+ ...
-class ImageTransform:
- """Image transformation."""
+class _FixedSizeStrategy(ImageTilingStrategy):
+ """
+ Base class for fixed size image tiling strategies.
+ """
- def __init__(self, input_size, vision_model_type, *, dynamic_resolution=False, res_step=16, min_num_patches=1, max_num_patches=128, pixel_shuffle=False, min_side=None, conv_merging=False, match_tiling_dynamic_resolution=False, masked_tiling_dynamic_resolution=False, thumbnail_area_threshold=0.8):
- self._transform = _build_transform(input_size, vision_model_type)
+ def __init__(
+ self,
+ vision_model_type: str,
+ target_width: int,
+ target_height: int,
+ embeddings_per_image: int,
+ ):
self._vision_model_type = vision_model_type
- self._dynamic_resolution = dynamic_resolution
- self._res_step = res_step
- self._min_num_patches = min_num_patches
- self._max_num_patches = max_num_patches
- self._pixel_shuffle = pixel_shuffle
- self._min_side = min_side
- self._conv_merging = conv_merging
- self._match_tiling_dynamic_resolution = match_tiling_dynamic_resolution
- self._masked_tiling_dynamic_resolution = masked_tiling_dynamic_resolution
- self._thumbnail_area_threshold = thumbnail_area_threshold
-
- def __call__(self, img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, find_closest_aspect_ratio_fn=find_closest_aspect_ratio, is_video=False):
- assert not augment, "Image augmentation not implemented."
- if use_tiling:
- assert img_h == img_w, "dynamic tiling expects equal tile height and width"
- imgs = dynamic_preprocess(
- img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail,
- find_closest_aspect_ratio_fn=find_closest_aspect_ratio_fn)
- imgs = [self._transform(img) for img in imgs]
- elif self._masked_tiling_dynamic_resolution:
- assert img_h == img_w, "masked tiling dynamic resolution expects equal tile height and width"
- assert "radio" in self._vision_model_type, "Masked tiling dynamic resolution is only supported for radio models"
-
- # Use tiling logic to determine tile grid (nx, ny)
- orig_width, orig_height = img.size
- aspect_ratio = orig_width / orig_height
-
- target_ratios = set(
- (i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1)
- if i * j <= max_num_tiles and i * j >= 1
- )
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
-
- tiling = find_closest_aspect_ratio_fn(
- aspect_ratio, target_ratios, orig_width, orig_height, img_h
- )
-
- # Resize and split into tiles of size (img_h x img_h)
- target_width = img_h * tiling[0]
- target_height = img_w * tiling[1]
- blocks = tiling[0] * tiling[1]
-
- resized_img = img.resize((target_width, target_height))
- processed_images = []
- for i in range(blocks):
- box = (
- (i % (target_width // img_h)) * img_h,
- (i // (target_width // img_h)) * img_h,
- ((i % (target_width // img_h)) + 1) * img_h,
- ((i // (target_width // img_h)) + 1) * img_h,
- )
- tile_img = resized_img.crop(box)
- processed_images.append(tile_img)
- assert len(processed_images) == blocks
-
- # Optional thumbnail
- if use_thumbnail and blocks != 1:
- thumbnail_img = img.resize((img_h, img_h))
- processed_images.append(thumbnail_img)
-
- pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
- transform = T.Compose([
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
- T.ToTensor(),
- T.Normalize(mean=pixel_mean, std=pixel_std),
- ])
- imgs = [transform(im) for im in processed_images]
- elif self._match_tiling_dynamic_resolution:
- assert img_h == img_w, "match tiling dynamic resolution expects equal tile height and width"
- assert "radio" in self._vision_model_type, "Match tiling dynamic resolution is only supported for radio models"
-
- # Use tiling logic to determine optimal dimensions
- orig_width, orig_height = img.size
- aspect_ratio = orig_width / orig_height
-
- # Calculate target ratios (same logic as tiling)
- target_ratios = set(
- (i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
- i * j <= max_num_tiles and i * j >= 1)
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
-
- # Find the closest aspect ratio to the target
- target_aspect_ratio = find_closest_aspect_ratio_fn(
- aspect_ratio, target_ratios, orig_width, orig_height, img_h)
-
- # Calculate the target width and height using tiling logic
- target_width = img_h * target_aspect_ratio[0]
- target_height = img_w * target_aspect_ratio[1]
-
- # Resize image to target dimensions (same as tiling, but don't split)
- resized_img = img.resize((target_width, target_height))
-
- # Process as single dynamic resolution image
- pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
- transform = T.Compose([
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
- T.ToTensor(),
- T.Normalize(mean=pixel_mean, std=pixel_std),
- ])
- processed_images = [resized_img]
-
- # Add thumbnail if use_thumbnail=True and there's more than 1 tile
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
- if use_thumbnail and blocks != 1:
- thumbnail_img = img.resize((img_h, img_h))
- processed_images.append(thumbnail_img)
-
- imgs = [transform(img) for img in processed_images]
- elif self._dynamic_resolution:
- pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
- transform = T.Compose([
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
- T.ToTensor(),
- T.Normalize(mean=pixel_mean, std=pixel_std),
- ])
- processed_img = dynamic_res_preprocess(img, min_patches=self._min_num_patches, max_patches=self._max_num_patches, res_step=self._res_step, pixel_shuffle=self._pixel_shuffle, min_side=self._min_side, conv_merging=self._conv_merging, is_video=is_video)
- processed_images = [processed_img]
-
- # Add thumbnail if enabled and image area is below threshold
- if use_thumbnail:
- # Calculate areas
- processed_width, processed_height = processed_img.size
- resized_area = processed_width * processed_height
- thumbnail_area = img_h * img_h # img_h should be square thumbnail size
- area_ratio = resized_area / thumbnail_area
-
- # Only add thumbnail if resized image area is less than threshold % of thumbnail area
- if area_ratio < self._thumbnail_area_threshold:
- thumbnail_img = img.resize((img_h, img_h)) # Use square thumbnail with img_h size
- processed_images.append(thumbnail_img)
-
- imgs = [transform(img) for img in processed_images]
- else:
- imgs = [self._transform(img)]
-
- return imgs
-
-
-# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702
-# Copyright (c) 2023 OpenGVLab.
-def dynamic_preprocess(
- image, min_num=1, max_num=6, image_size=448, use_thumbnail=False,
- find_closest_aspect_ratio_fn=find_closest_aspect_ratio):
- orig_width, orig_height = image.size
- aspect_ratio = orig_width / orig_height
-
- # calculate the existing image aspect ratio
- target_ratios = set(
- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
- i * j <= max_num and i * j >= min_num)
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
-
- # find the closest aspect ratio to the target
- target_aspect_ratio = find_closest_aspect_ratio_fn(
- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
-
- # calculate the target width and height
- target_width = image_size * target_aspect_ratio[0]
- target_height = image_size * target_aspect_ratio[1]
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
-
- # resize the image
- resized_img = image.resize((target_width, target_height))
- processed_images = []
- for i in range(blocks):
- box = (
- (i % (target_width // image_size)) * image_size,
- (i // (target_width // image_size)) * image_size,
- ((i % (target_width // image_size)) + 1) * image_size,
- ((i // (target_width // image_size)) + 1) * image_size
+ self._target_width = target_width
+ self._target_height = target_height
+ self._embeddings_per_image = embeddings_per_image
+ self._transform = self._build_transform(
+ (target_width, target_height), vision_model_type
+ )
+
+ # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
+ # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
+ @staticmethod
+ def _build_transform(target_size: tuple[int, int], vision_model_type: str):
+ """
+ Build a transform for a given vision model type and target size.
+ """
+ if vision_model_type in ("siglip", "internvit", "radio", "radio-g", "cradio-g"):
+ pixel_mean, pixel_std = pixel_statistics[vision_model_type]
+
+ transform = T.Compose(
+ [
+ T.Lambda(
+ lambda img: img.convert("RGB") if img.mode != "RGB" else img
+ ),
+ T.Resize(
+ (target_size[1], target_size[0]),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)),
+ T.Normalize(mean=pixel_mean, std=pixel_std),
+ ]
+ )
+ # From the official CLIP repo.
+ elif vision_model_type == "clip":
+ pixel_mean, pixel_std = pixel_statistics[vision_model_type]
+
+ transform = Compose(
+ [
+ T.Resize(
+ (target_size[1], target_size[0]),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ T.Lambda(
+ lambda img: img.convert("RGB") if img.mode != "RGB" else img
+ ),
+ T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)),
+ T.Normalize(mean=pixel_mean, std=pixel_std),
+ ]
+ )
+ elif vision_model_type.startswith("hf://"):
+ from megatron.core.models.huggingface.module import get_hf_model_type
+
+ model_type = get_hf_model_type(vision_model_type)
+ if "siglip" in model_type:
+ from transformers.models.siglip.image_processing_siglip import (
+ SiglipImageProcessor,
+ )
+
+ processor = SiglipImageProcessor(
+ size={"height": target_size[1], "width": target_size[0]}
+ )
+
+ def transform(x):
+ x = x.convert("RGB") if x.mode != "RGB" else x
+ x = processor(x, return_tensors="pt")
+ return x["pixel_values"][0]
+ else:
+ raise NotImplementedError(
+ f"image processing not defined for huggingface model {vision_model_type}"
+ )
+ else:
+ raise NotImplementedError(
+ f"image processing not defined for vision model {vision_model_type}"
+ )
+
+ return transform
+
+ def stack(
+ self, images: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
+ return (
+ torch.stack(images) if len(images) > 0 else None,
+ torch.tensor(
+ [(img.shape[1], img.shape[2]) for img in images], dtype=torch.int32
+ ) if len(images) > 0 else None,
+ None,
+ None,
)
- # split the image
- split_img = resized_img.crop(box)
- processed_images.append(split_img)
- assert len(processed_images) == blocks
- if use_thumbnail and len(processed_images) != 1:
- thumbnail_img = image.resize((image_size, image_size))
- processed_images.append(thumbnail_img)
- return processed_images
-def dynamic_res_preprocess(image, min_patches=1, max_patches=128, res_step=16, factor_max=1., pixel_shuffle=False, min_side=None, conv_merging=False, is_video=False):
+class NoTilingStrategy(_FixedSizeStrategy):
+ """
+ A simple image transformation that resizes the image to the target width and height.
+ """
+
+ def __init__(
+ self,
+ vision_model_type: str,
+ target_width: int,
+ target_height: int,
+ embeddings_per_image: int,
+ ):
+ super().__init__(
+ vision_model_type=vision_model_type,
+ target_width=target_width,
+ target_height=target_height,
+ embeddings_per_image=embeddings_per_image,
+ )
+
+ def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]:
+ return [self._transform(transform_media.media.value)]
+
+ def compute_params(
+ self,
+ media_list: list[ImageMedia | VideoFrameMedia],
+ num_tokens_available: Optional[int] = None,
+ max_num_tiles: int | None = None,
+ **kwargs,
+ ) -> list[ImageTilingParams]:
+ return [
+ ImageTilingParams(
+ media=media, num_tiles=1, num_embeddings=self._embeddings_per_image
+ )
+ for media in media_list
+ ]
+
+ def __str__(self):
+ return f"SimpleImageTransform(vision_model_type={self._vision_model_type}, num_tokens_per_image={self._embeddings_per_image})"
+
+
+@dataclass
+class ImageTilingParamsV1(ImageTilingParams):
+ tiling: tuple[int, int]
+
+
+class ImageTilingStrategyV1(_FixedSizeStrategy):
+ """Tiling image transformation.
+
+ This transformation splits the image into a grid of tiles and applies the transformation to each tile.
+ """
+
+ # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
+ # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
+
+ def __init__(
+ self,
+ vision_model_type: str,
+ tile_size: int,
+ use_thumbnail: bool,
+ min_num_tiles: int,
+ max_num_tiles: int,
+ embeddings_per_tile: int,
+ find_closest_aspect_ratio_fn=find_closest_aspect_ratio,
+ ):
+ super().__init__(
+ vision_model_type=vision_model_type,
+ target_width=tile_size,
+ target_height=tile_size,
+ embeddings_per_image=embeddings_per_tile,
+ )
+
+ # print(f"Transformation params: {vision_model_type=}, {use_tiling=}, {tile_size=}, {use_thumbnail=}, {augment=}, {min_num_tiles=}, {max_num_tiles=}, {find_closest_aspect_ratio_fn=}")
+ self._tile_size = tile_size
+ self._use_thumbnail = use_thumbnail
+ self._min_num_tiles = min_num_tiles
+ self._max_num_tiles = max_num_tiles
+ self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn
+
+ # Calculate all possible aspect ratios for each max_num_tiles.
+ self.target_ratios = {
+ max_num_tiles: sorted(
+ set(
+ (x, y)
+ for n in range(self._min_num_tiles, max_num_tiles + 1)
+ for x in range(1, n + 1)
+ for y in range(1, n + 1)
+ if x * y <= max_num_tiles and x * y >= self._min_num_tiles
+ ),
+ key=lambda x: x[0] * x[1],
+ )
+ for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1)
+ }
+
+ self.transform = A.Compose([
+ A.OneOf([
+ A.GaussNoise(var_limit=(5.0, 30.0)),
+ A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5)),
+ ], p=0.3),
+ A.OneOf([
+ A.MedianBlur(blur_limit=5),
+ A.GaussianBlur(blur_limit=5),
+ ], p=0.2),
+ A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.5),
+ A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=15, val_shift_limit=15, p=0.3),
+ A.ImageCompression(quality_lower=70, quality_upper=100, p=0.3),
+ ])
+
+ def apply_params(self, transform_media: ImageTilingParams, data_augment: bool = False, **kwargs) -> list[torch.Tensor]:
+ assert isinstance(transform_media, ImageTilingParamsV1)
+ image = transform_media.media.value
+
+ if data_augment:
+ image = self.transform(image=np.asarray(image))["image"]
+ image = Image.fromarray(image)
+
+ # calculate the target width and height
+ target_width = self._tile_size * transform_media.tiling[0]
+ target_height = self._tile_size * transform_media.tiling[1]
+ blocks = transform_media.tiling[0] * transform_media.tiling[1]
+
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // self._tile_size)) * self._tile_size,
+ (i // (target_width // self._tile_size)) * self._tile_size,
+ ((i % (target_width // self._tile_size)) + 1) * self._tile_size,
+ ((i // (target_width // self._tile_size)) + 1) * self._tile_size,
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+ assert len(processed_images) == blocks
+ if self._use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = image.resize((self._tile_size, self._tile_size))
+ processed_images.append(thumbnail_img)
+
+ return [self._transform(img) for img in processed_images]
+
+ def compute_params(
+ self,
+ media_list: list[ImageMedia | VideoFrameMedia],
+ num_tokens_available: Optional[int] = None,
+ max_num_tiles: int | None = None,
+ data_augment: bool = False,
+ tiling_augment_prob: float = 0.4,
+ **kwargs,
+ ) -> list[ImageTilingParamsV1]:
+ # Use provided max_num_tiles or fall back to instance's max_num_tiles
+ # Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value
+ effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
+ effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles)
+
+ max_num_tiles_to_use = min(
+ num_tokens_available // self._embeddings_per_image, effective_max_num_tiles
+ )
+
+ # calculate the existing image aspect ratio
+ target_ratios = self.target_ratios[max_num_tiles_to_use]
+
+ params = []
+ for media in media_list:
+ if isinstance(media, ImageMedia):
+ img_size = (media.width, media.height)
+ elif isinstance(media, VideoFrameMedia):
+ img_size = (media.video_width, media.video_height)
+ else:
+ raise ValueError(f"Unsupported media type: {type(media)}")
+
+ aspect_ratio = img_size[0] / img_size[1]
+
+ # find the closest aspect ratio to the target
+ tiling = self._find_closest_aspect_ratio_fn(
+ aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size
+ )
+ if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob:
+ tiling = self.augment_tiling(tiling)
+ num_tiles = tiling[0] * tiling[1]
+ if self._use_thumbnail and num_tiles != 1:
+ num_tiles += 1
+
+ params.append(
+ ImageTilingParamsV1(
+ media=media,
+ num_tiles=num_tiles,
+ num_embeddings=num_tiles * self._embeddings_per_image,
+ tiling=tiling,
+ )
+ )
+
+ return params
+
+ def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]:
+ def num_tiles(tiling: tuple[int, int]) -> int:
+ return tiling[0] * tiling[1]
+
+ def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]:
+ if random.random() < minus_prob:
+ # Minus one
+ if tiling[0] == 1 and tiling[1] == 1:
+ return tiling
+ elif tiling[0] == 1:
+ return (tiling[0], tiling[1] - 1)
+ elif tiling[1] == 1:
+ return (tiling[0] - 1, tiling[1])
+ else:
+ if random.random() < 0.5:
+ return (tiling[0] - 1, tiling[1])
+ else:
+ return (tiling[0], tiling[1] - 1)
+ else:
+ # Plus one
+ if num_tiles(tiling) < self._max_num_tiles:
+ tiling0 = (tiling[0] + 1, tiling[1])
+ tiling1 = (tiling[0], tiling[1] + 1)
+ if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles:
+ return tiling
+ elif num_tiles(tiling0) > self._max_num_tiles:
+ return tiling1
+ elif num_tiles(tiling1) > self._max_num_tiles:
+ return tiling0
+ else:
+ if random.random() < 0.5:
+ return tiling0
+ else:
+ return tiling1
+ return tiling
+
+ new_tiling = plus_minus_one(tiling)
+ return new_tiling
+
+ def __str__(self):
+ return f"TilingImageTransform(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, embeddings_per_tile={self._embeddings_per_image}, find_closest_aspect_ratio_fn={self._find_closest_aspect_ratio_fn})"
+
+
+class TileDegradationStrategy(ImageTilingStrategy):
+ """Strategy for tiling images and video frames, each with their own tiling strategy, while trying to match the
+ number of tokens left in the sample by reducing the number of tiles if needed.
+ """
+
+ # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
+ # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
+
+ def __init__(
+ self,
+ image_strategy: ImageTilingStrategy,
+ video_frame_strategy: ImageTilingStrategy,
+ embeddings_per_tile: int,
+ max_num_tiles: int,
+ tile_degradation_map: dict[int, int] = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1},
+ ):
+ self._image_strategy = image_strategy
+ self._video_frame_strategy = video_frame_strategy
+ self._embeddings_per_tile = embeddings_per_tile
+ self._max_num_tiles = max_num_tiles
+ self._tile_degradation_map = tile_degradation_map
+
+ def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]:
+ if isinstance(transform_media.media, ImageMedia):
+ return self._image_strategy.apply_params(transform_media, **kwargs)
+ elif isinstance(transform_media.media, VideoFrameMedia):
+ return self._video_frame_strategy.apply_params(transform_media, **kwargs)
+ else:
+ raise ValueError(f"Unsupported media type: {type(transform_media.media)}")
+
+ def compute_params(
+ self,
+ media_list: list[ImageMedia | VideoFrameMedia],
+ num_tokens_available: int | None = None,
+ max_num_tiles: int | None = None,
+ **kwargs,
+ ) -> list[ImageTilingParams]:
+ # Use provided max_num_tiles or fall back to instance's max_num_tiles
+ effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
+ max_num_tiles_to_use = effective_max_num_tiles
+ degradation_map = self._tile_degradation_map
+
+ while True:
+ params = []
+ img_num_tiles = []
+ for media in media_list:
+ if isinstance(media, ImageMedia):
+ media_params = self._image_strategy.compute_params(
+ [media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs
+ )[0]
+ elif isinstance(media, VideoFrameMedia):
+ max_num_tiles_to_use = 1
+ media_params = self._video_frame_strategy.compute_params(
+ [media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs
+ )[0]
+ else:
+ raise ValueError(f"Unsupported media type: {type(media)}")
+ img_num_tiles.append(media_params.num_tiles)
+ params.append(media_params)
+ if max_num_tiles_to_use == 1 or num_tokens_available is None:
+ break
+ if sum(img_num_tiles) * self._embeddings_per_tile > num_tokens_available:
+ if max_num_tiles_to_use in degradation_map:
+ max_num_tiles_to_use = degradation_map[max_num_tiles_to_use]
+ else:
+ # End of degradation
+ break
+ else:
+ break
+ return params
+
+ def stack(
+ self, images: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
+ return self._image_strategy.stack(images)
+
+ def __str__(self):
+ return f"TileDegradationImageTransform(max_num_tiles={self._max_num_tiles}, image_transform={self._image_strategy}, video_frame_transform={self._video_frame_strategy})"
+
+
+@dataclass
+class DynamicResolutionParams(ImageTilingParams):
+ patch_size: tuple[int, int]
+
+
+class DynamicResolutionImageTilingStrategy(ImageTilingStrategy):
"""Preprocess an image with dynamic resolution for vision transformers.
-
+
This function resizes an image to optimize the number of patches while respecting
constraints on minimum/maximum patches, minimum side length, and compatibility
with pixel shuffle or convolution merging operations.
-
+
The algorithm works by:
1. Computing the initial patch grid size based on the image dimensions and res_step
2. Scaling the patch grid to fit within the max_patches constraint
@@ -350,159 +615,1214 @@ def dynamic_res_preprocess(image, min_patches=1, max_patches=128, res_step=16, f
4. Optionally enforcing a minimum side length constraint
5. Rounding patch dimensions to even numbers for pixel_shuffle/conv_merging compatibility
6. Resizing the image to the computed target dimensions
-
- Args:
- image (PIL.Image): Input image to preprocess.
- min_patches (int, optional): Minimum number of patches required. Defaults to 1.
- max_patches (int, optional): Maximum number of patches allowed. Defaults to 128.
- res_step (int, optional): Resolution step size (patch dimension). Defaults to 16.
- factor_max (float, optional): Maximum scaling factor to apply. Defaults to 1.0.
- pixel_shuffle (bool, optional): Whether to ensure compatibility with pixel shuffle
- operations by rounding to even patch dimensions. Defaults to False.
- min_side (int, optional): Minimum side length in pixels. If specified, ensures
- at least one side meets this constraint. Defaults to None.
- conv_merging (bool, optional): Whether to ensure compatibility with convolution
- merging by rounding to even patch dimensions. Defaults to False.
-
- Returns:
- PIL.Image: Resized image with dimensions optimized for patch-based processing.
- The output dimensions will be (target_patch_width * res_step, target_patch_height * res_step).
-
+
Note:
The function preserves aspect ratio as much as possible while satisfying all constraints.
When constraints conflict (e.g., min_side vs max_patches), the function prioritizes
staying within max_patches while maximizing the image size.
-
+
Example:
>>> from PIL import Image
>>> img = Image.open("example.jpg") # 800x600 image
- >>> resized_img = dynamic_res_preprocess(img, min_patches=4, max_patches=64, res_step=14)
+ >>> strategy = DynamicResolutionImageTilingStrategy(vision_model_type="radio", min_patches=4, max_patches=64, res_step=14, get_num_embeddings=lambda x, y: x * y * 2)
+ >>> params = strategy.compute_params([img])
+ >>> img_tensor = strategy.apply_params(params[0])
>>> # Returns image resized to maintain aspect ratio with 4-64 patches of size 14x14
"""
- orig_width, orig_height = image.size
- closest_patch_height = round(orig_height / res_step + 0.5)
- closest_patch_width = round(orig_width / res_step + 0.5)
- patches = closest_patch_height * closest_patch_width
+ def __init__(
+ self,
+ vision_model_type: str,
+ min_num_patches: int,
+ patch_size: int,
+ get_num_embeddings: Callable[[int, int], int],
+ factor_max: float = 1.0,
+ pixel_shuffle: bool = False,
+ min_side: int | None = None,
+ conv_merging: bool = False,
+ use_thumbnail: bool = False,
+ thumbnail_size: int = 448,
+ thumbnail_area_threshold: float = 0.8,
+ max_num_patches: int = 0,
+ apply_data_augment: bool = False,
+ ):
+ """
+ Args:
+ vision_model_type: Vision model type.
+ min_num_patches: Minimum number of patches required. Defaults to 1.
+ max_num_patches: Maximum number of patches allowed. Defaults to 0 (no maximum).
+ patch_size: Resolution step size (patch dimension). Defaults to 16.
+ get_num_embeddings: Function to get the number of embeddings from the patch size (width, height).
+ factor_max: Maximum scaling factor to apply. Defaults to 1.0.
+ pixel_shuffle: Whether to ensure compatibility with pixel shuffle operations by rounding to even patch
+ dimensions. Defaults to False.
+ min_side: Minimum side length in pixels. If specified, ensures at least one side meets this constraint.
+ Defaults to None.
+ conv_merging: Whether to ensure compatibility with convolution merging by rounding to even patch dimensions.
+ Defaults to False.
+ use_thumbnail: Whether to add a thumbnail image when processing. Defaults to False.
+ thumbnail_size: Size of the thumbnail image (width and height). Defaults to 448.
+ thumbnail_area_threshold: Maximum area percentage (0.0-1.0) of the resized image relative to thumbnail area
+ for which to add a thumbnail. If the resized image area is larger than this threshold of the thumbnail
+ area, no thumbnail will be added. Defaults to 0.8 (80%).
+ apply_data_augment: Whether to apply data augmentation to the image. Defaults to False.
+ """
+ assert "radio" in vision_model_type, (
+ "Dynamic resolution is only supported for radio models"
+ )
+ self._vision_model_type = vision_model_type
+ self._min_num_patches = min_num_patches
+ self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf")
+ self._patch_size = patch_size
+ self._get_num_embeddings = get_num_embeddings
+ self._factor_max = factor_max
+ self._pixel_shuffle = pixel_shuffle
+ self._min_side = min_side
+ self._conv_merging = conv_merging
+ self._use_thumbnail = use_thumbnail
+ self._thumbnail_size = thumbnail_size
+ self._thumbnail_area_threshold = thumbnail_area_threshold
+ pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
+ self._transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)),
+ T.Normalize(mean=pixel_mean, std=pixel_std),
+ ]
+ )
+ self._apply_data_augment = apply_data_augment
- factor = min(math.sqrt(max_patches / patches), factor_max)
- target_patch_height = math.floor(factor * closest_patch_height)
- target_patch_width = math.floor(factor * closest_patch_width)
+ def apply_params(self, params: DynamicResolutionParams, **kwargs) -> list[torch.Tensor]:
+ # resize the image
+ resized_img = params.media.value.resize(
+ (
+ params.patch_size[0] * self._patch_size,
+ params.patch_size[1] * self._patch_size,
+ )
+ )
+ processed_images = [resized_img]
+
+ # Add thumbnail if enabled and image area is below threshold
+ if self._use_thumbnail:
+ # Calculate areas
+ resized_area = resized_img.size[0] * resized_img.size[1]
+ thumbnail_area = self._thumbnail_size * self._thumbnail_size
+ area_ratio = resized_area / thumbnail_area
+
+ # Only add thumbnail if resized image area is less than threshold % of thumbnail area
+ if area_ratio < self._thumbnail_area_threshold:
+ thumbnail_img = params.media.value.resize((self._thumbnail_size, self._thumbnail_size))
+ processed_images.append(thumbnail_img)
+
+ return [self._transform(img) for img in processed_images]
- if target_patch_height * target_patch_width < min_patches:
- up_factor = math.sqrt(min_patches / (target_patch_height * target_patch_width))
- target_patch_height = math.ceil(up_factor * target_patch_height)
- target_patch_width = math.ceil(up_factor * target_patch_width)
-
- if min_side is not None and min(target_patch_width, target_patch_height) * res_step < min_side:
- if target_patch_width <= target_patch_height:
- up_factor = min_side / (target_patch_width * res_step)
- new_patch_height = math.ceil(up_factor * target_patch_height)
- new_patch_width = math.ceil(up_factor * target_patch_width)
-
- if new_patch_height * new_patch_width > max_patches:
- # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
- if max(max_patches // new_patch_width, 1) * res_step < min_side:
- up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width))
- target_patch_height = math.floor(up_factor * target_patch_height)
- target_patch_width = math.floor(up_factor * target_patch_width)
- target_patch_width = new_patch_width
- target_patch_height = max(max_patches // new_patch_width, 1)
- else:
- target_patch_height = new_patch_height
- target_patch_width = new_patch_width
+ def process_media(
+ self,
+ media: ImageMedia | VideoFrameMedia,
+ num_tokens_available: int,
+ data_augment: bool = False,
+ tiling_augment_prob: float = 0.4,
+ ) -> DynamicResolutionParams:
+ """Process a single media item and return its parameters.
+
+ Args:
+ media: The media item to process
+ num_tokens_available: Number of tokens available for this media
+ data_augment: Whether to apply data augmentation to the image. Defaults to False.
+ Returns:
+ DynamicResolutionParams for the media
+ """
+ current_num_tokens_available = num_tokens_available
+ if isinstance(media, ImageMedia):
+ orig_width, orig_height = media.width, media.height
+ elif isinstance(media, VideoFrameMedia):
+ orig_width, orig_height = media.video_width, media.video_height
+ # current_num_tokens_available = 1024 #TEMP: hack for video
else:
- up_factor = min_side / (target_patch_height * res_step)
- new_patch_height = math.ceil(up_factor * target_patch_height)
- new_patch_width = math.ceil(up_factor * target_patch_width)
+ raise ValueError(f"Unsupported media type: {type(media)}")
- if new_patch_height * new_patch_width > max_patches:
- # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
- if max(max_patches // new_patch_height, 1) * res_step < min_side:
- up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width))
- target_patch_height = math.floor(up_factor * target_patch_height)
- target_patch_width = math.floor(up_factor * target_patch_width)
+ closest_patch_height = round(orig_height / self._patch_size + 0.5)
+ closest_patch_width = round(orig_width / self._patch_size + 0.5)
+ patches = closest_patch_height * closest_patch_width
+
+ factor = min(math.sqrt(current_num_tokens_available / patches), self._factor_max)
+ target_patch_height = math.floor(factor * closest_patch_height)
+ target_patch_width = math.floor(factor * closest_patch_width)
+
+ # We only consider self._min_num_patches if it is greater than current_num_tokens_available.
+ if current_num_tokens_available > self._min_num_patches and target_patch_height * target_patch_width < self._min_num_patches:
+ up_factor = math.sqrt(
+ self._min_num_patches / (target_patch_height * target_patch_width)
+ )
+ target_patch_height = math.ceil(up_factor * target_patch_height)
+ target_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if (
+ self._min_side is not None
+ and min(target_patch_width, target_patch_height) * self._patch_size
+ < self._min_side
+ ):
+ if target_patch_width <= target_patch_height:
+ up_factor = self._min_side / (target_patch_width * self._patch_size)
+ new_patch_height = math.ceil(up_factor * target_patch_height)
+ new_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if new_patch_height * new_patch_width > current_num_tokens_available:
+ # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
+ if (
+ max(current_num_tokens_available // new_patch_width, 1)
+ * self._patch_size
+ < self._min_side
+ ):
+ up_factor = math.sqrt(
+ current_num_tokens_available
+ / (target_patch_height * target_patch_width)
+ )
+ target_patch_height = math.floor(
+ up_factor * target_patch_height
+ )
+ target_patch_width = math.floor(
+ up_factor * target_patch_width
+ )
+ target_patch_width = new_patch_width
+ target_patch_height = max(
+ current_num_tokens_available // new_patch_width, 1
+ )
else:
target_patch_height = new_patch_height
- target_patch_width = max(max_patches // new_patch_height, 1)
+ target_patch_width = new_patch_width
else:
- target_patch_height = new_patch_height
- target_patch_width = new_patch_width
+ up_factor = self._min_side / (
+ target_patch_height * self._patch_size
+ )
+ new_patch_height = math.ceil(up_factor * target_patch_height)
+ new_patch_width = math.ceil(up_factor * target_patch_width)
- # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
- # or by 4 when BOTH are enabled (two successive 2x reductions)
- if pixel_shuffle or conv_merging:
- required_divisor = 4 if (pixel_shuffle and conv_merging) else 2
+ if new_patch_height * new_patch_width > current_num_tokens_available:
+ # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
+ if (
+ max(current_num_tokens_available // new_patch_height, 1)
+ * self._patch_size
+ < self._min_side
+ ):
+ up_factor = math.sqrt(
+ current_num_tokens_available
+ / (target_patch_height * target_patch_width)
+ )
+ target_patch_height = math.floor(
+ up_factor * target_patch_height
+ )
+ target_patch_width = math.floor(
+ up_factor * target_patch_width
+ )
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = max(
+ current_num_tokens_available // new_patch_height, 1
+ )
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = new_patch_width
- rem_h = target_patch_height % required_divisor
- if rem_h != 0:
- inc_h = required_divisor - rem_h
- if (target_patch_height + inc_h) * target_patch_width <= max_patches:
- target_patch_height += inc_h
- else:
- target_patch_height = max(1, target_patch_height - rem_h)
+ # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
+ # or by 4 when BOTH are enabled (two successive 2x reductions)
+ if self._pixel_shuffle or self._conv_merging:
+ required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2
- rem_w = target_patch_width % required_divisor
- if rem_w != 0:
- inc_w = required_divisor - rem_w
- if target_patch_height * (target_patch_width + inc_w) <= max_patches:
- target_patch_width += inc_w
- else:
- target_patch_width = max(1, target_patch_width - rem_w)
- assert target_patch_height * target_patch_width <= max_patches
+ rem_h = target_patch_height % required_divisor
+ if rem_h != 0:
+ inc_h = required_divisor - rem_h
+ if (target_patch_height + inc_h) * target_patch_width <= current_num_tokens_available:
+ target_patch_height += inc_h
+ else:
+ target_patch_height = max(required_divisor, target_patch_height - rem_h)
- #TEMP: hacky way to process video same as in training
- if is_video:
- # max_patches = 1024
- # min_patches = 512
- target_patch_width = 32
- target_patch_height = 32
+ rem_w = target_patch_width % required_divisor
+ if rem_w != 0:
+ inc_w = required_divisor - rem_w
+ if target_patch_height * (target_patch_width + inc_w) <= current_num_tokens_available:
+ target_patch_width += inc_w
+ else:
+ target_patch_width = max(required_divisor, target_patch_width - rem_w)
+
+ if data_augment and self._apply_data_augment and random.random() < tiling_augment_prob:
+ target_patch_width, target_patch_height = self.augment_resolution(target_patch_width, target_patch_height, current_num_tokens_available)
+
+ #TEMP: hack for video
+ if isinstance(media, VideoFrameMedia):
+ target_patch_width = 32
+ target_patch_height = 32
+
+ # Calculate embeddings for the main dynamic resolution image
+ num_embeddings = self._get_num_embeddings(
+ target_patch_width * self._patch_size,
+ target_patch_height * self._patch_size,
+ )
+
+ token_count = target_patch_width * target_patch_height
+ # Add thumbnail embeddings if enabled and image area is below threshold
+ num_tiles = 1 # Base dynamic resolution image
+ if self._use_thumbnail:
+ # Calculate areas
+ resized_area = (target_patch_width * self._patch_size) * (target_patch_height * self._patch_size)
+ thumbnail_area = self._thumbnail_size * self._thumbnail_size
+ area_ratio = resized_area / thumbnail_area
+
+ # Only add thumbnail if resized image area is less than threshold % of thumbnail area
+ if area_ratio < self._thumbnail_area_threshold:
+ num_tiles += 1 # Add 1 for thumbnail
+ # Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
+ num_embeddings += self._get_num_embeddings(self._thumbnail_size, self._thumbnail_size)
+ token_count += self._thumbnail_size // self._patch_size * self._thumbnail_size // self._patch_size
- # resize the image
- resized_img = image.resize((target_patch_width * res_step, target_patch_height * res_step))
+ return DynamicResolutionParams(
+ media=media,
+ num_tiles=num_tiles,
+ num_embeddings=num_embeddings,
+ patch_size=(target_patch_width, target_patch_height),
+ ), token_count
+
+ def augment_resolution(self, target_patch_width: int, target_patch_height: int, current_num_tokens_available: int) -> tuple[int, int]:
- return resized_img
+ min_num_patch_one_side = 32
-
-
-# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
-# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
-def _build_transform(input_size, vision_model_type):
- if vision_model_type in ("siglip", "internvit", "internvit300M", "radio", "radio-g", "cradio-g"):
- pixel_mean, pixel_std = pixel_statistics[vision_model_type]
-
- transform = T.Compose([
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
- T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
- T.ToTensor(),
- T.Normalize(mean=pixel_mean, std=pixel_std)
- ])
- elif vision_model_type == "clip":
- pixel_mean, pixel_std = pixel_statistics[vision_model_type]
-
- transform = Compose([
- T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
- T.ToTensor(),
- T.Normalize(mean=pixel_mean, std=pixel_std),
- ])
- elif vision_model_type.startswith("hf://"):
- from megatron.core.models.huggingface.module import get_hf_model_type
-
- model_type = get_hf_model_type(vision_model_type)
- if "siglip" in model_type:
- from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
-
- processor = SiglipImageProcessor(size={"height": input_size, "width": input_size})
-
- def transform(x):
- x = x.convert("RGB") if x.mode != "RGB" else x
- x = processor(x, return_tensors="pt")
- return x["pixel_values"][0]
+ if random.random() < 0.5:
+ # Minus one
+ if target_patch_width <= min_num_patch_one_side and target_patch_height <= min_num_patch_one_side:
+ return target_patch_width, target_patch_height
+ elif target_patch_width <= min_num_patch_one_side:
+ return target_patch_width, target_patch_height - min_num_patch_one_side
+ elif target_patch_height <= min_num_patch_one_side:
+ return target_patch_width - min_num_patch_one_side, target_patch_height
+ else:
+ if random.random() < 0.5:
+ return target_patch_width - min_num_patch_one_side, target_patch_height
+ else:
+ return target_patch_width, target_patch_height - min_num_patch_one_side
else:
- raise NotImplementedError(f"image processing not defined for huggingface model {vision_model_type}")
- else:
- raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}")
+ # Plus one
+ if target_patch_width * target_patch_height < current_num_tokens_available:
+ if random.random() < 0.5:
+ return target_patch_width + min_num_patch_one_side, target_patch_height
+ else:
+ return target_patch_width, target_patch_height + min_num_patch_one_side
+ return target_patch_width, target_patch_height
- return transform
+ def compute_params(
+ self,
+ media_list: list[ImageMedia | VideoFrameMedia],
+ num_tokens_available: int | None = None,
+ max_num_tiles: int | None = None,
+ data_augment: bool = False,
+ **kwargs,
+ ) -> list[ImageTilingParams]:
+ """Compute parameters for all media with iterative token budgeting.
+
+ Args:
+ media_list: List of media items to process
+ num_tokens_available: Total number of tokens available across all media
+ max_num_tiles: Maximum number of tiles (unused in this implementation)
+ data_augment: Whether to apply data augmentation to the image. Defaults to False.
+ Returns:
+ List of ImageTilingParams for each media item
+ """
+ num_tokens_available = num_tokens_available * (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1)
+ # When the number of available token is too small, allow self._min_num_patches per media and
+ # let the sample be truncated.
+ num_tokens_available = max(num_tokens_available, self._min_num_patches * len(media_list))
+
+ # Clip the number of tokens available per media to be between min and max patches.
+ num_tokens_available_per_media = [
+ max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
+ for _ in range(len(media_list))]
+
+ # In theory this could be a while True loop, but in case the process_media method slightly
+ # changes, I want to make sure we don't get stuck in an infinite loop.
+ for _ in range(10):
+ # Step 1: Process each media with current token budget
+ params = []
+ token_counts = []
+
+ for media, tokens_for_media in zip(media_list, num_tokens_available_per_media):
+ param, token_count = self.process_media(media, tokens_for_media, data_augment=data_augment)
+ params.append(param)
+ token_counts.append(token_count)
+
+ # Step 2: Check if total tokens is within budget
+ total_tokens = sum(token_counts)
+
+ if total_tokens <= num_tokens_available:
+ # We're within budget, return the params
+ return params
+
+ # Step 3: We're over budget, need to scale down
+ # Calculate scaling factor to get under budget
+ scaling_factor = num_tokens_available / total_tokens
+
+ # Recalculate token budgets for each media based on scaling
+ # Each media gets a proportional share of the total budget
+ scaled_down_num_tokens_available_per_media = [
+ max(self._min_num_patches, int(token_count * scaling_factor))
+ for token_count in token_counts
+ ]
+ scaled_down = any([
+ scaled_down_num_tokens_available_per_media[i] < num_tokens_available_per_media[i]
+ for i in range(len(num_tokens_available_per_media))])
+ # If there was not scaling down, we're stuck just use min_num_patches per media, else
+ # try with the scaled down num_tokens_available_per_media.
+ if not scaled_down:
+ num_tokens_available_per_media = [self._min_num_patches] * len(media_list)
+ else:
+ num_tokens_available_per_media = scaled_down_num_tokens_available_per_media
+ return params
+
+ def stack(
+ self, images: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
+ imgs_sizes = torch.tensor(
+ [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
+ )
+
+ def rearrange_img(x):
+ py = x.shape[-2] // self._patch_size
+ px = x.shape[-1] // self._patch_size
+ x = einops.rearrange(
+ x,
+ "c (py yy) (px xx) -> (py px) (c yy xx)",
+ py=py,
+ yy=self._patch_size,
+ px=px,
+ xx=self._patch_size,
+ )
+ return x
+
+ if len(images) > 0:
+ imgs = [rearrange_img(img) for img in images]
+
+ current_length = 0
+ max_length = 0
+ vision_cu_lengths = [0]
+ for img in imgs:
+ if max_length < img.shape[0]:
+ max_length = img.shape[0]
+ current_length += img.shape[0]
+ vision_cu_lengths.append(current_length)
+
+ vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
+ vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
+
+ return (
+ torch.cat(imgs, dim=0).unsqueeze(0),
+ imgs_sizes,
+ vision_cu_lengths,
+ vision_max_lengths,
+ )
+ else:
+ return (
+ torch.tensor([[0]], dtype=torch.float32),
+ torch.tensor([[0,0]], dtype=torch.int32),
+ None,
+ None,
+ )
+
+ def __str__(self):
+ return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})"
+
+
+@dataclass
+class MatchTilingDynamicResolutionParams(ImageTilingParams):
+ tiling: tuple[int, int]
+
+
+class MatchTilingDynamicResolutionStrategy(ImageTilingStrategy):
+ """
+ Strategy that uses tiling logic to determine optimal image dimensions but processes
+ the image as a single dynamic resolution image instead of splitting into tiles.
+
+ This combines the aspect ratio optimization from ImageTilingStrategyV1 with the
+ dynamic resolution processing from DynamicResolutionImageTilingStrategy.
+
+ Also includes tile degradation logic similar to TileDegradationStrategy.
+ """
+
+ def __init__(
+ self,
+ vision_model_type: str,
+ tile_size: int,
+ use_thumbnail: bool,
+ min_num_tiles: int,
+ max_num_tiles: int,
+ embeddings_per_tile: int,
+ patch_size: int,
+ get_num_embeddings: Callable[[int, int], int],
+ find_closest_aspect_ratio_fn=find_closest_aspect_ratio,
+ pixel_shuffle: bool = False,
+ conv_merging: bool = False,
+ tile_degradation_map: dict[int, int] = None,
+ video_frame_strategy: ImageTilingStrategy = None,
+ enable_tile_degradation: bool = True,
+ ):
+ """
+ Args:
+ vision_model_type: Vision model type (should support dynamic resolution)
+ tile_size: Size of each tile for tiling calculation
+ use_thumbnail: Whether tiling logic should include thumbnail
+ min_num_tiles: Minimum number of tiles for tiling calculation
+ max_num_tiles: Maximum number of tiles for tiling calculation
+ embeddings_per_tile: Embeddings per tile for tiling calculation
+ patch_size: Patch size for dynamic resolution processing
+ get_num_embeddings: Function to get number of embeddings from dimensions
+ find_closest_aspect_ratio_fn: Function to find closest aspect ratio
+ pixel_shuffle: Whether to ensure compatibility with pixel shuffle
+ conv_merging: Whether to ensure compatibility with convolution merging
+ tile_degradation_map: Map for degrading tiles when tokens are insufficient
+ video_frame_strategy: Strategy for processing video frames
+ enable_tile_degradation: Whether to enable tile degradation (default: True)
+ """
+ assert "radio" in vision_model_type, (
+ "MatchTilingDynamicResolution is only supported for radio models"
+ )
+
+ self._vision_model_type = vision_model_type
+ self._tile_size = tile_size
+ self._use_thumbnail = use_thumbnail
+ self._min_num_tiles = min_num_tiles
+ self._max_num_tiles = max_num_tiles
+ self._embeddings_per_tile = embeddings_per_tile
+ self._patch_size = patch_size
+ self._get_num_embeddings = get_num_embeddings
+ self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn
+ self._pixel_shuffle = pixel_shuffle
+ self._conv_merging = conv_merging
+ self._enable_tile_degradation = enable_tile_degradation
+
+ # Tile degradation logic (similar to TileDegradationStrategy)
+ if tile_degradation_map is None:
+ self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1}
+ else:
+ self._tile_degradation_map = tile_degradation_map
+
+ # Video frame strategy (similar to TileDegradationStrategy)
+ if video_frame_strategy is None:
+ self._video_frame_strategy = NoTilingStrategy(
+ vision_model_type=vision_model_type,
+ target_width=tile_size,
+ target_height=tile_size,
+ embeddings_per_image=embeddings_per_tile,
+ )
+ else:
+ self._video_frame_strategy = video_frame_strategy
+
+ # Calculate all possible aspect ratios for each max_num_tiles (borrowed from ImageTilingStrategyV1)
+ self.target_ratios = {
+ max_num_tiles: sorted(
+ set(
+ (x, y)
+ for n in range(self._min_num_tiles, max_num_tiles + 1)
+ for x in range(1, n + 1)
+ for y in range(1, n + 1)
+ if x * y <= max_num_tiles and x * y >= self._min_num_tiles
+ ),
+ key=lambda x: x[0] * x[1],
+ )
+ for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1)
+ }
+
+ # Set up transform for dynamic resolution processing
+ pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
+ self._transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.ToTensor(),
+ T.Normalize(mean=pixel_mean, std=pixel_std),
+ ]
+ )
+
+ def apply_params(self, params: MatchTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]:
+ # Handle video frames using the video frame strategy
+ if isinstance(params.media, VideoFrameMedia):
+ return self._video_frame_strategy.apply_params(params, **kwargs)
+
+ # Handle images with dynamic resolution processing
+ image = params.media.value
+ # Calculate the target width and height (same logic as ImageTilingStrategyV1)
+ target_width = self._tile_size * params.tiling[0]
+ target_height = self._tile_size * params.tiling[1]
+
+ # Resize the image to the target dimensions (same as ImageTilingStrategyV1)
+ resized_img = image.resize((target_width, target_height))
+
+ # Process as single dynamic resolution image
+ processed_images = [resized_img]
+
+ # Add thumbnail if use_thumbnail=True and there's more than 1 tile (same as ImageTilingStrategyV1)
+ blocks = params.tiling[0] * params.tiling[1]
+ if self._use_thumbnail and blocks != 1:
+ thumbnail_img = image.resize((self._tile_size, self._tile_size))
+ processed_images.append(thumbnail_img)
+
+ return [self._transform(img) for img in processed_images]
+
+ def compute_params(
+ self,
+ media_list: list[ImageMedia | VideoFrameMedia],
+ num_tokens_available: int | None = None,
+ max_num_tiles: int | None = None,
+ **kwargs,
+ ) -> list[MatchTilingDynamicResolutionParams]:
+ # Implement tile degradation logic similar to TileDegradationStrategy
+ # Use provided max_num_tiles or fall back to instance's max_num_tiles
+ # Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value
+ effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
+ effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles)
+ max_num_tiles_to_use = effective_max_num_tiles
+ degradation_map = self._tile_degradation_map
+
+ while True:
+ params = []
+ total_embeddings_needed = 0
+
+ for media in media_list:
+ if isinstance(media, ImageMedia):
+ # Use tiling logic for images
+ img_size = (media.width, media.height)
+ aspect_ratio = img_size[0] / img_size[1]
+
+ # Find the closest aspect ratio to the target
+ target_ratios = self.target_ratios[max_num_tiles_to_use]
+ tiling = self._find_closest_aspect_ratio_fn(
+ aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size
+ )
+
+ # Calculate target dimensions for dynamic resolution processing
+ target_width = self._tile_size * tiling[0]
+ target_height = self._tile_size * tiling[1]
+ num_embeddings = self._get_num_embeddings(target_width, target_height)
+
+ # Account for thumbnail (same logic as ImageTilingStrategyV1)
+ num_tiles = 1 # Base dynamic resolution image
+ blocks = tiling[0] * tiling[1]
+ if self._use_thumbnail and blocks != 1:
+ num_tiles += 1 # Add 1 for thumbnail
+ # Add embeddings for thumbnail (tile_size x tile_size)
+ num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size)
+
+ media_params = MatchTilingDynamicResolutionParams(
+ media=media,
+ num_tiles=num_tiles,
+ num_embeddings=num_embeddings,
+ tiling=tiling,
+ )
+ elif isinstance(media, VideoFrameMedia):
+ # Use video frame strategy for video frames (always 1 tile)
+ video_params = self._video_frame_strategy.compute_params(
+ [media], 1 * self._embeddings_per_tile
+ )[0]
+ media_params = MatchTilingDynamicResolutionParams(
+ media=media,
+ num_tiles=video_params.num_tiles,
+ num_embeddings=video_params.num_embeddings,
+ tiling=(1, 1), # Video frames always use 1x1 tiling
+ )
+ else:
+ raise ValueError(f"Unsupported media type: {type(media)}")
+
+ params.append(media_params)
+ total_embeddings_needed += media_params.num_embeddings
+
+ # Check if we need to degrade (only if degradation is enabled)
+ if not self._enable_tile_degradation:
+ break
+ if max_num_tiles_to_use == 1 or num_tokens_available is None:
+ break
+ if total_embeddings_needed > num_tokens_available:
+ if max_num_tiles_to_use in degradation_map:
+ max_num_tiles_to_use = degradation_map[max_num_tiles_to_use]
+ # Recalculate target ratios for the new max_num_tiles_to_use
+ if max_num_tiles_to_use not in self.target_ratios:
+ self.target_ratios[max_num_tiles_to_use] = sorted(
+ set(
+ (x, y)
+ for n in range(self._min_num_tiles, max_num_tiles_to_use + 1)
+ for x in range(1, n + 1)
+ for y in range(1, n + 1)
+ if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles
+ ),
+ key=lambda x: x[0] * x[1],
+ )
+ else:
+ # End of degradation
+ break
+ else:
+ break
+
+ return params
+
+ def stack(
+ self, images: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]:
+ """Stack images using dynamic resolution approach with sequence packing"""
+ imgs_sizes = torch.tensor(
+ [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
+ )
+
+ def rearrange_img(x):
+ py = x.shape[-2] // self._patch_size
+ px = x.shape[-1] // self._patch_size
+ x = einops.rearrange(
+ x,
+ "c (py yy) (px xx) -> (py px) (c yy xx)",
+ py=py,
+ yy=self._patch_size,
+ px=px,
+ xx=self._patch_size,
+ )
+ return x
+
+ if len(images) > 0:
+ imgs = [rearrange_img(img) for img in images]
+
+ current_length = 0
+ max_length = 0
+ vision_cu_lengths = [0]
+ for img in imgs:
+ if max_length < img.shape[0]:
+ max_length = img.shape[0]
+ current_length += img.shape[0]
+ vision_cu_lengths.append(current_length)
+
+ vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
+ vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
+
+ return (
+ torch.cat(imgs, dim=0).unsqueeze(0),
+ imgs_sizes,
+ vision_cu_lengths,
+ vision_max_lengths,
+ )
+ else:
+ return (
+ torch.tensor([[0]], dtype=torch.float32),
+ torch.tensor([[0,0]], dtype=torch.int32),
+ None,
+ None,
+ )
+
+ def __str__(self):
+ return f"MatchTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})"
+
+
+@dataclass
+class MaskedTilingDynamicResolutionParams(ImageTilingParams):
+ tiling: tuple[int, int]
+
+
+class MaskedTilingDynamicResolutionStrategy(ImageTilingStrategy):
+ """
+ Like MatchTilingDynamicResolutionStrategy, but ensures tiles are isolated in the
+ vision encoder by emitting per-tile packed samples (block-diagonal attention across tiles).
+ """
+
+ def __init__(
+ self,
+ vision_model_type: str,
+ tile_size: int,
+ use_thumbnail: bool,
+ min_num_tiles: int,
+ max_num_tiles: int,
+ embeddings_per_tile: int,
+ patch_size: int,
+ get_num_embeddings: Callable[[int, int], int],
+ find_closest_aspect_ratio_fn=find_closest_aspect_ratio,
+ pixel_shuffle: bool = False,
+ conv_merging: bool = False,
+ tile_degradation_map: dict[int, int] = None,
+ video_frame_strategy: ImageTilingStrategy = None,
+ enable_tile_degradation: bool = True,
+ ):
+ assert "radio" in vision_model_type, (
+ "MaskedTilingDynamicResolution is only supported for radio models"
+ )
+
+ self._vision_model_type = vision_model_type
+ self._tile_size = tile_size
+ self._use_thumbnail = use_thumbnail
+ self._min_num_tiles = min_num_tiles
+ self._max_num_tiles = max_num_tiles
+ self._embeddings_per_tile = embeddings_per_tile
+ self._patch_size = patch_size
+ self._get_num_embeddings = get_num_embeddings
+ self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn
+ self._pixel_shuffle = pixel_shuffle
+ self._conv_merging = conv_merging
+ self._enable_tile_degradation = enable_tile_degradation
+
+ if tile_degradation_map is None:
+ self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1}
+ else:
+ self._tile_degradation_map = tile_degradation_map
+
+ if video_frame_strategy is None:
+ self._video_frame_strategy = NoTilingStrategy(
+ vision_model_type=vision_model_type,
+ target_width=tile_size,
+ target_height=tile_size,
+ embeddings_per_image=embeddings_per_tile,
+ )
+ else:
+ self._video_frame_strategy = video_frame_strategy
+
+ self.target_ratios = {
+ max_num_tiles: sorted(
+ set(
+ (x, y)
+ for n in range(self._min_num_tiles, max_num_tiles + 1)
+ for x in range(1, n + 1)
+ for y in range(1, n + 1)
+ if x * y <= max_num_tiles and x * y >= self._min_num_tiles
+ ),
+ key=lambda x: x[0] * x[1],
+ )
+ for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1)
+ }
+
+ pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
+ self._transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.ToTensor(),
+ T.Normalize(mean=pixel_mean, std=pixel_std),
+ ]
+ )
+
+ def apply_params(self, params: MaskedTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]:
+ # Handle video frames using the video frame strategy
+ if isinstance(params.media, VideoFrameMedia):
+ return self._video_frame_strategy.apply_params(params, **kwargs)
+
+ image = params.media.value
+ nx, ny = params.tiling
+ target_width = self._tile_size * nx
+ target_height = self._tile_size * ny
+
+ resized_img = image.resize((target_width, target_height))
+
+ processed_images = []
+ # Emit per-tile images (each becomes an isolated packed sample later)
+ for j in range(ny):
+ for i in range(nx):
+ box = (
+ i * self._tile_size,
+ j * self._tile_size,
+ (i + 1) * self._tile_size,
+ (j + 1) * self._tile_size,
+ )
+ tile_img = resized_img.crop(box)
+ processed_images.append(tile_img)
+
+ if self._use_thumbnail and (nx * ny) != 1:
+ thumbnail_img = image.resize((self._tile_size, self._tile_size))
+ processed_images.append(thumbnail_img)
+
+ return [self._transform(img) for img in processed_images]
+
+ def compute_params(
+ self,
+ media_list: list[ImageMedia | VideoFrameMedia],
+ num_tokens_available: int | None = None,
+ max_num_tiles: int | None = None,
+ data_augment: bool = False,
+ tiling_augment_prob: float = 0.4,
+ **kwargs,
+ ) -> list[MaskedTilingDynamicResolutionParams]:
+ effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
+ effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles)
+ max_num_tiles_to_use = effective_max_num_tiles
+ degradation_map = self._tile_degradation_map
+
+ while True:
+ params = []
+ total_embeddings_needed = 0
+
+ for media in media_list:
+ if isinstance(media, ImageMedia):
+ img_size = (media.width, media.height)
+ aspect_ratio = img_size[0] / img_size[1]
+
+ target_ratios = self.target_ratios[max_num_tiles_to_use]
+ tiling = self._find_closest_aspect_ratio_fn(
+ aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size
+ )
+
+ # Apply tiling augmentation if enabled
+ if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob:
+ tiling = self.augment_tiling(tiling)
+
+ blocks = tiling[0] * tiling[1]
+ # Each tile is tile_size x tile_size
+ per_tile_emb = self._get_num_embeddings(self._tile_size, self._tile_size)
+ num_embeddings = blocks * per_tile_emb
+
+ num_tiles = blocks
+ if self._use_thumbnail and blocks != 1:
+ num_tiles += 1
+ num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size)
+
+ media_params = MaskedTilingDynamicResolutionParams(
+ media=media,
+ num_tiles=num_tiles,
+ num_embeddings=num_embeddings,
+ tiling=tiling,
+ )
+ elif isinstance(media, VideoFrameMedia):
+ video_params = self._video_frame_strategy.compute_params(
+ [media], 1 * self._embeddings_per_tile
+ )[0]
+ media_params = MaskedTilingDynamicResolutionParams(
+ media=media,
+ num_tiles=video_params.num_tiles,
+ num_embeddings=video_params.num_embeddings,
+ tiling=(1, 1),
+ )
+ else:
+ raise ValueError(f"Unsupported media type: {type(media)}")
+
+ params.append(media_params)
+ total_embeddings_needed += media_params.num_embeddings
+
+ if not self._enable_tile_degradation:
+ break
+ if max_num_tiles_to_use == 1 or num_tokens_available is None:
+ break
+ if total_embeddings_needed > num_tokens_available:
+ if max_num_tiles_to_use in degradation_map:
+ max_num_tiles_to_use = degradation_map[max_num_tiles_to_use]
+ if max_num_tiles_to_use not in self.target_ratios:
+ self.target_ratios[max_num_tiles_to_use] = sorted(
+ set(
+ (x, y)
+ for n in range(self._min_num_tiles, max_num_tiles_to_use + 1)
+ for x in range(1, n + 1)
+ for y in range(1, n + 1)
+ if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles
+ ),
+ key=lambda x: x[0] * x[1],
+ )
+ else:
+ break
+ else:
+ break
+
+ return params
+
+ def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]:
+ def num_tiles(tiling: tuple[int, int]) -> int:
+ return tiling[0] * tiling[1]
+
+ def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]:
+ if random.random() < minus_prob:
+ # Minus one
+ if tiling[0] == 1 and tiling[1] == 1:
+ return tiling
+ elif tiling[0] == 1:
+ return (tiling[0], tiling[1] - 1)
+ elif tiling[1] == 1:
+ return (tiling[0] - 1, tiling[1])
+ else:
+ if random.random() < 0.5:
+ return (tiling[0] - 1, tiling[1])
+ else:
+ return (tiling[0], tiling[1] - 1)
+ else:
+ # Plus one
+ if num_tiles(tiling) < self._max_num_tiles:
+ tiling0 = (tiling[0] + 1, tiling[1])
+ tiling1 = (tiling[0], tiling[1] + 1)
+ if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles:
+ return tiling
+ elif num_tiles(tiling0) > self._max_num_tiles:
+ return tiling1
+ elif num_tiles(tiling1) > self._max_num_tiles:
+ return tiling0
+ else:
+ if random.random() < 0.5:
+ return tiling0
+ else:
+ return tiling1
+ return tiling
+
+ new_tiling = plus_minus_one(tiling)
+ return new_tiling
+
+ def stack(
+ self, images: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]:
+ # Identical to dynamic resolution packing; each tile is already an independent image sample
+ imgs_sizes = torch.tensor(
+ [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
+ )
+
+ def rearrange_img(x):
+ py = x.shape[-2] // self._patch_size
+ px = x.shape[-1] // self._patch_size
+ x = einops.rearrange(
+ x,
+ "c (py yy) (px xx) -> (py px) (c yy xx)",
+ py=py,
+ yy=self._patch_size,
+ px=px,
+ xx=self._patch_size,
+ )
+ return x
+
+ if len(images) > 0:
+ imgs = [rearrange_img(img) for img in images]
+
+ current_length = 0
+ max_length = 0
+ vision_cu_lengths = [0]
+ for img in imgs:
+ if max_length < img.shape[0]:
+ max_length = img.shape[0]
+ current_length += img.shape[0]
+ vision_cu_lengths.append(current_length)
+
+ vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
+ vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
+
+ return (
+ torch.cat(imgs, dim=0).unsqueeze(0),
+ imgs_sizes,
+ vision_cu_lengths,
+ vision_max_lengths,
+ )
+ else:
+ return (
+ torch.tensor([[0]], dtype=torch.float32),
+ torch.tensor([[0,0]], dtype=torch.int32),
+ None,
+ None,
+ )
+
+ def __str__(self):
+ return f"MaskedTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})"
+
+def create_image_tiling_strategy(args):
+ """
+ Create an image tiling strategy based on the provided arguments.
+
+ This function encapsulates the logic for creating the appropriate image tiling strategy
+ based on the training/evaluation configuration. It can be used by both training (task_encoder)
+ and evaluation code outside of data_loading/.
+
+ Args:
+ args: Arguments object with the following relevant attributes:
+ - img_h, img_w: Image height and width
+ - patch_dim: Patch dimension
+ - vision_model_type: Vision model type (e.g., 'radio', 'clip', 'siglip')
+ - disable_vision_class_token: Whether to disable vision class token
+ - pixel_shuffle: Whether to use pixel shuffle
+ - use_tile_tags: Whether to use tile tags
+ - max_num_tiles: Maximum number of tiles
+ - tokenizer_prompt_format: Tokenizer prompt format
+ - image_break_token: Image break token (optional)
+ - conv_merging: Whether to use convolution merging
+ - dynamic_resolution: Whether to use dynamic resolution
+ - match_tiling_dynamic_resolution: Whether to match tiling with dynamic resolution
+ - use_area_weighted_aspect_ratio: Whether to use area-weighted aspect ratio
+ - use_thumbnail: Whether to use thumbnail
+ - dynamic_resolution_min_patches: Minimum number of patches for dynamic resolution
+ - dynamic_resolution_min_side: Minimum side length for dynamic resolution (optional)
+ - thumbnail_area_threshold: Thumbnail area threshold (optional)
+ - use_tiling: Whether to use tiling
+
+ Returns:
+ ImageTilingStrategy: The created image tiling strategy
+ """
+ from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
+
+ assert args.img_h == args.img_w, "img_h and img_w must be the same"
+
+ match_tiling_dynamic_resolution = args.match_tiling_dynamic_resolution
+ masked_tiling_dynamic_resolution = getattr(args, "masked_tiling_dynamic_resolution", False)
+ dynamic_resolution = args.dynamic_resolution
+ use_tiling = args.use_tiling
+ use_area_weighted_aspect_ratio = args.use_area_weighted_aspect_ratio
+
+ if match_tiling_dynamic_resolution:
+ assert dynamic_resolution, "must enable --dynamic-resolution if using --match-tiling-dynamic-resolution"
+ assert not use_tiling, "cannot use --use-tiling and --match-tiling-dynamic-resolution together"
+ if masked_tiling_dynamic_resolution:
+ assert dynamic_resolution, "must enable --dynamic-resolution if using --masked-tiling-dynamic-resolution"
+ assert not use_tiling, "cannot use --use-tiling and --masked-tiling-dynamic-resolution together"
+ assert not match_tiling_dynamic_resolution, "cannot combine --masked-tiling-dynamic-resolution with --match-tiling-dynamic-resolution"
+
+ if dynamic_resolution:
+ if masked_tiling_dynamic_resolution:
+ num_image_embeddings_per_tile = get_num_image_embeddings(
+ img_h=args.img_h,
+ img_w=args.img_w,
+ patch_dim=args.patch_dim,
+ vision_model_type=args.vision_model_type,
+ disable_vision_class_token=args.disable_vision_class_token,
+ class_token_len=1,
+ pixel_shuffle=args.pixel_shuffle,
+ use_tile_tags=args.use_tile_tags,
+ max_num_tiles=args.max_num_tiles,
+ tokenizer_type=args.tokenizer_prompt_format,
+ use_image_break_token=args.image_break_token is not None,
+ conv_merging=args.conv_merging,
+ )
+ image_tiling_strategy = MaskedTilingDynamicResolutionStrategy(
+ vision_model_type=args.vision_model_type,
+ tile_size=args.img_h,
+ use_thumbnail=args.use_thumbnail,
+ min_num_tiles=1,
+ max_num_tiles=args.max_num_tiles,
+ embeddings_per_tile=num_image_embeddings_per_tile,
+ patch_size=args.patch_dim,
+ get_num_embeddings=lambda width, height: get_num_image_embeddings(
+ img_h=height,
+ img_w=width,
+ patch_dim=args.patch_dim,
+ vision_model_type=args.vision_model_type,
+ disable_vision_class_token=args.disable_vision_class_token,
+ class_token_len=1,
+ pixel_shuffle=args.pixel_shuffle,
+ use_tile_tags=args.use_tile_tags,
+ max_num_tiles=args.max_num_tiles,
+ tokenizer_type=args.tokenizer_prompt_format,
+ use_image_break_token=args.image_break_token is not None,
+ conv_merging=args.conv_merging,
+ ),
+ find_closest_aspect_ratio_fn=(
+ find_closest_area_weighted_aspect_ratio
+ if use_area_weighted_aspect_ratio
+ else find_closest_aspect_ratio
+ ),
+ pixel_shuffle=args.pixel_shuffle,
+ conv_merging=args.conv_merging,
+ )
+ elif match_tiling_dynamic_resolution:
+ num_image_embeddings_per_tile = get_num_image_embeddings(
+ img_h=args.img_h,
+ img_w=args.img_w,
+ patch_dim=args.patch_dim,
+ vision_model_type=args.vision_model_type,
+ disable_vision_class_token=args.disable_vision_class_token,
+ class_token_len=1,
+ pixel_shuffle=args.pixel_shuffle,
+ use_tile_tags=args.use_tile_tags,
+ max_num_tiles=args.max_num_tiles,
+ tokenizer_type=args.tokenizer_prompt_format,
+ use_image_break_token=args.image_break_token is not None,
+ conv_merging=args.conv_merging,
+ )
+ image_tiling_strategy = MatchTilingDynamicResolutionStrategy(
+ vision_model_type=args.vision_model_type,
+ tile_size=args.img_h,
+ use_thumbnail=args.use_thumbnail,
+ min_num_tiles=1,
+ max_num_tiles=args.max_num_tiles,
+ embeddings_per_tile=num_image_embeddings_per_tile,
+ patch_size=args.patch_dim,
+ get_num_embeddings=lambda width, height: get_num_image_embeddings(
+ img_h=height,
+ img_w=width,
+ patch_dim=args.patch_dim,
+ vision_model_type=args.vision_model_type,
+ disable_vision_class_token=args.disable_vision_class_token,
+ class_token_len=1,
+ pixel_shuffle=args.pixel_shuffle,
+ use_tile_tags=args.use_tile_tags,
+ max_num_tiles=args.max_num_tiles,
+ tokenizer_type=args.tokenizer_prompt_format,
+ use_image_break_token=args.image_break_token is not None,
+ conv_merging=args.conv_merging,
+ ),
+ find_closest_aspect_ratio_fn=(
+ find_closest_area_weighted_aspect_ratio
+ if use_area_weighted_aspect_ratio
+ else find_closest_aspect_ratio
+ ),
+ pixel_shuffle=args.pixel_shuffle,
+ conv_merging=args.conv_merging,
+ )
+ else:
+ image_tiling_strategy = DynamicResolutionImageTilingStrategy(
+ vision_model_type=args.vision_model_type,
+ min_num_patches=args.dynamic_resolution_min_patches,
+ patch_size=args.patch_dim,
+ get_num_embeddings=lambda width, height: get_num_image_embeddings(
+ img_h=height,
+ img_w=width,
+ patch_dim=args.patch_dim,
+ vision_model_type=args.vision_model_type,
+ disable_vision_class_token=args.disable_vision_class_token,
+ class_token_len=1,
+ pixel_shuffle=args.pixel_shuffle,
+ use_tile_tags=args.use_tile_tags,
+ max_num_tiles=args.max_num_tiles,
+ tokenizer_type=args.tokenizer_prompt_format,
+ use_image_break_token=args.image_break_token is not None,
+ conv_merging=args.conv_merging,
+ ),
+ pixel_shuffle=args.pixel_shuffle,
+ min_side=args.dynamic_resolution_min_side,
+ conv_merging=args.conv_merging,
+ use_thumbnail=args.use_thumbnail,
+ thumbnail_size=args.img_h,
+ thumbnail_area_threshold=args.thumbnail_area_threshold,
+ max_num_patches=args.dynamic_resolution_max_patches,
+ apply_data_augment=args.apply_data_augment,
+ )
+ else:
+ num_image_embeddings_per_tile = get_num_image_embeddings(
+ img_h=args.img_h,
+ img_w=args.img_w,
+ patch_dim=args.patch_dim,
+ vision_model_type=args.vision_model_type,
+ disable_vision_class_token=args.disable_vision_class_token,
+ class_token_len=1,
+ pixel_shuffle=args.pixel_shuffle,
+ use_tile_tags=args.use_tile_tags,
+ max_num_tiles=args.max_num_tiles,
+ tokenizer_type=args.tokenizer_prompt_format,
+ use_image_break_token=args.image_break_token is not None,
+ conv_merging=args.conv_merging,
+ )
+ if use_tiling:
+ image_strategy = ImageTilingStrategyV1(
+ vision_model_type=args.vision_model_type,
+ tile_size=args.img_h,
+ use_thumbnail=args.use_thumbnail,
+ min_num_tiles=1,
+ max_num_tiles=args.max_num_tiles,
+ embeddings_per_tile=num_image_embeddings_per_tile,
+ find_closest_aspect_ratio_fn=(
+ find_closest_area_weighted_aspect_ratio
+ if use_area_weighted_aspect_ratio
+ else find_closest_aspect_ratio
+ ),
+ )
+ else:
+ image_strategy = NoTilingStrategy(
+ vision_model_type=args.vision_model_type,
+ embeddings_per_image=num_image_embeddings_per_tile,
+ target_width=args.img_w,
+ target_height=args.img_h,
+ )
+ image_tiling_strategy = TileDegradationStrategy(
+ image_strategy=image_strategy,
+ video_frame_strategy=NoTilingStrategy(
+ vision_model_type=args.vision_model_type,
+ embeddings_per_image=num_image_embeddings_per_tile,
+ target_width=args.img_w,
+ target_height=args.img_h,
+ ),
+ embeddings_per_tile=num_image_embeddings_per_tile,
+ max_num_tiles=args.max_num_tiles,
+ )
+
+ return image_tiling_strategy
From 1bceb28678fcf6f7a5c39cd082222023b393e5e5 Mon Sep 17 00:00:00 2001
From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Date: Thu, 11 Dec 2025 17:03:35 +0200
Subject: [PATCH 03/10] import image processing into
model_executor/models/nano_nemotron_vl.py
---
.../model_executor/models/nano_nemotron_vl.py | 489 +++++++++++++++++-
1 file changed, 488 insertions(+), 1 deletion(-)
diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
index 6dfab595e5b92..1fbbf06d0a695 100644
--- a/vllm/model_executor/models/nano_nemotron_vl.py
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -7,10 +7,14 @@
# LICENSE is in root directory.
# --------------------------------------------------------
+import math
+import random
+from dataclasses import dataclass
+
import copy
import warnings
from abc import ABC, abstractmethod
-from collections.abc import Iterable, Mapping, Sequence
+from collections.abc import Iterable, Mapping, Sequence, Callable
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy.typing as npt
@@ -20,6 +24,7 @@ import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType
+import einops
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
@@ -84,6 +89,488 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Alternative: Set a specific higher limit
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
+IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
+IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
+SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
+SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
+CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
+CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
+RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060]
+RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250]
+
+
+pixel_statistics = {
+ "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
+ "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
+ "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
+ "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
+ "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
+ "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
+ "radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
+ "cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
+ "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
+}
+
+
+@dataclass
+class DynamicResolutionParams:
+ image: Image.Image
+ num_tiles: int
+ num_embeddings: int
+ patch_size: tuple[int, int]
+
+
+class DynamicResolutionImageTilingStrategy:
+ def __init__(
+ self,
+ vision_model_type: str,
+ min_num_patches: int,
+ patch_size: int,
+ get_num_embeddings: Callable[[int, int], int],
+ factor_max: float = 1.0,
+ pixel_shuffle: bool = False,
+ min_side: int | None = None,
+ conv_merging: bool = False,
+ use_thumbnail: bool = False,
+ thumbnail_size: int = 448,
+ thumbnail_area_threshold: float = 0.8,
+ max_num_patches: int = 0,
+ apply_data_augment: bool = False,
+ ):
+ assert "radio" in vision_model_type, (
+ "Dynamic resolution is only supported for radio models"
+ )
+ self._vision_model_type = vision_model_type
+ self._min_num_patches = min_num_patches
+ self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf")
+ self._patch_size = patch_size
+ self._get_num_embeddings = get_num_embeddings
+ self._factor_max = factor_max
+ self._pixel_shuffle = pixel_shuffle
+ self._min_side = min_side
+ self._conv_merging = conv_merging
+ self._use_thumbnail = use_thumbnail
+ self._thumbnail_size = thumbnail_size
+ self._thumbnail_area_threshold = thumbnail_area_threshold
+ pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
+ self._transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)),
+ T.Normalize(mean=pixel_mean, std=pixel_std),
+ ]
+ )
+ self._apply_data_augment = apply_data_augment
+
+ def apply_params(
+ self, params: DynamicResolutionParams, **kwargs
+ ) -> list[torch.Tensor]:
+ # resize the image
+ resized_img = params.image.resize(
+ (
+ params.patch_size[0] * self._patch_size,
+ params.patch_size[1] * self._patch_size,
+ )
+ )
+ processed_images = [resized_img]
+
+ # Add thumbnail if enabled and image area is below threshold
+ if self._use_thumbnail:
+ # Calculate areas
+ resized_area = resized_img.size[0] * resized_img.size[1]
+ thumbnail_area = self._thumbnail_size * self._thumbnail_size
+ area_ratio = resized_area / thumbnail_area
+
+ # Only add thumbnail if resized image area is less than threshold % of thumbnail area
+ if area_ratio < self._thumbnail_area_threshold:
+ thumbnail_img = params.image.resize(
+ (self._thumbnail_size, self._thumbnail_size)
+ )
+ processed_images.append(thumbnail_img)
+
+ return [self._transform(img) for img in processed_images]
+
+ def process_media(
+ self,
+ image: Image.Image,
+ num_tokens_available: int,
+ data_augment: bool = False,
+ tiling_augment_prob: float = 0.4,
+ ) -> DynamicResolutionParams:
+ """Process a single media item and return its parameters.
+
+ Args:
+ media: The media item to process
+ num_tokens_available: Number of tokens available for this media
+ data_augment: Whether to apply data augmentation to the image. Defaults to False.
+ Returns:
+ DynamicResolutionParams for the media
+ """
+ current_num_tokens_available = num_tokens_available
+ assert isinstance(image, Image.Image), (
+ "Dynamic resolution is only supported for image media"
+ )
+ orig_width, orig_height = image.width, image.height
+
+ closest_patch_height = round(orig_height / self._patch_size + 0.5)
+ closest_patch_width = round(orig_width / self._patch_size + 0.5)
+ patches = closest_patch_height * closest_patch_width
+
+ factor = min(
+ math.sqrt(current_num_tokens_available / patches), self._factor_max
+ )
+ target_patch_height = math.floor(factor * closest_patch_height)
+ target_patch_width = math.floor(factor * closest_patch_width)
+
+ # We only consider self._min_num_patches if it is greater than current_num_tokens_available.
+ if (
+ current_num_tokens_available > self._min_num_patches
+ and target_patch_height * target_patch_width < self._min_num_patches
+ ):
+ up_factor = math.sqrt(
+ self._min_num_patches / (target_patch_height * target_patch_width)
+ )
+ target_patch_height = math.ceil(up_factor * target_patch_height)
+ target_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if (
+ self._min_side is not None
+ and min(target_patch_width, target_patch_height) * self._patch_size
+ < self._min_side
+ ):
+ if target_patch_width <= target_patch_height:
+ up_factor = self._min_side / (target_patch_width * self._patch_size)
+ new_patch_height = math.ceil(up_factor * target_patch_height)
+ new_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if new_patch_height * new_patch_width > current_num_tokens_available:
+ # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
+ if (
+ max(current_num_tokens_available // new_patch_width, 1)
+ * self._patch_size
+ < self._min_side
+ ):
+ up_factor = math.sqrt(
+ current_num_tokens_available
+ / (target_patch_height * target_patch_width)
+ )
+ target_patch_height = math.floor(
+ up_factor * target_patch_height
+ )
+ target_patch_width = math.floor(up_factor * target_patch_width)
+ target_patch_width = new_patch_width
+ target_patch_height = max(
+ current_num_tokens_available // new_patch_width, 1
+ )
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = new_patch_width
+ else:
+ up_factor = self._min_side / (target_patch_height * self._patch_size)
+ new_patch_height = math.ceil(up_factor * target_patch_height)
+ new_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if new_patch_height * new_patch_width > current_num_tokens_available:
+ # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
+ if (
+ max(current_num_tokens_available // new_patch_height, 1)
+ * self._patch_size
+ < self._min_side
+ ):
+ up_factor = math.sqrt(
+ current_num_tokens_available
+ / (target_patch_height * target_patch_width)
+ )
+ target_patch_height = math.floor(
+ up_factor * target_patch_height
+ )
+ target_patch_width = math.floor(up_factor * target_patch_width)
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = max(
+ current_num_tokens_available // new_patch_height, 1
+ )
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = new_patch_width
+
+ # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
+ # or by 4 when BOTH are enabled (two successive 2x reductions)
+ if self._pixel_shuffle or self._conv_merging:
+ required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2
+
+ rem_h = target_patch_height % required_divisor
+ if rem_h != 0:
+ inc_h = required_divisor - rem_h
+ if (
+ target_patch_height + inc_h
+ ) * target_patch_width <= current_num_tokens_available:
+ target_patch_height += inc_h
+ else:
+ target_patch_height = max(
+ required_divisor, target_patch_height - rem_h
+ )
+
+ rem_w = target_patch_width % required_divisor
+ if rem_w != 0:
+ inc_w = required_divisor - rem_w
+ if (
+ target_patch_height * (target_patch_width + inc_w)
+ <= current_num_tokens_available
+ ):
+ target_patch_width += inc_w
+ else:
+ target_patch_width = max(
+ required_divisor, target_patch_width - rem_w
+ )
+
+ if (
+ data_augment
+ and self._apply_data_augment
+ and random.random() < tiling_augment_prob
+ ):
+ target_patch_width, target_patch_height = self.augment_resolution(
+ target_patch_width, target_patch_height, current_num_tokens_available
+ )
+
+ assert isinstance(image, Image.Image), (
+ "Dynamic resolution is only supported for image media"
+ )
+
+ # Calculate embeddings for the main dynamic resolution image
+ num_embeddings = self._get_num_embeddings(
+ target_patch_width * self._patch_size,
+ target_patch_height * self._patch_size,
+ )
+
+ token_count = target_patch_width * target_patch_height
+
+ # Add thumbnail embeddings if enabled and image area is below threshold
+ num_tiles = 1 # Base dynamic resolution image
+ if self._use_thumbnail:
+ # Calculate areas
+ resized_area = (target_patch_width * self._patch_size) * (
+ target_patch_height * self._patch_size
+ )
+ thumbnail_area = self._thumbnail_size * self._thumbnail_size
+ area_ratio = resized_area / thumbnail_area
+
+ # Only add thumbnail if resized image area is less than threshold % of thumbnail area
+ if area_ratio < self._thumbnail_area_threshold:
+ num_tiles += 1 # Add 1 for thumbnail
+ # Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
+ num_embeddings += self._get_num_embeddings(
+ self._thumbnail_size, self._thumbnail_size
+ )
+ token_count += (
+ self._thumbnail_size
+ // self._patch_size
+ * self._thumbnail_size
+ // self._patch_size
+ )
+
+ return DynamicResolutionParams(
+ image=image,
+ num_tiles=num_tiles,
+ num_embeddings=num_embeddings,
+ patch_size=(target_patch_width, target_patch_height),
+ ), token_count
+
+ def augment_resolution(
+ self,
+ target_patch_width: int,
+ target_patch_height: int,
+ current_num_tokens_available: int,
+ ) -> tuple[int, int]:
+ min_num_patch_one_side = 32
+
+ if random.random() < 0.5:
+ # Minus one
+ if (
+ target_patch_width <= min_num_patch_one_side
+ and target_patch_height <= min_num_patch_one_side
+ ):
+ return target_patch_width, target_patch_height
+ elif target_patch_width <= min_num_patch_one_side:
+ return target_patch_width, target_patch_height - min_num_patch_one_side
+ elif target_patch_height <= min_num_patch_one_side:
+ return target_patch_width - min_num_patch_one_side, target_patch_height
+ else:
+ if random.random() < 0.5:
+ return (
+ target_patch_width - min_num_patch_one_side,
+ target_patch_height,
+ )
+ else:
+ return (
+ target_patch_width,
+ target_patch_height - min_num_patch_one_side,
+ )
+ else:
+ # Plus one
+ if target_patch_width * target_patch_height < current_num_tokens_available:
+ if random.random() < 0.5:
+ return (
+ target_patch_width + min_num_patch_one_side,
+ target_patch_height,
+ )
+ else:
+ return (
+ target_patch_width,
+ target_patch_height + min_num_patch_one_side,
+ )
+ return target_patch_width, target_patch_height
+
+ def compute_params(
+ self,
+ media_list: list[Image.Image],
+ num_tokens_available: int | None = None,
+ max_num_tiles: int | None = None,
+ data_augment: bool = False,
+ **kwargs,
+ ) -> list[DynamicResolutionParams]:
+ """Compute parameters for all media with iterative token budgeting.
+
+ Args:
+ media_list: List of media items to process
+ num_tokens_available: Total number of tokens available across all media
+ max_num_tiles: Maximum number of tiles (unused in this implementation)
+ data_augment: Whether to apply data augmentation to the image. Defaults to False.
+ Returns:
+ List of ImageTilingParams for each media item
+ """
+ num_tokens_available = (
+ num_tokens_available
+ * (4 if self._pixel_shuffle else 1)
+ * (4 if self._conv_merging else 1)
+ )
+ # When the number of available token is too small, allow self._min_num_patches per media and
+ # let the sample be truncated.
+ num_tokens_available = max(
+ num_tokens_available, self._min_num_patches * len(media_list)
+ )
+
+ # Clip the number of tokens available per media to be between min and max patches.
+ num_tokens_available_per_media = [
+ max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
+ for _ in range(len(media_list))
+ ]
+
+ # In theory this could be a while True loop, but in case the process_media method slightly
+ # changes, I want to make sure we don't get stuck in an infinite loop.
+ for _ in range(10):
+ # Step 1: Process each media with current token budget
+ params = []
+ token_counts = []
+
+ for media, tokens_for_media in zip(
+ media_list, num_tokens_available_per_media
+ ):
+ param, token_count = self.process_media(
+ media, tokens_for_media, data_augment=data_augment
+ )
+ params.append(param)
+ token_counts.append(token_count)
+
+ # Step 2: Check if total tokens is within budget
+ total_tokens = sum(token_counts)
+
+ if total_tokens <= num_tokens_available:
+ # We're within budget, return the params
+ return params
+
+ # Step 3: We're over budget, need to scale down
+ # Calculate scaling factor to get under budget
+ scaling_factor = num_tokens_available / total_tokens
+
+ # Recalculate token budgets for each media based on scaling
+ # Each media gets a proportional share of the total budget
+ scaled_down_num_tokens_available_per_media = [
+ max(self._min_num_patches, int(token_count * scaling_factor))
+ for token_count in token_counts
+ ]
+ scaled_down = any(
+ [
+ scaled_down_num_tokens_available_per_media[i]
+ < num_tokens_available_per_media[i]
+ for i in range(len(num_tokens_available_per_media))
+ ]
+ )
+ # If there was not scaling down, we're stuck just use min_num_patches per media, else
+ # try with the scaled down num_tokens_available_per_media.
+ if not scaled_down:
+ num_tokens_available_per_media = [self._min_num_patches] * len(
+ media_list
+ )
+ else:
+ num_tokens_available_per_media = (
+ scaled_down_num_tokens_available_per_media
+ )
+ return params
+
+ def stack(
+ self, images: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
+ imgs_sizes = torch.tensor(
+ [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
+ )
+
+ def rearrange_img(x):
+ py = x.shape[-2] // self._patch_size
+ px = x.shape[-1] // self._patch_size
+ x = einops.rearrange(
+ x,
+ "c (py yy) (px xx) -> (py px) (c yy xx)",
+ py=py,
+ yy=self._patch_size,
+ px=px,
+ xx=self._patch_size,
+ )
+ return x
+
+ if len(images) > 0:
+ imgs = [rearrange_img(img) for img in images]
+
+ current_length = 0
+ max_length = 0
+ vision_cu_lengths = [0]
+ for img in imgs:
+ if max_length < img.shape[0]:
+ max_length = img.shape[0]
+ current_length += img.shape[0]
+ vision_cu_lengths.append(current_length)
+
+ vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
+ vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
+
+ return (
+ torch.cat(imgs, dim=0).unsqueeze(0),
+ imgs_sizes,
+ vision_cu_lengths,
+ vision_max_lengths,
+ )
+ else:
+ return (
+ torch.tensor([[0]], dtype=torch.float32),
+ torch.tensor([[0, 0]], dtype=torch.int32),
+ None,
+ None,
+ )
+
+ def __str__(self):
+ return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})"
+
+
+
+image_tiling_strategy = DynamicResolutionImageTilingStrategy(
+ vision_model_type="radio",
+ min_num_patches=4,
+ patch_size=16,
+ get_num_embeddings=lambda x, y: x * y * 2,
+ max_num_patches=64,
+)
+
+
IMG_START = "
"
IMG_END = ""
IMG_CONTEXT = ""
From 4e558858b8d0ffeb1d783e6ea4bc75432bb57bc5 Mon Sep 17 00:00:00 2001
From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Date: Thu, 11 Dec 2025 17:32:08 +0200
Subject: [PATCH 04/10] reformat commits
---
.../model_executor/models/nano_nemotron_vl.py | 56 ++++++++++++-------
1 file changed, 35 insertions(+), 21 deletions(-)
diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
index 1fbbf06d0a695..82423891c11c7 100644
--- a/vllm/model_executor/models/nano_nemotron_vl.py
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -7,16 +7,16 @@
# LICENSE is in root directory.
# --------------------------------------------------------
+import copy
import math
import random
-from dataclasses import dataclass
-
-import copy
import warnings
from abc import ABC, abstractmethod
-from collections.abc import Iterable, Mapping, Sequence, Callable
+from collections.abc import Callable, Iterable, Mapping, Sequence
+from dataclasses import dataclass
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
+import einops
import numpy.typing as npt
import regex as re
import torch
@@ -24,7 +24,6 @@ import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType
-import einops
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
@@ -181,7 +180,8 @@ class DynamicResolutionImageTilingStrategy:
thumbnail_area = self._thumbnail_size * self._thumbnail_size
area_ratio = resized_area / thumbnail_area
- # Only add thumbnail if resized image area is less than threshold % of thumbnail area
+ # Only add thumbnail if resized image area is less than threshold % of
+ # thumbnail area
if area_ratio < self._thumbnail_area_threshold:
thumbnail_img = params.image.resize(
(self._thumbnail_size, self._thumbnail_size)
@@ -198,11 +198,11 @@ class DynamicResolutionImageTilingStrategy:
tiling_augment_prob: float = 0.4,
) -> DynamicResolutionParams:
"""Process a single media item and return its parameters.
-
Args:
media: The media item to process
num_tokens_available: Number of tokens available for this media
- data_augment: Whether to apply data augmentation to the image. Defaults to False.
+ data_augment: Whether to apply data augmentation to the image. Defaults to
+ False.
Returns:
DynamicResolutionParams for the media
"""
@@ -222,7 +222,8 @@ class DynamicResolutionImageTilingStrategy:
target_patch_height = math.floor(factor * closest_patch_height)
target_patch_width = math.floor(factor * closest_patch_width)
- # We only consider self._min_num_patches if it is greater than current_num_tokens_available.
+ # We only consider self._min_num_patches if it is greater than
+ # current_num_tokens_available.
if (
current_num_tokens_available > self._min_num_patches
and target_patch_height * target_patch_width < self._min_num_patches
@@ -244,7 +245,8 @@ class DynamicResolutionImageTilingStrategy:
new_patch_width = math.ceil(up_factor * target_patch_width)
if new_patch_height * new_patch_width > current_num_tokens_available:
- # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
+ # If only one side can be min_side, make as big as possible at
+ # native aspect ratio while staying below max_patches
if (
max(current_num_tokens_available // new_patch_width, 1)
* self._patch_size
@@ -271,7 +273,8 @@ class DynamicResolutionImageTilingStrategy:
new_patch_width = math.ceil(up_factor * target_patch_width)
if new_patch_height * new_patch_width > current_num_tokens_available:
- # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
+ # If only one side can be min_side, make as big as possible at
+ # native aspect ratio while staying below max_patches
if (
max(current_num_tokens_available // new_patch_height, 1)
* self._patch_size
@@ -355,7 +358,8 @@ class DynamicResolutionImageTilingStrategy:
thumbnail_area = self._thumbnail_size * self._thumbnail_size
area_ratio = resized_area / thumbnail_area
- # Only add thumbnail if resized image area is less than threshold % of thumbnail area
+ # Only add thumbnail if resized image area is less than threshold % of
+ # thumbnail area
if area_ratio < self._thumbnail_area_threshold:
num_tiles += 1 # Add 1 for thumbnail
# Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
@@ -435,7 +439,8 @@ class DynamicResolutionImageTilingStrategy:
media_list: List of media items to process
num_tokens_available: Total number of tokens available across all media
max_num_tiles: Maximum number of tiles (unused in this implementation)
- data_augment: Whether to apply data augmentation to the image. Defaults to False.
+ data_augment: Whether to apply data augmentation to the image. Defaults to
+ False.
Returns:
List of ImageTilingParams for each media item
"""
@@ -444,19 +449,21 @@ class DynamicResolutionImageTilingStrategy:
* (4 if self._pixel_shuffle else 1)
* (4 if self._conv_merging else 1)
)
- # When the number of available token is too small, allow self._min_num_patches per media and
- # let the sample be truncated.
+ # When the number of available token is too small, allow self._min_num_patches
+ # per media and let the sample be truncated.
num_tokens_available = max(
num_tokens_available, self._min_num_patches * len(media_list)
)
- # Clip the number of tokens available per media to be between min and max patches.
+ # Clip the number of tokens available per media to be between min and max
+ # patches.
num_tokens_available_per_media = [
max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
for _ in range(len(media_list))
]
- # In theory this could be a while True loop, but in case the process_media method slightly
+ # In theory this could be a while True loop, but in case the process_media
+ # method slightly
# changes, I want to make sure we don't get stuck in an infinite loop.
for _ in range(10):
# Step 1: Process each media with current token budget
@@ -496,8 +503,8 @@ class DynamicResolutionImageTilingStrategy:
for i in range(len(num_tokens_available_per_media))
]
)
- # If there was not scaling down, we're stuck just use min_num_patches per media, else
- # try with the scaled down num_tokens_available_per_media.
+ # If there was not scaling down, we're stuck just use min_num_patches per
+ # media, else try with the scaled down num_tokens_available_per_media.
if not scaled_down:
num_tokens_available_per_media = [self._min_num_patches] * len(
media_list
@@ -558,8 +565,15 @@ class DynamicResolutionImageTilingStrategy:
)
def __str__(self):
- return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})"
-
+ return f"DynamicResolutionImageTransform(\
+ vision_model_type={self._vision_model_type}, \
+ min_num_patches={self._min_num_patches}, \
+ patch_size={self._patch_size}, \
+ pixel_shuffle={self._pixel_shuffle}, \
+ conv_merging={self._conv_merging}, \
+ use_thumbnail={self._use_thumbnail}, \
+ thumbnail_size={self._thumbnail_size}, \
+ thumbnail_area_threshold={self._thumbnail_area_threshold})"
image_tiling_strategy = DynamicResolutionImageTilingStrategy(
From 7b55a619e4dff16dba7f6e6c3e1b2c278fd68290 Mon Sep 17 00:00:00 2001
From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Date: Thu, 11 Dec 2025 18:04:01 +0200
Subject: [PATCH 05/10] rewire
---
image_processing.py | 1828 -----------------
.../model_executor/models/nano_nemotron_vl.py | 16 +-
2 files changed, 8 insertions(+), 1836 deletions(-)
delete mode 100644 image_processing.py
diff --git a/image_processing.py b/image_processing.py
deleted file mode 100644
index 5d43e764acadd..0000000000000
--- a/image_processing.py
+++ /dev/null
@@ -1,1828 +0,0 @@
-# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE.
-from abc import ABC, abstractmethod
-from dataclasses import dataclass
-import math
-from typing import Callable, Optional
-import numpy as np
-import random
-from PIL import Image
-import albumentations as A
-
-import einops
-import torch
-from torchvision import transforms as T
-from torchvision.transforms import Compose
-from torchvision.transforms.functional import InterpolationMode
-
-from data_loading.conversation_sample import (
- ImageMedia,
- VideoFrameMedia,
-)
-
-IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
-IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
-SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
-SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
-CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
-CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
-RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060]
-RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250]
-
-
-pixel_statistics = {
- "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
- "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
- "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
- "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
- "radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
-}
-
-
-# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685
-# Copyright (c) 2023 OpenGVLab.
-def find_closest_aspect_ratio(
- aspect_ratio: float,
- target_ratios: list[tuple[int, int]],
- width: int,
- height: int,
- image_size: int,
-) -> tuple[int, int]:
- best_ratio_diff = float("inf")
- best_ratio = (1, 1)
- area = width * height
- for ratio in target_ratios:
- target_aspect_ratio = ratio[0] / ratio[1]
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
- if ratio_diff < best_ratio_diff:
- best_ratio_diff = ratio_diff
- best_ratio = ratio
- elif ratio_diff == best_ratio_diff:
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
- best_ratio = ratio
- return best_ratio
-
-
-def find_closest_area_weighted_aspect_ratio(
- aspect_ratio: float,
- target_ratios: list[tuple[int, int]],
- width: int,
- height: int,
- image_size: int,
-):
- """
- Find the best number of tiles based on the aspect ratio and the area covered by the tiles.
- """
- best_factor = float("-inf")
- best_ratio = (1, 1)
- area = width * height
- for ratio in target_ratios:
- target_aspect_ratio = ratio[0] / ratio[1]
- factor_based_on_area_n_ratio = min(
- (ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
- ) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
- if factor_based_on_area_n_ratio > best_factor:
- best_factor = factor_based_on_area_n_ratio
- best_ratio = ratio
- return best_ratio
-
-
-# Mike's optimized ToTensor.
-def _fast_to_tensor(pic) -> torch.Tensor:
- np_img = np.array(pic, copy=False)
- img = torch.from_numpy(np_img)
- img = img.permute(2, 0, 1) # HWC to CHW
- fp_img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
- fp_img.div_(255)
- return fp_img
-
-
-@dataclass
-class ImageTilingParams:
- media: ImageMedia | VideoFrameMedia
- num_tiles: int
- num_embeddings: int
-
-
-class ImageTilingStrategy(ABC):
- """
- Base class for image tiling strategies.
- A tiling strategy is a function that takes a list of media and returns a list of image tiling parameters.
- These can then be used to apply the tiling to the media.
-
- Subclasses must implement the `compute_params` and `apply_params` methods.
-
- The `transform` method is a convenience method that computes the transformation parameters and applies the transformation to the media.
-
- """
-
- def transform(
- self,
- media_list: list[ImageMedia | VideoFrameMedia],
- num_tokens_available: int | None = None,
- ) -> list[torch.Tensor]:
- """
- Transform the media and compute the transformation parameters.
- """
- transform_media_list = self.compute_params(media_list, num_tokens_available)
- return [
- self.apply_params(transform_media, **kwargs)
- for transform_media in transform_media_list
- ]
-
- @abstractmethod
- def compute_params(
- self, media_list: list[ImageMedia | VideoFrameMedia], num_tokens_available: int, max_num_tiles: int | None = None, **kwargs
- ) -> list[ImageTilingParams]:
- """
- Compute the transformation parameters and the number of tokens to use for the media.
-
- Args:
- media_list: List of media to transform
- num_tokens_available: Number of tokens available for all media
- max_num_tiles: Maximum number of tiles allowed (optional, defaults to instance's max_num_tiles if not provided)
-
- Returns:
- list of transformation parameters with the media
- """
- ...
-
- @abstractmethod
- def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]:
- """
- Apply the transformation parameters to the media.
-
- Args:
- transform_media: The media to apply the transformation to
-
- Returns:
- list of transformed media tensors
- """
- ...
-
- @abstractmethod
- def stack(
- self, images: list[torch.Tensor]
- ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]:
- """
- Stack the images into a single tensor.
-
- Args:
- media_list: List of images to stack
-
- Returns:
- tuple of (stacked media, image sizes, vision cu lengths, vision max lengths)
- """
- ...
-
-
-class _FixedSizeStrategy(ImageTilingStrategy):
- """
- Base class for fixed size image tiling strategies.
- """
-
- def __init__(
- self,
- vision_model_type: str,
- target_width: int,
- target_height: int,
- embeddings_per_image: int,
- ):
- self._vision_model_type = vision_model_type
- self._target_width = target_width
- self._target_height = target_height
- self._embeddings_per_image = embeddings_per_image
- self._transform = self._build_transform(
- (target_width, target_height), vision_model_type
- )
-
- # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
- # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
- @staticmethod
- def _build_transform(target_size: tuple[int, int], vision_model_type: str):
- """
- Build a transform for a given vision model type and target size.
- """
- if vision_model_type in ("siglip", "internvit", "radio", "radio-g", "cradio-g"):
- pixel_mean, pixel_std = pixel_statistics[vision_model_type]
-
- transform = T.Compose(
- [
- T.Lambda(
- lambda img: img.convert("RGB") if img.mode != "RGB" else img
- ),
- T.Resize(
- (target_size[1], target_size[0]),
- interpolation=InterpolationMode.BICUBIC,
- ),
- T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)),
- T.Normalize(mean=pixel_mean, std=pixel_std),
- ]
- )
- # From the official CLIP repo.
- elif vision_model_type == "clip":
- pixel_mean, pixel_std = pixel_statistics[vision_model_type]
-
- transform = Compose(
- [
- T.Resize(
- (target_size[1], target_size[0]),
- interpolation=InterpolationMode.BICUBIC,
- ),
- T.Lambda(
- lambda img: img.convert("RGB") if img.mode != "RGB" else img
- ),
- T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)),
- T.Normalize(mean=pixel_mean, std=pixel_std),
- ]
- )
- elif vision_model_type.startswith("hf://"):
- from megatron.core.models.huggingface.module import get_hf_model_type
-
- model_type = get_hf_model_type(vision_model_type)
- if "siglip" in model_type:
- from transformers.models.siglip.image_processing_siglip import (
- SiglipImageProcessor,
- )
-
- processor = SiglipImageProcessor(
- size={"height": target_size[1], "width": target_size[0]}
- )
-
- def transform(x):
- x = x.convert("RGB") if x.mode != "RGB" else x
- x = processor(x, return_tensors="pt")
- return x["pixel_values"][0]
- else:
- raise NotImplementedError(
- f"image processing not defined for huggingface model {vision_model_type}"
- )
- else:
- raise NotImplementedError(
- f"image processing not defined for vision model {vision_model_type}"
- )
-
- return transform
-
- def stack(
- self, images: list[torch.Tensor]
- ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
- return (
- torch.stack(images) if len(images) > 0 else None,
- torch.tensor(
- [(img.shape[1], img.shape[2]) for img in images], dtype=torch.int32
- ) if len(images) > 0 else None,
- None,
- None,
- )
-
-
-class NoTilingStrategy(_FixedSizeStrategy):
- """
- A simple image transformation that resizes the image to the target width and height.
- """
-
- def __init__(
- self,
- vision_model_type: str,
- target_width: int,
- target_height: int,
- embeddings_per_image: int,
- ):
- super().__init__(
- vision_model_type=vision_model_type,
- target_width=target_width,
- target_height=target_height,
- embeddings_per_image=embeddings_per_image,
- )
-
- def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]:
- return [self._transform(transform_media.media.value)]
-
- def compute_params(
- self,
- media_list: list[ImageMedia | VideoFrameMedia],
- num_tokens_available: Optional[int] = None,
- max_num_tiles: int | None = None,
- **kwargs,
- ) -> list[ImageTilingParams]:
- return [
- ImageTilingParams(
- media=media, num_tiles=1, num_embeddings=self._embeddings_per_image
- )
- for media in media_list
- ]
-
- def __str__(self):
- return f"SimpleImageTransform(vision_model_type={self._vision_model_type}, num_tokens_per_image={self._embeddings_per_image})"
-
-
-@dataclass
-class ImageTilingParamsV1(ImageTilingParams):
- tiling: tuple[int, int]
-
-
-class ImageTilingStrategyV1(_FixedSizeStrategy):
- """Tiling image transformation.
-
- This transformation splits the image into a grid of tiles and applies the transformation to each tile.
- """
-
- # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
- # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
-
- def __init__(
- self,
- vision_model_type: str,
- tile_size: int,
- use_thumbnail: bool,
- min_num_tiles: int,
- max_num_tiles: int,
- embeddings_per_tile: int,
- find_closest_aspect_ratio_fn=find_closest_aspect_ratio,
- ):
- super().__init__(
- vision_model_type=vision_model_type,
- target_width=tile_size,
- target_height=tile_size,
- embeddings_per_image=embeddings_per_tile,
- )
-
- # print(f"Transformation params: {vision_model_type=}, {use_tiling=}, {tile_size=}, {use_thumbnail=}, {augment=}, {min_num_tiles=}, {max_num_tiles=}, {find_closest_aspect_ratio_fn=}")
- self._tile_size = tile_size
- self._use_thumbnail = use_thumbnail
- self._min_num_tiles = min_num_tiles
- self._max_num_tiles = max_num_tiles
- self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn
-
- # Calculate all possible aspect ratios for each max_num_tiles.
- self.target_ratios = {
- max_num_tiles: sorted(
- set(
- (x, y)
- for n in range(self._min_num_tiles, max_num_tiles + 1)
- for x in range(1, n + 1)
- for y in range(1, n + 1)
- if x * y <= max_num_tiles and x * y >= self._min_num_tiles
- ),
- key=lambda x: x[0] * x[1],
- )
- for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1)
- }
-
- self.transform = A.Compose([
- A.OneOf([
- A.GaussNoise(var_limit=(5.0, 30.0)),
- A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5)),
- ], p=0.3),
- A.OneOf([
- A.MedianBlur(blur_limit=5),
- A.GaussianBlur(blur_limit=5),
- ], p=0.2),
- A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.5),
- A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=15, val_shift_limit=15, p=0.3),
- A.ImageCompression(quality_lower=70, quality_upper=100, p=0.3),
- ])
-
- def apply_params(self, transform_media: ImageTilingParams, data_augment: bool = False, **kwargs) -> list[torch.Tensor]:
- assert isinstance(transform_media, ImageTilingParamsV1)
- image = transform_media.media.value
-
- if data_augment:
- image = self.transform(image=np.asarray(image))["image"]
- image = Image.fromarray(image)
-
- # calculate the target width and height
- target_width = self._tile_size * transform_media.tiling[0]
- target_height = self._tile_size * transform_media.tiling[1]
- blocks = transform_media.tiling[0] * transform_media.tiling[1]
-
- # resize the image
- resized_img = image.resize((target_width, target_height))
- processed_images = []
- for i in range(blocks):
- box = (
- (i % (target_width // self._tile_size)) * self._tile_size,
- (i // (target_width // self._tile_size)) * self._tile_size,
- ((i % (target_width // self._tile_size)) + 1) * self._tile_size,
- ((i // (target_width // self._tile_size)) + 1) * self._tile_size,
- )
- # split the image
- split_img = resized_img.crop(box)
- processed_images.append(split_img)
- assert len(processed_images) == blocks
- if self._use_thumbnail and len(processed_images) != 1:
- thumbnail_img = image.resize((self._tile_size, self._tile_size))
- processed_images.append(thumbnail_img)
-
- return [self._transform(img) for img in processed_images]
-
- def compute_params(
- self,
- media_list: list[ImageMedia | VideoFrameMedia],
- num_tokens_available: Optional[int] = None,
- max_num_tiles: int | None = None,
- data_augment: bool = False,
- tiling_augment_prob: float = 0.4,
- **kwargs,
- ) -> list[ImageTilingParamsV1]:
- # Use provided max_num_tiles or fall back to instance's max_num_tiles
- # Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value
- effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
- effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles)
-
- max_num_tiles_to_use = min(
- num_tokens_available // self._embeddings_per_image, effective_max_num_tiles
- )
-
- # calculate the existing image aspect ratio
- target_ratios = self.target_ratios[max_num_tiles_to_use]
-
- params = []
- for media in media_list:
- if isinstance(media, ImageMedia):
- img_size = (media.width, media.height)
- elif isinstance(media, VideoFrameMedia):
- img_size = (media.video_width, media.video_height)
- else:
- raise ValueError(f"Unsupported media type: {type(media)}")
-
- aspect_ratio = img_size[0] / img_size[1]
-
- # find the closest aspect ratio to the target
- tiling = self._find_closest_aspect_ratio_fn(
- aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size
- )
- if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob:
- tiling = self.augment_tiling(tiling)
- num_tiles = tiling[0] * tiling[1]
- if self._use_thumbnail and num_tiles != 1:
- num_tiles += 1
-
- params.append(
- ImageTilingParamsV1(
- media=media,
- num_tiles=num_tiles,
- num_embeddings=num_tiles * self._embeddings_per_image,
- tiling=tiling,
- )
- )
-
- return params
-
- def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]:
- def num_tiles(tiling: tuple[int, int]) -> int:
- return tiling[0] * tiling[1]
-
- def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]:
- if random.random() < minus_prob:
- # Minus one
- if tiling[0] == 1 and tiling[1] == 1:
- return tiling
- elif tiling[0] == 1:
- return (tiling[0], tiling[1] - 1)
- elif tiling[1] == 1:
- return (tiling[0] - 1, tiling[1])
- else:
- if random.random() < 0.5:
- return (tiling[0] - 1, tiling[1])
- else:
- return (tiling[0], tiling[1] - 1)
- else:
- # Plus one
- if num_tiles(tiling) < self._max_num_tiles:
- tiling0 = (tiling[0] + 1, tiling[1])
- tiling1 = (tiling[0], tiling[1] + 1)
- if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles:
- return tiling
- elif num_tiles(tiling0) > self._max_num_tiles:
- return tiling1
- elif num_tiles(tiling1) > self._max_num_tiles:
- return tiling0
- else:
- if random.random() < 0.5:
- return tiling0
- else:
- return tiling1
- return tiling
-
- new_tiling = plus_minus_one(tiling)
- return new_tiling
-
- def __str__(self):
- return f"TilingImageTransform(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, embeddings_per_tile={self._embeddings_per_image}, find_closest_aspect_ratio_fn={self._find_closest_aspect_ratio_fn})"
-
-
-class TileDegradationStrategy(ImageTilingStrategy):
- """Strategy for tiling images and video frames, each with their own tiling strategy, while trying to match the
- number of tokens left in the sample by reducing the number of tiles if needed.
- """
-
- # Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
- # and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
-
- def __init__(
- self,
- image_strategy: ImageTilingStrategy,
- video_frame_strategy: ImageTilingStrategy,
- embeddings_per_tile: int,
- max_num_tiles: int,
- tile_degradation_map: dict[int, int] = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1},
- ):
- self._image_strategy = image_strategy
- self._video_frame_strategy = video_frame_strategy
- self._embeddings_per_tile = embeddings_per_tile
- self._max_num_tiles = max_num_tiles
- self._tile_degradation_map = tile_degradation_map
-
- def apply_params(self, transform_media: ImageTilingParams, **kwargs) -> list[torch.Tensor]:
- if isinstance(transform_media.media, ImageMedia):
- return self._image_strategy.apply_params(transform_media, **kwargs)
- elif isinstance(transform_media.media, VideoFrameMedia):
- return self._video_frame_strategy.apply_params(transform_media, **kwargs)
- else:
- raise ValueError(f"Unsupported media type: {type(transform_media.media)}")
-
- def compute_params(
- self,
- media_list: list[ImageMedia | VideoFrameMedia],
- num_tokens_available: int | None = None,
- max_num_tiles: int | None = None,
- **kwargs,
- ) -> list[ImageTilingParams]:
- # Use provided max_num_tiles or fall back to instance's max_num_tiles
- effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
- max_num_tiles_to_use = effective_max_num_tiles
- degradation_map = self._tile_degradation_map
-
- while True:
- params = []
- img_num_tiles = []
- for media in media_list:
- if isinstance(media, ImageMedia):
- media_params = self._image_strategy.compute_params(
- [media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs
- )[0]
- elif isinstance(media, VideoFrameMedia):
- max_num_tiles_to_use = 1
- media_params = self._video_frame_strategy.compute_params(
- [media], max_num_tiles_to_use * self._embeddings_per_tile, max_num_tiles_to_use, **kwargs
- )[0]
- else:
- raise ValueError(f"Unsupported media type: {type(media)}")
- img_num_tiles.append(media_params.num_tiles)
- params.append(media_params)
- if max_num_tiles_to_use == 1 or num_tokens_available is None:
- break
- if sum(img_num_tiles) * self._embeddings_per_tile > num_tokens_available:
- if max_num_tiles_to_use in degradation_map:
- max_num_tiles_to_use = degradation_map[max_num_tiles_to_use]
- else:
- # End of degradation
- break
- else:
- break
- return params
-
- def stack(
- self, images: list[torch.Tensor]
- ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
- return self._image_strategy.stack(images)
-
- def __str__(self):
- return f"TileDegradationImageTransform(max_num_tiles={self._max_num_tiles}, image_transform={self._image_strategy}, video_frame_transform={self._video_frame_strategy})"
-
-
-@dataclass
-class DynamicResolutionParams(ImageTilingParams):
- patch_size: tuple[int, int]
-
-
-class DynamicResolutionImageTilingStrategy(ImageTilingStrategy):
- """Preprocess an image with dynamic resolution for vision transformers.
-
- This function resizes an image to optimize the number of patches while respecting
- constraints on minimum/maximum patches, minimum side length, and compatibility
- with pixel shuffle or convolution merging operations.
-
- The algorithm works by:
- 1. Computing the initial patch grid size based on the image dimensions and res_step
- 2. Scaling the patch grid to fit within the max_patches constraint
- 3. Ensuring the result has at least min_patches
- 4. Optionally enforcing a minimum side length constraint
- 5. Rounding patch dimensions to even numbers for pixel_shuffle/conv_merging compatibility
- 6. Resizing the image to the computed target dimensions
-
- Note:
- The function preserves aspect ratio as much as possible while satisfying all constraints.
- When constraints conflict (e.g., min_side vs max_patches), the function prioritizes
- staying within max_patches while maximizing the image size.
-
- Example:
- >>> from PIL import Image
- >>> img = Image.open("example.jpg") # 800x600 image
- >>> strategy = DynamicResolutionImageTilingStrategy(vision_model_type="radio", min_patches=4, max_patches=64, res_step=14, get_num_embeddings=lambda x, y: x * y * 2)
- >>> params = strategy.compute_params([img])
- >>> img_tensor = strategy.apply_params(params[0])
- >>> # Returns image resized to maintain aspect ratio with 4-64 patches of size 14x14
- """
-
- def __init__(
- self,
- vision_model_type: str,
- min_num_patches: int,
- patch_size: int,
- get_num_embeddings: Callable[[int, int], int],
- factor_max: float = 1.0,
- pixel_shuffle: bool = False,
- min_side: int | None = None,
- conv_merging: bool = False,
- use_thumbnail: bool = False,
- thumbnail_size: int = 448,
- thumbnail_area_threshold: float = 0.8,
- max_num_patches: int = 0,
- apply_data_augment: bool = False,
- ):
- """
- Args:
- vision_model_type: Vision model type.
- min_num_patches: Minimum number of patches required. Defaults to 1.
- max_num_patches: Maximum number of patches allowed. Defaults to 0 (no maximum).
- patch_size: Resolution step size (patch dimension). Defaults to 16.
- get_num_embeddings: Function to get the number of embeddings from the patch size (width, height).
- factor_max: Maximum scaling factor to apply. Defaults to 1.0.
- pixel_shuffle: Whether to ensure compatibility with pixel shuffle operations by rounding to even patch
- dimensions. Defaults to False.
- min_side: Minimum side length in pixels. If specified, ensures at least one side meets this constraint.
- Defaults to None.
- conv_merging: Whether to ensure compatibility with convolution merging by rounding to even patch dimensions.
- Defaults to False.
- use_thumbnail: Whether to add a thumbnail image when processing. Defaults to False.
- thumbnail_size: Size of the thumbnail image (width and height). Defaults to 448.
- thumbnail_area_threshold: Maximum area percentage (0.0-1.0) of the resized image relative to thumbnail area
- for which to add a thumbnail. If the resized image area is larger than this threshold of the thumbnail
- area, no thumbnail will be added. Defaults to 0.8 (80%).
- apply_data_augment: Whether to apply data augmentation to the image. Defaults to False.
- """
- assert "radio" in vision_model_type, (
- "Dynamic resolution is only supported for radio models"
- )
- self._vision_model_type = vision_model_type
- self._min_num_patches = min_num_patches
- self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf")
- self._patch_size = patch_size
- self._get_num_embeddings = get_num_embeddings
- self._factor_max = factor_max
- self._pixel_shuffle = pixel_shuffle
- self._min_side = min_side
- self._conv_merging = conv_merging
- self._use_thumbnail = use_thumbnail
- self._thumbnail_size = thumbnail_size
- self._thumbnail_area_threshold = thumbnail_area_threshold
- pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
- self._transform = T.Compose(
- [
- T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
- T.ToTensor(), #T.Lambda(lambda img: _fast_to_tensor(img)),
- T.Normalize(mean=pixel_mean, std=pixel_std),
- ]
- )
- self._apply_data_augment = apply_data_augment
-
- def apply_params(self, params: DynamicResolutionParams, **kwargs) -> list[torch.Tensor]:
- # resize the image
- resized_img = params.media.value.resize(
- (
- params.patch_size[0] * self._patch_size,
- params.patch_size[1] * self._patch_size,
- )
- )
- processed_images = [resized_img]
-
- # Add thumbnail if enabled and image area is below threshold
- if self._use_thumbnail:
- # Calculate areas
- resized_area = resized_img.size[0] * resized_img.size[1]
- thumbnail_area = self._thumbnail_size * self._thumbnail_size
- area_ratio = resized_area / thumbnail_area
-
- # Only add thumbnail if resized image area is less than threshold % of thumbnail area
- if area_ratio < self._thumbnail_area_threshold:
- thumbnail_img = params.media.value.resize((self._thumbnail_size, self._thumbnail_size))
- processed_images.append(thumbnail_img)
-
- return [self._transform(img) for img in processed_images]
-
- def process_media(
- self,
- media: ImageMedia | VideoFrameMedia,
- num_tokens_available: int,
- data_augment: bool = False,
- tiling_augment_prob: float = 0.4,
- ) -> DynamicResolutionParams:
- """Process a single media item and return its parameters.
-
- Args:
- media: The media item to process
- num_tokens_available: Number of tokens available for this media
- data_augment: Whether to apply data augmentation to the image. Defaults to False.
- Returns:
- DynamicResolutionParams for the media
- """
- current_num_tokens_available = num_tokens_available
- if isinstance(media, ImageMedia):
- orig_width, orig_height = media.width, media.height
- elif isinstance(media, VideoFrameMedia):
- orig_width, orig_height = media.video_width, media.video_height
- # current_num_tokens_available = 1024 #TEMP: hack for video
- else:
- raise ValueError(f"Unsupported media type: {type(media)}")
-
- closest_patch_height = round(orig_height / self._patch_size + 0.5)
- closest_patch_width = round(orig_width / self._patch_size + 0.5)
- patches = closest_patch_height * closest_patch_width
-
- factor = min(math.sqrt(current_num_tokens_available / patches), self._factor_max)
- target_patch_height = math.floor(factor * closest_patch_height)
- target_patch_width = math.floor(factor * closest_patch_width)
-
- # We only consider self._min_num_patches if it is greater than current_num_tokens_available.
- if current_num_tokens_available > self._min_num_patches and target_patch_height * target_patch_width < self._min_num_patches:
- up_factor = math.sqrt(
- self._min_num_patches / (target_patch_height * target_patch_width)
- )
- target_patch_height = math.ceil(up_factor * target_patch_height)
- target_patch_width = math.ceil(up_factor * target_patch_width)
-
- if (
- self._min_side is not None
- and min(target_patch_width, target_patch_height) * self._patch_size
- < self._min_side
- ):
- if target_patch_width <= target_patch_height:
- up_factor = self._min_side / (target_patch_width * self._patch_size)
- new_patch_height = math.ceil(up_factor * target_patch_height)
- new_patch_width = math.ceil(up_factor * target_patch_width)
-
- if new_patch_height * new_patch_width > current_num_tokens_available:
- # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
- if (
- max(current_num_tokens_available // new_patch_width, 1)
- * self._patch_size
- < self._min_side
- ):
- up_factor = math.sqrt(
- current_num_tokens_available
- / (target_patch_height * target_patch_width)
- )
- target_patch_height = math.floor(
- up_factor * target_patch_height
- )
- target_patch_width = math.floor(
- up_factor * target_patch_width
- )
- target_patch_width = new_patch_width
- target_patch_height = max(
- current_num_tokens_available // new_patch_width, 1
- )
- else:
- target_patch_height = new_patch_height
- target_patch_width = new_patch_width
- else:
- up_factor = self._min_side / (
- target_patch_height * self._patch_size
- )
- new_patch_height = math.ceil(up_factor * target_patch_height)
- new_patch_width = math.ceil(up_factor * target_patch_width)
-
- if new_patch_height * new_patch_width > current_num_tokens_available:
- # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
- if (
- max(current_num_tokens_available // new_patch_height, 1)
- * self._patch_size
- < self._min_side
- ):
- up_factor = math.sqrt(
- current_num_tokens_available
- / (target_patch_height * target_patch_width)
- )
- target_patch_height = math.floor(
- up_factor * target_patch_height
- )
- target_patch_width = math.floor(
- up_factor * target_patch_width
- )
- else:
- target_patch_height = new_patch_height
- target_patch_width = max(
- current_num_tokens_available // new_patch_height, 1
- )
- else:
- target_patch_height = new_patch_height
- target_patch_width = new_patch_width
-
- # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
- # or by 4 when BOTH are enabled (two successive 2x reductions)
- if self._pixel_shuffle or self._conv_merging:
- required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2
-
- rem_h = target_patch_height % required_divisor
- if rem_h != 0:
- inc_h = required_divisor - rem_h
- if (target_patch_height + inc_h) * target_patch_width <= current_num_tokens_available:
- target_patch_height += inc_h
- else:
- target_patch_height = max(required_divisor, target_patch_height - rem_h)
-
- rem_w = target_patch_width % required_divisor
- if rem_w != 0:
- inc_w = required_divisor - rem_w
- if target_patch_height * (target_patch_width + inc_w) <= current_num_tokens_available:
- target_patch_width += inc_w
- else:
- target_patch_width = max(required_divisor, target_patch_width - rem_w)
-
- if data_augment and self._apply_data_augment and random.random() < tiling_augment_prob:
- target_patch_width, target_patch_height = self.augment_resolution(target_patch_width, target_patch_height, current_num_tokens_available)
-
- #TEMP: hack for video
- if isinstance(media, VideoFrameMedia):
- target_patch_width = 32
- target_patch_height = 32
-
- # Calculate embeddings for the main dynamic resolution image
- num_embeddings = self._get_num_embeddings(
- target_patch_width * self._patch_size,
- target_patch_height * self._patch_size,
- )
-
- token_count = target_patch_width * target_patch_height
-
- # Add thumbnail embeddings if enabled and image area is below threshold
- num_tiles = 1 # Base dynamic resolution image
- if self._use_thumbnail:
- # Calculate areas
- resized_area = (target_patch_width * self._patch_size) * (target_patch_height * self._patch_size)
- thumbnail_area = self._thumbnail_size * self._thumbnail_size
- area_ratio = resized_area / thumbnail_area
-
- # Only add thumbnail if resized image area is less than threshold % of thumbnail area
- if area_ratio < self._thumbnail_area_threshold:
- num_tiles += 1 # Add 1 for thumbnail
- # Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
- num_embeddings += self._get_num_embeddings(self._thumbnail_size, self._thumbnail_size)
- token_count += self._thumbnail_size // self._patch_size * self._thumbnail_size // self._patch_size
-
- return DynamicResolutionParams(
- media=media,
- num_tiles=num_tiles,
- num_embeddings=num_embeddings,
- patch_size=(target_patch_width, target_patch_height),
- ), token_count
-
- def augment_resolution(self, target_patch_width: int, target_patch_height: int, current_num_tokens_available: int) -> tuple[int, int]:
-
- min_num_patch_one_side = 32
-
- if random.random() < 0.5:
- # Minus one
- if target_patch_width <= min_num_patch_one_side and target_patch_height <= min_num_patch_one_side:
- return target_patch_width, target_patch_height
- elif target_patch_width <= min_num_patch_one_side:
- return target_patch_width, target_patch_height - min_num_patch_one_side
- elif target_patch_height <= min_num_patch_one_side:
- return target_patch_width - min_num_patch_one_side, target_patch_height
- else:
- if random.random() < 0.5:
- return target_patch_width - min_num_patch_one_side, target_patch_height
- else:
- return target_patch_width, target_patch_height - min_num_patch_one_side
- else:
- # Plus one
- if target_patch_width * target_patch_height < current_num_tokens_available:
- if random.random() < 0.5:
- return target_patch_width + min_num_patch_one_side, target_patch_height
- else:
- return target_patch_width, target_patch_height + min_num_patch_one_side
- return target_patch_width, target_patch_height
-
- def compute_params(
- self,
- media_list: list[ImageMedia | VideoFrameMedia],
- num_tokens_available: int | None = None,
- max_num_tiles: int | None = None,
- data_augment: bool = False,
- **kwargs,
- ) -> list[ImageTilingParams]:
- """Compute parameters for all media with iterative token budgeting.
-
- Args:
- media_list: List of media items to process
- num_tokens_available: Total number of tokens available across all media
- max_num_tiles: Maximum number of tiles (unused in this implementation)
- data_augment: Whether to apply data augmentation to the image. Defaults to False.
- Returns:
- List of ImageTilingParams for each media item
- """
- num_tokens_available = num_tokens_available * (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1)
- # When the number of available token is too small, allow self._min_num_patches per media and
- # let the sample be truncated.
- num_tokens_available = max(num_tokens_available, self._min_num_patches * len(media_list))
-
- # Clip the number of tokens available per media to be between min and max patches.
- num_tokens_available_per_media = [
- max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
- for _ in range(len(media_list))]
-
- # In theory this could be a while True loop, but in case the process_media method slightly
- # changes, I want to make sure we don't get stuck in an infinite loop.
- for _ in range(10):
- # Step 1: Process each media with current token budget
- params = []
- token_counts = []
-
- for media, tokens_for_media in zip(media_list, num_tokens_available_per_media):
- param, token_count = self.process_media(media, tokens_for_media, data_augment=data_augment)
- params.append(param)
- token_counts.append(token_count)
-
- # Step 2: Check if total tokens is within budget
- total_tokens = sum(token_counts)
-
- if total_tokens <= num_tokens_available:
- # We're within budget, return the params
- return params
-
- # Step 3: We're over budget, need to scale down
- # Calculate scaling factor to get under budget
- scaling_factor = num_tokens_available / total_tokens
-
- # Recalculate token budgets for each media based on scaling
- # Each media gets a proportional share of the total budget
- scaled_down_num_tokens_available_per_media = [
- max(self._min_num_patches, int(token_count * scaling_factor))
- for token_count in token_counts
- ]
- scaled_down = any([
- scaled_down_num_tokens_available_per_media[i] < num_tokens_available_per_media[i]
- for i in range(len(num_tokens_available_per_media))])
- # If there was not scaling down, we're stuck just use min_num_patches per media, else
- # try with the scaled down num_tokens_available_per_media.
- if not scaled_down:
- num_tokens_available_per_media = [self._min_num_patches] * len(media_list)
- else:
- num_tokens_available_per_media = scaled_down_num_tokens_available_per_media
- return params
-
- def stack(
- self, images: list[torch.Tensor]
- ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
- imgs_sizes = torch.tensor(
- [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
- )
-
- def rearrange_img(x):
- py = x.shape[-2] // self._patch_size
- px = x.shape[-1] // self._patch_size
- x = einops.rearrange(
- x,
- "c (py yy) (px xx) -> (py px) (c yy xx)",
- py=py,
- yy=self._patch_size,
- px=px,
- xx=self._patch_size,
- )
- return x
-
- if len(images) > 0:
- imgs = [rearrange_img(img) for img in images]
-
- current_length = 0
- max_length = 0
- vision_cu_lengths = [0]
- for img in imgs:
- if max_length < img.shape[0]:
- max_length = img.shape[0]
- current_length += img.shape[0]
- vision_cu_lengths.append(current_length)
-
- vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
- vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
-
- return (
- torch.cat(imgs, dim=0).unsqueeze(0),
- imgs_sizes,
- vision_cu_lengths,
- vision_max_lengths,
- )
- else:
- return (
- torch.tensor([[0]], dtype=torch.float32),
- torch.tensor([[0,0]], dtype=torch.int32),
- None,
- None,
- )
-
- def __str__(self):
- return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})"
-
-
-@dataclass
-class MatchTilingDynamicResolutionParams(ImageTilingParams):
- tiling: tuple[int, int]
-
-
-class MatchTilingDynamicResolutionStrategy(ImageTilingStrategy):
- """
- Strategy that uses tiling logic to determine optimal image dimensions but processes
- the image as a single dynamic resolution image instead of splitting into tiles.
-
- This combines the aspect ratio optimization from ImageTilingStrategyV1 with the
- dynamic resolution processing from DynamicResolutionImageTilingStrategy.
-
- Also includes tile degradation logic similar to TileDegradationStrategy.
- """
-
- def __init__(
- self,
- vision_model_type: str,
- tile_size: int,
- use_thumbnail: bool,
- min_num_tiles: int,
- max_num_tiles: int,
- embeddings_per_tile: int,
- patch_size: int,
- get_num_embeddings: Callable[[int, int], int],
- find_closest_aspect_ratio_fn=find_closest_aspect_ratio,
- pixel_shuffle: bool = False,
- conv_merging: bool = False,
- tile_degradation_map: dict[int, int] = None,
- video_frame_strategy: ImageTilingStrategy = None,
- enable_tile_degradation: bool = True,
- ):
- """
- Args:
- vision_model_type: Vision model type (should support dynamic resolution)
- tile_size: Size of each tile for tiling calculation
- use_thumbnail: Whether tiling logic should include thumbnail
- min_num_tiles: Minimum number of tiles for tiling calculation
- max_num_tiles: Maximum number of tiles for tiling calculation
- embeddings_per_tile: Embeddings per tile for tiling calculation
- patch_size: Patch size for dynamic resolution processing
- get_num_embeddings: Function to get number of embeddings from dimensions
- find_closest_aspect_ratio_fn: Function to find closest aspect ratio
- pixel_shuffle: Whether to ensure compatibility with pixel shuffle
- conv_merging: Whether to ensure compatibility with convolution merging
- tile_degradation_map: Map for degrading tiles when tokens are insufficient
- video_frame_strategy: Strategy for processing video frames
- enable_tile_degradation: Whether to enable tile degradation (default: True)
- """
- assert "radio" in vision_model_type, (
- "MatchTilingDynamicResolution is only supported for radio models"
- )
-
- self._vision_model_type = vision_model_type
- self._tile_size = tile_size
- self._use_thumbnail = use_thumbnail
- self._min_num_tiles = min_num_tiles
- self._max_num_tiles = max_num_tiles
- self._embeddings_per_tile = embeddings_per_tile
- self._patch_size = patch_size
- self._get_num_embeddings = get_num_embeddings
- self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn
- self._pixel_shuffle = pixel_shuffle
- self._conv_merging = conv_merging
- self._enable_tile_degradation = enable_tile_degradation
-
- # Tile degradation logic (similar to TileDegradationStrategy)
- if tile_degradation_map is None:
- self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1}
- else:
- self._tile_degradation_map = tile_degradation_map
-
- # Video frame strategy (similar to TileDegradationStrategy)
- if video_frame_strategy is None:
- self._video_frame_strategy = NoTilingStrategy(
- vision_model_type=vision_model_type,
- target_width=tile_size,
- target_height=tile_size,
- embeddings_per_image=embeddings_per_tile,
- )
- else:
- self._video_frame_strategy = video_frame_strategy
-
- # Calculate all possible aspect ratios for each max_num_tiles (borrowed from ImageTilingStrategyV1)
- self.target_ratios = {
- max_num_tiles: sorted(
- set(
- (x, y)
- for n in range(self._min_num_tiles, max_num_tiles + 1)
- for x in range(1, n + 1)
- for y in range(1, n + 1)
- if x * y <= max_num_tiles and x * y >= self._min_num_tiles
- ),
- key=lambda x: x[0] * x[1],
- )
- for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1)
- }
-
- # Set up transform for dynamic resolution processing
- pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
- self._transform = T.Compose(
- [
- T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
- T.ToTensor(),
- T.Normalize(mean=pixel_mean, std=pixel_std),
- ]
- )
-
- def apply_params(self, params: MatchTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]:
- # Handle video frames using the video frame strategy
- if isinstance(params.media, VideoFrameMedia):
- return self._video_frame_strategy.apply_params(params, **kwargs)
-
- # Handle images with dynamic resolution processing
- image = params.media.value
- # Calculate the target width and height (same logic as ImageTilingStrategyV1)
- target_width = self._tile_size * params.tiling[0]
- target_height = self._tile_size * params.tiling[1]
-
- # Resize the image to the target dimensions (same as ImageTilingStrategyV1)
- resized_img = image.resize((target_width, target_height))
-
- # Process as single dynamic resolution image
- processed_images = [resized_img]
-
- # Add thumbnail if use_thumbnail=True and there's more than 1 tile (same as ImageTilingStrategyV1)
- blocks = params.tiling[0] * params.tiling[1]
- if self._use_thumbnail and blocks != 1:
- thumbnail_img = image.resize((self._tile_size, self._tile_size))
- processed_images.append(thumbnail_img)
-
- return [self._transform(img) for img in processed_images]
-
- def compute_params(
- self,
- media_list: list[ImageMedia | VideoFrameMedia],
- num_tokens_available: int | None = None,
- max_num_tiles: int | None = None,
- **kwargs,
- ) -> list[MatchTilingDynamicResolutionParams]:
- # Implement tile degradation logic similar to TileDegradationStrategy
- # Use provided max_num_tiles or fall back to instance's max_num_tiles
- # Clamp to self._max_num_tiles since target_ratios are only pre-computed up to that value
- effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
- effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles)
- max_num_tiles_to_use = effective_max_num_tiles
- degradation_map = self._tile_degradation_map
-
- while True:
- params = []
- total_embeddings_needed = 0
-
- for media in media_list:
- if isinstance(media, ImageMedia):
- # Use tiling logic for images
- img_size = (media.width, media.height)
- aspect_ratio = img_size[0] / img_size[1]
-
- # Find the closest aspect ratio to the target
- target_ratios = self.target_ratios[max_num_tiles_to_use]
- tiling = self._find_closest_aspect_ratio_fn(
- aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size
- )
-
- # Calculate target dimensions for dynamic resolution processing
- target_width = self._tile_size * tiling[0]
- target_height = self._tile_size * tiling[1]
- num_embeddings = self._get_num_embeddings(target_width, target_height)
-
- # Account for thumbnail (same logic as ImageTilingStrategyV1)
- num_tiles = 1 # Base dynamic resolution image
- blocks = tiling[0] * tiling[1]
- if self._use_thumbnail and blocks != 1:
- num_tiles += 1 # Add 1 for thumbnail
- # Add embeddings for thumbnail (tile_size x tile_size)
- num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size)
-
- media_params = MatchTilingDynamicResolutionParams(
- media=media,
- num_tiles=num_tiles,
- num_embeddings=num_embeddings,
- tiling=tiling,
- )
- elif isinstance(media, VideoFrameMedia):
- # Use video frame strategy for video frames (always 1 tile)
- video_params = self._video_frame_strategy.compute_params(
- [media], 1 * self._embeddings_per_tile
- )[0]
- media_params = MatchTilingDynamicResolutionParams(
- media=media,
- num_tiles=video_params.num_tiles,
- num_embeddings=video_params.num_embeddings,
- tiling=(1, 1), # Video frames always use 1x1 tiling
- )
- else:
- raise ValueError(f"Unsupported media type: {type(media)}")
-
- params.append(media_params)
- total_embeddings_needed += media_params.num_embeddings
-
- # Check if we need to degrade (only if degradation is enabled)
- if not self._enable_tile_degradation:
- break
- if max_num_tiles_to_use == 1 or num_tokens_available is None:
- break
- if total_embeddings_needed > num_tokens_available:
- if max_num_tiles_to_use in degradation_map:
- max_num_tiles_to_use = degradation_map[max_num_tiles_to_use]
- # Recalculate target ratios for the new max_num_tiles_to_use
- if max_num_tiles_to_use not in self.target_ratios:
- self.target_ratios[max_num_tiles_to_use] = sorted(
- set(
- (x, y)
- for n in range(self._min_num_tiles, max_num_tiles_to_use + 1)
- for x in range(1, n + 1)
- for y in range(1, n + 1)
- if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles
- ),
- key=lambda x: x[0] * x[1],
- )
- else:
- # End of degradation
- break
- else:
- break
-
- return params
-
- def stack(
- self, images: list[torch.Tensor]
- ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]:
- """Stack images using dynamic resolution approach with sequence packing"""
- imgs_sizes = torch.tensor(
- [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
- )
-
- def rearrange_img(x):
- py = x.shape[-2] // self._patch_size
- px = x.shape[-1] // self._patch_size
- x = einops.rearrange(
- x,
- "c (py yy) (px xx) -> (py px) (c yy xx)",
- py=py,
- yy=self._patch_size,
- px=px,
- xx=self._patch_size,
- )
- return x
-
- if len(images) > 0:
- imgs = [rearrange_img(img) for img in images]
-
- current_length = 0
- max_length = 0
- vision_cu_lengths = [0]
- for img in imgs:
- if max_length < img.shape[0]:
- max_length = img.shape[0]
- current_length += img.shape[0]
- vision_cu_lengths.append(current_length)
-
- vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
- vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
-
- return (
- torch.cat(imgs, dim=0).unsqueeze(0),
- imgs_sizes,
- vision_cu_lengths,
- vision_max_lengths,
- )
- else:
- return (
- torch.tensor([[0]], dtype=torch.float32),
- torch.tensor([[0,0]], dtype=torch.int32),
- None,
- None,
- )
-
- def __str__(self):
- return f"MatchTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})"
-
-
-@dataclass
-class MaskedTilingDynamicResolutionParams(ImageTilingParams):
- tiling: tuple[int, int]
-
-
-class MaskedTilingDynamicResolutionStrategy(ImageTilingStrategy):
- """
- Like MatchTilingDynamicResolutionStrategy, but ensures tiles are isolated in the
- vision encoder by emitting per-tile packed samples (block-diagonal attention across tiles).
- """
-
- def __init__(
- self,
- vision_model_type: str,
- tile_size: int,
- use_thumbnail: bool,
- min_num_tiles: int,
- max_num_tiles: int,
- embeddings_per_tile: int,
- patch_size: int,
- get_num_embeddings: Callable[[int, int], int],
- find_closest_aspect_ratio_fn=find_closest_aspect_ratio,
- pixel_shuffle: bool = False,
- conv_merging: bool = False,
- tile_degradation_map: dict[int, int] = None,
- video_frame_strategy: ImageTilingStrategy = None,
- enable_tile_degradation: bool = True,
- ):
- assert "radio" in vision_model_type, (
- "MaskedTilingDynamicResolution is only supported for radio models"
- )
-
- self._vision_model_type = vision_model_type
- self._tile_size = tile_size
- self._use_thumbnail = use_thumbnail
- self._min_num_tiles = min_num_tiles
- self._max_num_tiles = max_num_tiles
- self._embeddings_per_tile = embeddings_per_tile
- self._patch_size = patch_size
- self._get_num_embeddings = get_num_embeddings
- self._find_closest_aspect_ratio_fn = find_closest_aspect_ratio_fn
- self._pixel_shuffle = pixel_shuffle
- self._conv_merging = conv_merging
- self._enable_tile_degradation = enable_tile_degradation
-
- if tile_degradation_map is None:
- self._tile_degradation_map = {12: 8, 8: 6, 6: 4, 4: 2, 2: 1}
- else:
- self._tile_degradation_map = tile_degradation_map
-
- if video_frame_strategy is None:
- self._video_frame_strategy = NoTilingStrategy(
- vision_model_type=vision_model_type,
- target_width=tile_size,
- target_height=tile_size,
- embeddings_per_image=embeddings_per_tile,
- )
- else:
- self._video_frame_strategy = video_frame_strategy
-
- self.target_ratios = {
- max_num_tiles: sorted(
- set(
- (x, y)
- for n in range(self._min_num_tiles, max_num_tiles + 1)
- for x in range(1, n + 1)
- for y in range(1, n + 1)
- if x * y <= max_num_tiles and x * y >= self._min_num_tiles
- ),
- key=lambda x: x[0] * x[1],
- )
- for max_num_tiles in range(self._min_num_tiles, self._max_num_tiles + 1)
- }
-
- pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
- self._transform = T.Compose(
- [
- T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
- T.ToTensor(),
- T.Normalize(mean=pixel_mean, std=pixel_std),
- ]
- )
-
- def apply_params(self, params: MaskedTilingDynamicResolutionParams, **kwargs) -> list[torch.Tensor]:
- # Handle video frames using the video frame strategy
- if isinstance(params.media, VideoFrameMedia):
- return self._video_frame_strategy.apply_params(params, **kwargs)
-
- image = params.media.value
- nx, ny = params.tiling
- target_width = self._tile_size * nx
- target_height = self._tile_size * ny
-
- resized_img = image.resize((target_width, target_height))
-
- processed_images = []
- # Emit per-tile images (each becomes an isolated packed sample later)
- for j in range(ny):
- for i in range(nx):
- box = (
- i * self._tile_size,
- j * self._tile_size,
- (i + 1) * self._tile_size,
- (j + 1) * self._tile_size,
- )
- tile_img = resized_img.crop(box)
- processed_images.append(tile_img)
-
- if self._use_thumbnail and (nx * ny) != 1:
- thumbnail_img = image.resize((self._tile_size, self._tile_size))
- processed_images.append(thumbnail_img)
-
- return [self._transform(img) for img in processed_images]
-
- def compute_params(
- self,
- media_list: list[ImageMedia | VideoFrameMedia],
- num_tokens_available: int | None = None,
- max_num_tiles: int | None = None,
- data_augment: bool = False,
- tiling_augment_prob: float = 0.4,
- **kwargs,
- ) -> list[MaskedTilingDynamicResolutionParams]:
- effective_max_num_tiles = max_num_tiles if max_num_tiles is not None else self._max_num_tiles
- effective_max_num_tiles = min(effective_max_num_tiles, self._max_num_tiles)
- max_num_tiles_to_use = effective_max_num_tiles
- degradation_map = self._tile_degradation_map
-
- while True:
- params = []
- total_embeddings_needed = 0
-
- for media in media_list:
- if isinstance(media, ImageMedia):
- img_size = (media.width, media.height)
- aspect_ratio = img_size[0] / img_size[1]
-
- target_ratios = self.target_ratios[max_num_tiles_to_use]
- tiling = self._find_closest_aspect_ratio_fn(
- aspect_ratio, target_ratios, img_size[0], img_size[1], self._tile_size
- )
-
- # Apply tiling augmentation if enabled
- if data_augment and isinstance(media, ImageMedia) and random.random() < tiling_augment_prob:
- tiling = self.augment_tiling(tiling)
-
- blocks = tiling[0] * tiling[1]
- # Each tile is tile_size x tile_size
- per_tile_emb = self._get_num_embeddings(self._tile_size, self._tile_size)
- num_embeddings = blocks * per_tile_emb
-
- num_tiles = blocks
- if self._use_thumbnail and blocks != 1:
- num_tiles += 1
- num_embeddings += self._get_num_embeddings(self._tile_size, self._tile_size)
-
- media_params = MaskedTilingDynamicResolutionParams(
- media=media,
- num_tiles=num_tiles,
- num_embeddings=num_embeddings,
- tiling=tiling,
- )
- elif isinstance(media, VideoFrameMedia):
- video_params = self._video_frame_strategy.compute_params(
- [media], 1 * self._embeddings_per_tile
- )[0]
- media_params = MaskedTilingDynamicResolutionParams(
- media=media,
- num_tiles=video_params.num_tiles,
- num_embeddings=video_params.num_embeddings,
- tiling=(1, 1),
- )
- else:
- raise ValueError(f"Unsupported media type: {type(media)}")
-
- params.append(media_params)
- total_embeddings_needed += media_params.num_embeddings
-
- if not self._enable_tile_degradation:
- break
- if max_num_tiles_to_use == 1 or num_tokens_available is None:
- break
- if total_embeddings_needed > num_tokens_available:
- if max_num_tiles_to_use in degradation_map:
- max_num_tiles_to_use = degradation_map[max_num_tiles_to_use]
- if max_num_tiles_to_use not in self.target_ratios:
- self.target_ratios[max_num_tiles_to_use] = sorted(
- set(
- (x, y)
- for n in range(self._min_num_tiles, max_num_tiles_to_use + 1)
- for x in range(1, n + 1)
- for y in range(1, n + 1)
- if x * y <= max_num_tiles_to_use and x * y >= self._min_num_tiles
- ),
- key=lambda x: x[0] * x[1],
- )
- else:
- break
- else:
- break
-
- return params
-
- def augment_tiling(self, tiling: tuple[int, int]) -> tuple[int, int]:
- def num_tiles(tiling: tuple[int, int]) -> int:
- return tiling[0] * tiling[1]
-
- def plus_minus_one(tiling: tuple[int, int], minus_prob: float = 0.65) -> tuple[int, int]:
- if random.random() < minus_prob:
- # Minus one
- if tiling[0] == 1 and tiling[1] == 1:
- return tiling
- elif tiling[0] == 1:
- return (tiling[0], tiling[1] - 1)
- elif tiling[1] == 1:
- return (tiling[0] - 1, tiling[1])
- else:
- if random.random() < 0.5:
- return (tiling[0] - 1, tiling[1])
- else:
- return (tiling[0], tiling[1] - 1)
- else:
- # Plus one
- if num_tiles(tiling) < self._max_num_tiles:
- tiling0 = (tiling[0] + 1, tiling[1])
- tiling1 = (tiling[0], tiling[1] + 1)
- if num_tiles(tiling0) > self._max_num_tiles and num_tiles(tiling1) > self._max_num_tiles:
- return tiling
- elif num_tiles(tiling0) > self._max_num_tiles:
- return tiling1
- elif num_tiles(tiling1) > self._max_num_tiles:
- return tiling0
- else:
- if random.random() < 0.5:
- return tiling0
- else:
- return tiling1
- return tiling
-
- new_tiling = plus_minus_one(tiling)
- return new_tiling
-
- def stack(
- self, images: list[torch.Tensor]
- ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int], list[int]]:
- # Identical to dynamic resolution packing; each tile is already an independent image sample
- imgs_sizes = torch.tensor(
- [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
- )
-
- def rearrange_img(x):
- py = x.shape[-2] // self._patch_size
- px = x.shape[-1] // self._patch_size
- x = einops.rearrange(
- x,
- "c (py yy) (px xx) -> (py px) (c yy xx)",
- py=py,
- yy=self._patch_size,
- px=px,
- xx=self._patch_size,
- )
- return x
-
- if len(images) > 0:
- imgs = [rearrange_img(img) for img in images]
-
- current_length = 0
- max_length = 0
- vision_cu_lengths = [0]
- for img in imgs:
- if max_length < img.shape[0]:
- max_length = img.shape[0]
- current_length += img.shape[0]
- vision_cu_lengths.append(current_length)
-
- vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
- vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
-
- return (
- torch.cat(imgs, dim=0).unsqueeze(0),
- imgs_sizes,
- vision_cu_lengths,
- vision_max_lengths,
- )
- else:
- return (
- torch.tensor([[0]], dtype=torch.float32),
- torch.tensor([[0,0]], dtype=torch.int32),
- None,
- None,
- )
-
- def __str__(self):
- return f"MaskedTilingDynamicResolutionStrategy(vision_model_type={self._vision_model_type}, tile_size={self._tile_size}, use_thumbnail={self._use_thumbnail}, min_num_tiles={self._min_num_tiles}, max_num_tiles={self._max_num_tiles}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, enable_tile_degradation={self._enable_tile_degradation}, video_frame_strategy={self._video_frame_strategy})"
-
-def create_image_tiling_strategy(args):
- """
- Create an image tiling strategy based on the provided arguments.
-
- This function encapsulates the logic for creating the appropriate image tiling strategy
- based on the training/evaluation configuration. It can be used by both training (task_encoder)
- and evaluation code outside of data_loading/.
-
- Args:
- args: Arguments object with the following relevant attributes:
- - img_h, img_w: Image height and width
- - patch_dim: Patch dimension
- - vision_model_type: Vision model type (e.g., 'radio', 'clip', 'siglip')
- - disable_vision_class_token: Whether to disable vision class token
- - pixel_shuffle: Whether to use pixel shuffle
- - use_tile_tags: Whether to use tile tags
- - max_num_tiles: Maximum number of tiles
- - tokenizer_prompt_format: Tokenizer prompt format
- - image_break_token: Image break token (optional)
- - conv_merging: Whether to use convolution merging
- - dynamic_resolution: Whether to use dynamic resolution
- - match_tiling_dynamic_resolution: Whether to match tiling with dynamic resolution
- - use_area_weighted_aspect_ratio: Whether to use area-weighted aspect ratio
- - use_thumbnail: Whether to use thumbnail
- - dynamic_resolution_min_patches: Minimum number of patches for dynamic resolution
- - dynamic_resolution_min_side: Minimum side length for dynamic resolution (optional)
- - thumbnail_area_threshold: Thumbnail area threshold (optional)
- - use_tiling: Whether to use tiling
-
- Returns:
- ImageTilingStrategy: The created image tiling strategy
- """
- from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
-
- assert args.img_h == args.img_w, "img_h and img_w must be the same"
-
- match_tiling_dynamic_resolution = args.match_tiling_dynamic_resolution
- masked_tiling_dynamic_resolution = getattr(args, "masked_tiling_dynamic_resolution", False)
- dynamic_resolution = args.dynamic_resolution
- use_tiling = args.use_tiling
- use_area_weighted_aspect_ratio = args.use_area_weighted_aspect_ratio
-
- if match_tiling_dynamic_resolution:
- assert dynamic_resolution, "must enable --dynamic-resolution if using --match-tiling-dynamic-resolution"
- assert not use_tiling, "cannot use --use-tiling and --match-tiling-dynamic-resolution together"
- if masked_tiling_dynamic_resolution:
- assert dynamic_resolution, "must enable --dynamic-resolution if using --masked-tiling-dynamic-resolution"
- assert not use_tiling, "cannot use --use-tiling and --masked-tiling-dynamic-resolution together"
- assert not match_tiling_dynamic_resolution, "cannot combine --masked-tiling-dynamic-resolution with --match-tiling-dynamic-resolution"
-
- if dynamic_resolution:
- if masked_tiling_dynamic_resolution:
- num_image_embeddings_per_tile = get_num_image_embeddings(
- img_h=args.img_h,
- img_w=args.img_w,
- patch_dim=args.patch_dim,
- vision_model_type=args.vision_model_type,
- disable_vision_class_token=args.disable_vision_class_token,
- class_token_len=1,
- pixel_shuffle=args.pixel_shuffle,
- use_tile_tags=args.use_tile_tags,
- max_num_tiles=args.max_num_tiles,
- tokenizer_type=args.tokenizer_prompt_format,
- use_image_break_token=args.image_break_token is not None,
- conv_merging=args.conv_merging,
- )
- image_tiling_strategy = MaskedTilingDynamicResolutionStrategy(
- vision_model_type=args.vision_model_type,
- tile_size=args.img_h,
- use_thumbnail=args.use_thumbnail,
- min_num_tiles=1,
- max_num_tiles=args.max_num_tiles,
- embeddings_per_tile=num_image_embeddings_per_tile,
- patch_size=args.patch_dim,
- get_num_embeddings=lambda width, height: get_num_image_embeddings(
- img_h=height,
- img_w=width,
- patch_dim=args.patch_dim,
- vision_model_type=args.vision_model_type,
- disable_vision_class_token=args.disable_vision_class_token,
- class_token_len=1,
- pixel_shuffle=args.pixel_shuffle,
- use_tile_tags=args.use_tile_tags,
- max_num_tiles=args.max_num_tiles,
- tokenizer_type=args.tokenizer_prompt_format,
- use_image_break_token=args.image_break_token is not None,
- conv_merging=args.conv_merging,
- ),
- find_closest_aspect_ratio_fn=(
- find_closest_area_weighted_aspect_ratio
- if use_area_weighted_aspect_ratio
- else find_closest_aspect_ratio
- ),
- pixel_shuffle=args.pixel_shuffle,
- conv_merging=args.conv_merging,
- )
- elif match_tiling_dynamic_resolution:
- num_image_embeddings_per_tile = get_num_image_embeddings(
- img_h=args.img_h,
- img_w=args.img_w,
- patch_dim=args.patch_dim,
- vision_model_type=args.vision_model_type,
- disable_vision_class_token=args.disable_vision_class_token,
- class_token_len=1,
- pixel_shuffle=args.pixel_shuffle,
- use_tile_tags=args.use_tile_tags,
- max_num_tiles=args.max_num_tiles,
- tokenizer_type=args.tokenizer_prompt_format,
- use_image_break_token=args.image_break_token is not None,
- conv_merging=args.conv_merging,
- )
- image_tiling_strategy = MatchTilingDynamicResolutionStrategy(
- vision_model_type=args.vision_model_type,
- tile_size=args.img_h,
- use_thumbnail=args.use_thumbnail,
- min_num_tiles=1,
- max_num_tiles=args.max_num_tiles,
- embeddings_per_tile=num_image_embeddings_per_tile,
- patch_size=args.patch_dim,
- get_num_embeddings=lambda width, height: get_num_image_embeddings(
- img_h=height,
- img_w=width,
- patch_dim=args.patch_dim,
- vision_model_type=args.vision_model_type,
- disable_vision_class_token=args.disable_vision_class_token,
- class_token_len=1,
- pixel_shuffle=args.pixel_shuffle,
- use_tile_tags=args.use_tile_tags,
- max_num_tiles=args.max_num_tiles,
- tokenizer_type=args.tokenizer_prompt_format,
- use_image_break_token=args.image_break_token is not None,
- conv_merging=args.conv_merging,
- ),
- find_closest_aspect_ratio_fn=(
- find_closest_area_weighted_aspect_ratio
- if use_area_weighted_aspect_ratio
- else find_closest_aspect_ratio
- ),
- pixel_shuffle=args.pixel_shuffle,
- conv_merging=args.conv_merging,
- )
- else:
- image_tiling_strategy = DynamicResolutionImageTilingStrategy(
- vision_model_type=args.vision_model_type,
- min_num_patches=args.dynamic_resolution_min_patches,
- patch_size=args.patch_dim,
- get_num_embeddings=lambda width, height: get_num_image_embeddings(
- img_h=height,
- img_w=width,
- patch_dim=args.patch_dim,
- vision_model_type=args.vision_model_type,
- disable_vision_class_token=args.disable_vision_class_token,
- class_token_len=1,
- pixel_shuffle=args.pixel_shuffle,
- use_tile_tags=args.use_tile_tags,
- max_num_tiles=args.max_num_tiles,
- tokenizer_type=args.tokenizer_prompt_format,
- use_image_break_token=args.image_break_token is not None,
- conv_merging=args.conv_merging,
- ),
- pixel_shuffle=args.pixel_shuffle,
- min_side=args.dynamic_resolution_min_side,
- conv_merging=args.conv_merging,
- use_thumbnail=args.use_thumbnail,
- thumbnail_size=args.img_h,
- thumbnail_area_threshold=args.thumbnail_area_threshold,
- max_num_patches=args.dynamic_resolution_max_patches,
- apply_data_augment=args.apply_data_augment,
- )
- else:
- num_image_embeddings_per_tile = get_num_image_embeddings(
- img_h=args.img_h,
- img_w=args.img_w,
- patch_dim=args.patch_dim,
- vision_model_type=args.vision_model_type,
- disable_vision_class_token=args.disable_vision_class_token,
- class_token_len=1,
- pixel_shuffle=args.pixel_shuffle,
- use_tile_tags=args.use_tile_tags,
- max_num_tiles=args.max_num_tiles,
- tokenizer_type=args.tokenizer_prompt_format,
- use_image_break_token=args.image_break_token is not None,
- conv_merging=args.conv_merging,
- )
- if use_tiling:
- image_strategy = ImageTilingStrategyV1(
- vision_model_type=args.vision_model_type,
- tile_size=args.img_h,
- use_thumbnail=args.use_thumbnail,
- min_num_tiles=1,
- max_num_tiles=args.max_num_tiles,
- embeddings_per_tile=num_image_embeddings_per_tile,
- find_closest_aspect_ratio_fn=(
- find_closest_area_weighted_aspect_ratio
- if use_area_weighted_aspect_ratio
- else find_closest_aspect_ratio
- ),
- )
- else:
- image_strategy = NoTilingStrategy(
- vision_model_type=args.vision_model_type,
- embeddings_per_image=num_image_embeddings_per_tile,
- target_width=args.img_w,
- target_height=args.img_h,
- )
- image_tiling_strategy = TileDegradationStrategy(
- image_strategy=image_strategy,
- video_frame_strategy=NoTilingStrategy(
- vision_model_type=args.vision_model_type,
- embeddings_per_image=num_image_embeddings_per_tile,
- target_width=args.img_w,
- target_height=args.img_h,
- ),
- embeddings_per_tile=num_image_embeddings_per_tile,
- max_num_tiles=args.max_num_tiles,
- )
-
- return image_tiling_strategy
diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
index 82423891c11c7..3b8a3841cf938 100644
--- a/vllm/model_executor/models/nano_nemotron_vl.py
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -113,7 +113,7 @@ pixel_statistics = {
@dataclass
class DynamicResolutionParams:
- image: Image.Image
+ media: Image.Image
num_tiles: int
num_embeddings: int
patch_size: tuple[int, int]
@@ -165,7 +165,7 @@ class DynamicResolutionImageTilingStrategy:
self, params: DynamicResolutionParams, **kwargs
) -> list[torch.Tensor]:
# resize the image
- resized_img = params.image.resize(
+ resized_img = params.media.resize(
(
params.patch_size[0] * self._patch_size,
params.patch_size[1] * self._patch_size,
@@ -183,7 +183,7 @@ class DynamicResolutionImageTilingStrategy:
# Only add thumbnail if resized image area is less than threshold % of
# thumbnail area
if area_ratio < self._thumbnail_area_threshold:
- thumbnail_img = params.image.resize(
+ thumbnail_img = params.media.resize(
(self._thumbnail_size, self._thumbnail_size)
)
processed_images.append(thumbnail_img)
@@ -192,7 +192,7 @@ class DynamicResolutionImageTilingStrategy:
def process_media(
self,
- image: Image.Image,
+ media: Image.Image,
num_tokens_available: int,
data_augment: bool = False,
tiling_augment_prob: float = 0.4,
@@ -207,10 +207,10 @@ class DynamicResolutionImageTilingStrategy:
DynamicResolutionParams for the media
"""
current_num_tokens_available = num_tokens_available
- assert isinstance(image, Image.Image), (
+ assert isinstance(media, Image.Image), (
"Dynamic resolution is only supported for image media"
)
- orig_width, orig_height = image.width, image.height
+ orig_width, orig_height = media.width, media.height
closest_patch_height = round(orig_height / self._patch_size + 0.5)
closest_patch_width = round(orig_width / self._patch_size + 0.5)
@@ -336,7 +336,7 @@ class DynamicResolutionImageTilingStrategy:
target_patch_width, target_patch_height, current_num_tokens_available
)
- assert isinstance(image, Image.Image), (
+ assert isinstance(media, Image.Image), (
"Dynamic resolution is only supported for image media"
)
@@ -374,7 +374,7 @@ class DynamicResolutionImageTilingStrategy:
)
return DynamicResolutionParams(
- image=image,
+ media=media,
num_tiles=num_tiles,
num_embeddings=num_embeddings,
patch_size=(target_patch_width, target_patch_height),
From 3be49316c91fcd6a7f9579d21a2fb7307e9c4e1e Mon Sep 17 00:00:00 2001
From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Date: Mon, 15 Dec 2025 17:37:46 +0200
Subject: [PATCH 06/10] refactor
---
.../model_executor/models/nano_nemotron_vl.py | 1060 ++++++++---------
1 file changed, 511 insertions(+), 549 deletions(-)
diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
index 3b8a3841cf938..3bac36744ef5e 100644
--- a/vllm/model_executor/models/nano_nemotron_vl.py
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -12,7 +12,7 @@ import math
import random
import warnings
from abc import ABC, abstractmethod
-from collections.abc import Callable, Iterable, Mapping, Sequence
+from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
@@ -88,501 +88,9 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Alternative: Set a specific higher limit
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
-IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
-IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
-SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
-SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
-CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
-CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
-RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060]
-RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250]
-
-pixel_statistics = {
- "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
- "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
- "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
- "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
- "radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
-}
-
-
-@dataclass
-class DynamicResolutionParams:
- media: Image.Image
- num_tiles: int
- num_embeddings: int
- patch_size: tuple[int, int]
-
-
-class DynamicResolutionImageTilingStrategy:
- def __init__(
- self,
- vision_model_type: str,
- min_num_patches: int,
- patch_size: int,
- get_num_embeddings: Callable[[int, int], int],
- factor_max: float = 1.0,
- pixel_shuffle: bool = False,
- min_side: int | None = None,
- conv_merging: bool = False,
- use_thumbnail: bool = False,
- thumbnail_size: int = 448,
- thumbnail_area_threshold: float = 0.8,
- max_num_patches: int = 0,
- apply_data_augment: bool = False,
- ):
- assert "radio" in vision_model_type, (
- "Dynamic resolution is only supported for radio models"
- )
- self._vision_model_type = vision_model_type
- self._min_num_patches = min_num_patches
- self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf")
- self._patch_size = patch_size
- self._get_num_embeddings = get_num_embeddings
- self._factor_max = factor_max
- self._pixel_shuffle = pixel_shuffle
- self._min_side = min_side
- self._conv_merging = conv_merging
- self._use_thumbnail = use_thumbnail
- self._thumbnail_size = thumbnail_size
- self._thumbnail_area_threshold = thumbnail_area_threshold
- pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
- self._transform = T.Compose(
- [
- T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
- T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)),
- T.Normalize(mean=pixel_mean, std=pixel_std),
- ]
- )
- self._apply_data_augment = apply_data_augment
-
- def apply_params(
- self, params: DynamicResolutionParams, **kwargs
- ) -> list[torch.Tensor]:
- # resize the image
- resized_img = params.media.resize(
- (
- params.patch_size[0] * self._patch_size,
- params.patch_size[1] * self._patch_size,
- )
- )
- processed_images = [resized_img]
-
- # Add thumbnail if enabled and image area is below threshold
- if self._use_thumbnail:
- # Calculate areas
- resized_area = resized_img.size[0] * resized_img.size[1]
- thumbnail_area = self._thumbnail_size * self._thumbnail_size
- area_ratio = resized_area / thumbnail_area
-
- # Only add thumbnail if resized image area is less than threshold % of
- # thumbnail area
- if area_ratio < self._thumbnail_area_threshold:
- thumbnail_img = params.media.resize(
- (self._thumbnail_size, self._thumbnail_size)
- )
- processed_images.append(thumbnail_img)
-
- return [self._transform(img) for img in processed_images]
-
- def process_media(
- self,
- media: Image.Image,
- num_tokens_available: int,
- data_augment: bool = False,
- tiling_augment_prob: float = 0.4,
- ) -> DynamicResolutionParams:
- """Process a single media item and return its parameters.
- Args:
- media: The media item to process
- num_tokens_available: Number of tokens available for this media
- data_augment: Whether to apply data augmentation to the image. Defaults to
- False.
- Returns:
- DynamicResolutionParams for the media
- """
- current_num_tokens_available = num_tokens_available
- assert isinstance(media, Image.Image), (
- "Dynamic resolution is only supported for image media"
- )
- orig_width, orig_height = media.width, media.height
-
- closest_patch_height = round(orig_height / self._patch_size + 0.5)
- closest_patch_width = round(orig_width / self._patch_size + 0.5)
- patches = closest_patch_height * closest_patch_width
-
- factor = min(
- math.sqrt(current_num_tokens_available / patches), self._factor_max
- )
- target_patch_height = math.floor(factor * closest_patch_height)
- target_patch_width = math.floor(factor * closest_patch_width)
-
- # We only consider self._min_num_patches if it is greater than
- # current_num_tokens_available.
- if (
- current_num_tokens_available > self._min_num_patches
- and target_patch_height * target_patch_width < self._min_num_patches
- ):
- up_factor = math.sqrt(
- self._min_num_patches / (target_patch_height * target_patch_width)
- )
- target_patch_height = math.ceil(up_factor * target_patch_height)
- target_patch_width = math.ceil(up_factor * target_patch_width)
-
- if (
- self._min_side is not None
- and min(target_patch_width, target_patch_height) * self._patch_size
- < self._min_side
- ):
- if target_patch_width <= target_patch_height:
- up_factor = self._min_side / (target_patch_width * self._patch_size)
- new_patch_height = math.ceil(up_factor * target_patch_height)
- new_patch_width = math.ceil(up_factor * target_patch_width)
-
- if new_patch_height * new_patch_width > current_num_tokens_available:
- # If only one side can be min_side, make as big as possible at
- # native aspect ratio while staying below max_patches
- if (
- max(current_num_tokens_available // new_patch_width, 1)
- * self._patch_size
- < self._min_side
- ):
- up_factor = math.sqrt(
- current_num_tokens_available
- / (target_patch_height * target_patch_width)
- )
- target_patch_height = math.floor(
- up_factor * target_patch_height
- )
- target_patch_width = math.floor(up_factor * target_patch_width)
- target_patch_width = new_patch_width
- target_patch_height = max(
- current_num_tokens_available // new_patch_width, 1
- )
- else:
- target_patch_height = new_patch_height
- target_patch_width = new_patch_width
- else:
- up_factor = self._min_side / (target_patch_height * self._patch_size)
- new_patch_height = math.ceil(up_factor * target_patch_height)
- new_patch_width = math.ceil(up_factor * target_patch_width)
-
- if new_patch_height * new_patch_width > current_num_tokens_available:
- # If only one side can be min_side, make as big as possible at
- # native aspect ratio while staying below max_patches
- if (
- max(current_num_tokens_available // new_patch_height, 1)
- * self._patch_size
- < self._min_side
- ):
- up_factor = math.sqrt(
- current_num_tokens_available
- / (target_patch_height * target_patch_width)
- )
- target_patch_height = math.floor(
- up_factor * target_patch_height
- )
- target_patch_width = math.floor(up_factor * target_patch_width)
- else:
- target_patch_height = new_patch_height
- target_patch_width = max(
- current_num_tokens_available // new_patch_height, 1
- )
- else:
- target_patch_height = new_patch_height
- target_patch_width = new_patch_width
-
- # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
- # or by 4 when BOTH are enabled (two successive 2x reductions)
- if self._pixel_shuffle or self._conv_merging:
- required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2
-
- rem_h = target_patch_height % required_divisor
- if rem_h != 0:
- inc_h = required_divisor - rem_h
- if (
- target_patch_height + inc_h
- ) * target_patch_width <= current_num_tokens_available:
- target_patch_height += inc_h
- else:
- target_patch_height = max(
- required_divisor, target_patch_height - rem_h
- )
-
- rem_w = target_patch_width % required_divisor
- if rem_w != 0:
- inc_w = required_divisor - rem_w
- if (
- target_patch_height * (target_patch_width + inc_w)
- <= current_num_tokens_available
- ):
- target_patch_width += inc_w
- else:
- target_patch_width = max(
- required_divisor, target_patch_width - rem_w
- )
-
- if (
- data_augment
- and self._apply_data_augment
- and random.random() < tiling_augment_prob
- ):
- target_patch_width, target_patch_height = self.augment_resolution(
- target_patch_width, target_patch_height, current_num_tokens_available
- )
-
- assert isinstance(media, Image.Image), (
- "Dynamic resolution is only supported for image media"
- )
-
- # Calculate embeddings for the main dynamic resolution image
- num_embeddings = self._get_num_embeddings(
- target_patch_width * self._patch_size,
- target_patch_height * self._patch_size,
- )
-
- token_count = target_patch_width * target_patch_height
-
- # Add thumbnail embeddings if enabled and image area is below threshold
- num_tiles = 1 # Base dynamic resolution image
- if self._use_thumbnail:
- # Calculate areas
- resized_area = (target_patch_width * self._patch_size) * (
- target_patch_height * self._patch_size
- )
- thumbnail_area = self._thumbnail_size * self._thumbnail_size
- area_ratio = resized_area / thumbnail_area
-
- # Only add thumbnail if resized image area is less than threshold % of
- # thumbnail area
- if area_ratio < self._thumbnail_area_threshold:
- num_tiles += 1 # Add 1 for thumbnail
- # Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
- num_embeddings += self._get_num_embeddings(
- self._thumbnail_size, self._thumbnail_size
- )
- token_count += (
- self._thumbnail_size
- // self._patch_size
- * self._thumbnail_size
- // self._patch_size
- )
-
- return DynamicResolutionParams(
- media=media,
- num_tiles=num_tiles,
- num_embeddings=num_embeddings,
- patch_size=(target_patch_width, target_patch_height),
- ), token_count
-
- def augment_resolution(
- self,
- target_patch_width: int,
- target_patch_height: int,
- current_num_tokens_available: int,
- ) -> tuple[int, int]:
- min_num_patch_one_side = 32
-
- if random.random() < 0.5:
- # Minus one
- if (
- target_patch_width <= min_num_patch_one_side
- and target_patch_height <= min_num_patch_one_side
- ):
- return target_patch_width, target_patch_height
- elif target_patch_width <= min_num_patch_one_side:
- return target_patch_width, target_patch_height - min_num_patch_one_side
- elif target_patch_height <= min_num_patch_one_side:
- return target_patch_width - min_num_patch_one_side, target_patch_height
- else:
- if random.random() < 0.5:
- return (
- target_patch_width - min_num_patch_one_side,
- target_patch_height,
- )
- else:
- return (
- target_patch_width,
- target_patch_height - min_num_patch_one_side,
- )
- else:
- # Plus one
- if target_patch_width * target_patch_height < current_num_tokens_available:
- if random.random() < 0.5:
- return (
- target_patch_width + min_num_patch_one_side,
- target_patch_height,
- )
- else:
- return (
- target_patch_width,
- target_patch_height + min_num_patch_one_side,
- )
- return target_patch_width, target_patch_height
-
- def compute_params(
- self,
- media_list: list[Image.Image],
- num_tokens_available: int | None = None,
- max_num_tiles: int | None = None,
- data_augment: bool = False,
- **kwargs,
- ) -> list[DynamicResolutionParams]:
- """Compute parameters for all media with iterative token budgeting.
-
- Args:
- media_list: List of media items to process
- num_tokens_available: Total number of tokens available across all media
- max_num_tiles: Maximum number of tiles (unused in this implementation)
- data_augment: Whether to apply data augmentation to the image. Defaults to
- False.
- Returns:
- List of ImageTilingParams for each media item
- """
- num_tokens_available = (
- num_tokens_available
- * (4 if self._pixel_shuffle else 1)
- * (4 if self._conv_merging else 1)
- )
- # When the number of available token is too small, allow self._min_num_patches
- # per media and let the sample be truncated.
- num_tokens_available = max(
- num_tokens_available, self._min_num_patches * len(media_list)
- )
-
- # Clip the number of tokens available per media to be between min and max
- # patches.
- num_tokens_available_per_media = [
- max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
- for _ in range(len(media_list))
- ]
-
- # In theory this could be a while True loop, but in case the process_media
- # method slightly
- # changes, I want to make sure we don't get stuck in an infinite loop.
- for _ in range(10):
- # Step 1: Process each media with current token budget
- params = []
- token_counts = []
-
- for media, tokens_for_media in zip(
- media_list, num_tokens_available_per_media
- ):
- param, token_count = self.process_media(
- media, tokens_for_media, data_augment=data_augment
- )
- params.append(param)
- token_counts.append(token_count)
-
- # Step 2: Check if total tokens is within budget
- total_tokens = sum(token_counts)
-
- if total_tokens <= num_tokens_available:
- # We're within budget, return the params
- return params
-
- # Step 3: We're over budget, need to scale down
- # Calculate scaling factor to get under budget
- scaling_factor = num_tokens_available / total_tokens
-
- # Recalculate token budgets for each media based on scaling
- # Each media gets a proportional share of the total budget
- scaled_down_num_tokens_available_per_media = [
- max(self._min_num_patches, int(token_count * scaling_factor))
- for token_count in token_counts
- ]
- scaled_down = any(
- [
- scaled_down_num_tokens_available_per_media[i]
- < num_tokens_available_per_media[i]
- for i in range(len(num_tokens_available_per_media))
- ]
- )
- # If there was not scaling down, we're stuck just use min_num_patches per
- # media, else try with the scaled down num_tokens_available_per_media.
- if not scaled_down:
- num_tokens_available_per_media = [self._min_num_patches] * len(
- media_list
- )
- else:
- num_tokens_available_per_media = (
- scaled_down_num_tokens_available_per_media
- )
- return params
-
- def stack(
- self, images: list[torch.Tensor]
- ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
- imgs_sizes = torch.tensor(
- [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
- )
-
- def rearrange_img(x):
- py = x.shape[-2] // self._patch_size
- px = x.shape[-1] // self._patch_size
- x = einops.rearrange(
- x,
- "c (py yy) (px xx) -> (py px) (c yy xx)",
- py=py,
- yy=self._patch_size,
- px=px,
- xx=self._patch_size,
- )
- return x
-
- if len(images) > 0:
- imgs = [rearrange_img(img) for img in images]
-
- current_length = 0
- max_length = 0
- vision_cu_lengths = [0]
- for img in imgs:
- if max_length < img.shape[0]:
- max_length = img.shape[0]
- current_length += img.shape[0]
- vision_cu_lengths.append(current_length)
-
- vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
- vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
-
- return (
- torch.cat(imgs, dim=0).unsqueeze(0),
- imgs_sizes,
- vision_cu_lengths,
- vision_max_lengths,
- )
- else:
- return (
- torch.tensor([[0]], dtype=torch.float32),
- torch.tensor([[0, 0]], dtype=torch.int32),
- None,
- None,
- )
-
- def __str__(self):
- return f"DynamicResolutionImageTransform(\
- vision_model_type={self._vision_model_type}, \
- min_num_patches={self._min_num_patches}, \
- patch_size={self._patch_size}, \
- pixel_shuffle={self._pixel_shuffle}, \
- conv_merging={self._conv_merging}, \
- use_thumbnail={self._use_thumbnail}, \
- thumbnail_size={self._thumbnail_size}, \
- thumbnail_area_threshold={self._thumbnail_area_threshold})"
-
-
-image_tiling_strategy = DynamicResolutionImageTilingStrategy(
- vision_model_type="radio",
- min_num_patches=4,
- patch_size=16,
- get_num_embeddings=lambda x, y: x * y * 2,
- max_num_patches=64,
-)
+# TODO(nhaber): get 2048 from config
+# TODO(nhaber): does use_thumbnail=True work?
IMG_START = "
"
@@ -753,7 +261,12 @@ def video_to_pixel_values(
return torch.stack(frames_tensors)
-def input_conditioner(x, norm_mean, norm_std):
+def input_conditioner(
+ x: torch.Tensor, norm_mean: torch.Tensor, norm_std: torch.Tensor
+) -> torch.Tensor:
+ assert isinstance(x, torch.Tensor), "x must be a tensor"
+ assert isinstance(norm_mean, torch.Tensor), "norm_mean must be a tensor"
+ assert isinstance(norm_std, torch.Tensor), "norm_std must be a tensor"
return (x - norm_mean) / norm_std
@@ -792,15 +305,20 @@ class BaseNanoNemotronVLProcessor(ABC):
self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES
image_size: int = config.force_image_size
- patch_size: int = config.patch_size
+ self.patch_size: int = getattr(config, "patch_size", 16)
+ self.downsample_ratio: float = self.config.downsample_ratio
- self.num_image_token = int(
- (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
- )
self.image_size = image_size
self.use_thumbnail: bool = config.use_thumbnail
- self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1)
- self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1)
+ self.norm_mean = torch.tensor(config.norm_mean).reshape(1, 3, 1, 1)
+ self.norm_std = torch.tensor(config.norm_std).reshape(1, 3, 1, 1)
+
+ def num_image_token(self, *, image_width: int, image_height: int) -> int:
+ image_size = math.sqrt(image_width * image_height)
+ num_tokens = int(
+ (image_size // self.patch_size) ** 2 * (self.downsample_ratio**2)
+ )
+ return num_tokens
@property
@abstractmethod
@@ -832,10 +350,13 @@ class BaseNanoNemotronVLProcessor(ABC):
use_thumbnail=self.use_thumbnail,
)
- return num_patches * self.num_image_token
+ return num_patches * self.num_image_token(
+ image_width=image_width, image_height=image_height
+ )
def _images_to_pixel_values_lst(
self,
+ text: list[str],
images: list[Image.Image],
max_num_tiles: int,
) -> list[torch.Tensor]:
@@ -859,7 +380,9 @@ class BaseNanoNemotronVLProcessor(ABC):
if len(images) == 0:
image_inputs = {}
else:
- pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
+ pixel_values_lst = self._images_to_pixel_values_lst(
+ text=text, images=images, max_num_tiles=max_num_tiles
+ )
image_inputs = {
"pixel_values_flat": input_conditioner(
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
@@ -881,7 +404,10 @@ class BaseNanoNemotronVLProcessor(ABC):
for i, pixel_values in enumerate(pixel_values_lst):
num_patches = pixel_values.shape[0]
- feature_size = num_patches * self.num_image_token
+ feature_size = num_patches * self.num_image_token(
+ image_width=pixel_values.shape[1],
+ image_height=pixel_values.shape[2],
+ )
image_repl = self.get_image_repl(feature_size, num_patches)
parts[i] = parts[i].replace("", image_repl.full)
text = ["".join(parts)]
@@ -894,6 +420,7 @@ class BaseNanoNemotronVLProcessor(ABC):
input_item = [input_item]
return input_item
+ @abstractmethod
def __call__(
self,
text: str | list[str] | None = None,
@@ -901,26 +428,487 @@ class BaseNanoNemotronVLProcessor(ABC):
return_tensors: str | TensorType | None = None,
max_num_tiles: int | None = None,
) -> BatchFeature:
- # Use default if not provided
- if max_num_tiles is None:
- max_num_tiles = self.max_num_tiles
+ raise NotImplementedError
- text, images = [self._make_batch_input(x) for x in (text, images)]
- text, image_inputs = self._preprocess_image(
- text=text,
- images=images,
- max_num_tiles=max_num_tiles,
+@dataclass
+class DynamicResolutionParams:
+ media: Image.Image
+ num_tiles: int
+ num_embeddings: int
+ patch_size: tuple[int, int]
+
+
+class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
+ CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
+ CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ tokenizer: TokenizerLike,
+ *args,
+ max_num_tiles: int | None = None,
+ min_num_patches: int = 4,
+ factor_max: float = 1.0,
+ pixel_shuffle: bool = True,
+ min_side: int | None = None,
+ conv_merging: bool = False,
+ use_thumbnail: bool = False,
+ thumbnail_size: int = 448,
+ thumbnail_area_threshold: float = 0.8,
+ apply_data_augment: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ config=config, tokenizer=tokenizer, max_num_tiles=max_num_tiles, **kwargs
+ )
+ self._min_num_patches = min_num_patches
+ self._factor_max = factor_max
+ self._pixel_shuffle = pixel_shuffle
+ self._min_side = min_side
+ self._conv_merging = conv_merging
+ self._use_thumbnail = use_thumbnail
+ self._thumbnail_size = thumbnail_size
+ self._thumbnail_area_threshold = thumbnail_area_threshold
+ self._transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)),
+ ]
+ )
+ self._apply_data_augment = apply_data_augment
+
+ self.norm_mean = torch.tensor(self.CLIP_PIXEL_MEAN).reshape(1, 3, 1, 1)
+ self.norm_std = torch.tensor(self.CLIP_PIXEL_STD).reshape(1, 3, 1, 1)
+ self.downsample_ratio = 2 if pixel_shuffle else 1
+
+ def apply_params(self, params: DynamicResolutionParams) -> torch.Tensor:
+ resized_img = params.media.resize(
+ (
+ params.patch_size[0] * self.patch_size,
+ params.patch_size[1] * self.patch_size,
+ )
+ )
+ # processed_images = [resized_img]
+
+ # # Add thumbnail if enabled and image area is below threshold
+ # if self._use_thumbnail:
+ # # Calculate areas
+ # resized_area = resized_img.size[0] * resized_img.size[1]
+ # thumbnail_area = self._thumbnail_size * self._thumbnail_size
+ # area_ratio = resized_area / thumbnail_area
+
+ # # Only add thumbnail if resized image area is less than threshold % of
+ # # thumbnail area
+ # if area_ratio < self._thumbnail_area_threshold:
+ # thumbnail_img = params.media.resize(
+ # (self._thumbnail_size, self._thumbnail_size)
+ # )
+ # processed_images.append(thumbnail_img)
+
+ return self._transform(resized_img)
+
+ def process_media(
+ self,
+ media: Image.Image,
+ num_tokens_available: int,
+ data_augment: bool = False,
+ tiling_augment_prob: float = 0.4,
+ ) -> DynamicResolutionParams:
+ """Process a single media item and return its parameters.
+ Args:
+ media: The media item to process
+ num_tokens_available: Number of tokens available for this media
+ data_augment: Whether to apply data augmentation to the image. Defaults to
+ False.
+ Returns:
+ DynamicResolutionParams for the media
+ """
+ current_num_tokens_available = num_tokens_available
+ assert isinstance(media, Image.Image), (
+ "Dynamic resolution is only supported for image media"
+ )
+ orig_width, orig_height = media.width, media.height
+
+ closest_patch_height = round(orig_height / self.patch_size + 0.5)
+ closest_patch_width = round(orig_width / self.patch_size + 0.5)
+ patches = closest_patch_height * closest_patch_width
+
+ factor = min(
+ math.sqrt(current_num_tokens_available / patches), self._factor_max
+ )
+ target_patch_height = math.floor(factor * closest_patch_height)
+ target_patch_width = math.floor(factor * closest_patch_width)
+
+ # We only consider self._min_num_patches if it is greater than
+ # current_num_tokens_available.
+ if (
+ current_num_tokens_available > self._min_num_patches
+ and target_patch_height * target_patch_width < self._min_num_patches
+ ):
+ up_factor = math.sqrt(
+ self._min_num_patches / (target_patch_height * target_patch_width)
+ )
+ target_patch_height = math.ceil(up_factor * target_patch_height)
+ target_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if (
+ self._min_side is not None
+ and min(target_patch_width, target_patch_height) * self.patch_size
+ < self._min_side
+ ):
+ if target_patch_width <= target_patch_height:
+ up_factor = self._min_side / (target_patch_width * self.patch_size)
+ new_patch_height = math.ceil(up_factor * target_patch_height)
+ new_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if new_patch_height * new_patch_width > current_num_tokens_available:
+ # If only one side can be min_side, make as big as possible at
+ # native aspect ratio while staying below max_patches
+ if (
+ max(current_num_tokens_available // new_patch_width, 1)
+ * self.patch_size
+ < self._min_side
+ ):
+ up_factor = math.sqrt(
+ current_num_tokens_available
+ / (target_patch_height * target_patch_width)
+ )
+ target_patch_height = math.floor(
+ up_factor * target_patch_height
+ )
+ target_patch_width = math.floor(up_factor * target_patch_width)
+ target_patch_width = new_patch_width
+ target_patch_height = max(
+ current_num_tokens_available // new_patch_width, 1
+ )
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = new_patch_width
+ else:
+ up_factor = self._min_side / (target_patch_height * self.patch_size)
+ new_patch_height = math.ceil(up_factor * target_patch_height)
+ new_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if new_patch_height * new_patch_width > current_num_tokens_available:
+ # If only one side can be min_side, make as big as possible at
+ # native aspect ratio while staying below max_patches
+ if (
+ max(current_num_tokens_available // new_patch_height, 1)
+ * self.patch_size
+ < self._min_side
+ ):
+ up_factor = math.sqrt(
+ current_num_tokens_available
+ / (target_patch_height * target_patch_width)
+ )
+ target_patch_height = math.floor(
+ up_factor * target_patch_height
+ )
+ target_patch_width = math.floor(up_factor * target_patch_width)
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = max(
+ current_num_tokens_available // new_patch_height, 1
+ )
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = new_patch_width
+
+ # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
+ # or by 4 when BOTH are enabled (two successive 2x reductions)
+ if self._pixel_shuffle or self._conv_merging:
+ required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2
+
+ rem_h = target_patch_height % required_divisor
+ if rem_h != 0:
+ inc_h = required_divisor - rem_h
+ if (
+ target_patch_height + inc_h
+ ) * target_patch_width <= current_num_tokens_available:
+ target_patch_height += inc_h
+ else:
+ target_patch_height = max(
+ required_divisor, target_patch_height - rem_h
+ )
+
+ rem_w = target_patch_width % required_divisor
+ if rem_w != 0:
+ inc_w = required_divisor - rem_w
+ if (
+ target_patch_height * (target_patch_width + inc_w)
+ <= current_num_tokens_available
+ ):
+ target_patch_width += inc_w
+ else:
+ target_patch_width = max(
+ required_divisor, target_patch_width - rem_w
+ )
+
+ if (
+ data_augment
+ and self._apply_data_augment
+ and random.random() < tiling_augment_prob
+ ):
+ target_patch_width, target_patch_height = self.augment_resolution(
+ target_patch_width, target_patch_height, current_num_tokens_available
+ )
+
+ assert isinstance(media, Image.Image), (
+ "Dynamic resolution is only supported for image media"
)
- text_inputs = self.tokenizer(text, add_special_tokens=False)
+ # Calculate embeddings for the main dynamic resolution image
+ num_embeddings = self.num_image_token(
+ image_width=target_patch_width, image_height=target_patch_height
+ )
- combined_outputs = {**text_inputs, **image_inputs}
+ token_count = target_patch_width * target_patch_height
- return BatchFeature(combined_outputs, tensor_type=return_tensors)
+ # Add thumbnail embeddings if enabled and image area is below threshold
+ num_tiles = 1 # Base dynamic resolution image
+ if self._use_thumbnail:
+ # Calculate areas
+ resized_area = (target_patch_width * self.patch_size) * (
+ target_patch_height * self.patch_size
+ )
+ thumbnail_area = self._thumbnail_size * self._thumbnail_size
+ area_ratio = resized_area / thumbnail_area
+
+ # Only add thumbnail if resized image area is less than threshold % of
+ # thumbnail area
+ if area_ratio < self._thumbnail_area_threshold:
+ num_tiles += 1 # Add 1 for thumbnail
+ # Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
+ num_embeddings += self.num_image_token(
+ image_width=self._thumbnail_size, image_height=self._thumbnail_size
+ )
+ token_count += (
+ self._thumbnail_size
+ // self.patch_size
+ * self._thumbnail_size
+ // self.patch_size
+ )
+
+ return DynamicResolutionParams(
+ media=media,
+ num_tiles=num_tiles,
+ num_embeddings=num_embeddings,
+ patch_size=(target_patch_width, target_patch_height),
+ ), token_count
+
+ def augment_resolution(
+ self,
+ target_patch_width: int,
+ target_patch_height: int,
+ current_num_tokens_available: int,
+ ) -> tuple[int, int]:
+ min_num_patch_one_side = 32
+
+ if random.random() < 0.5:
+ # Minus one
+ if (
+ target_patch_width <= min_num_patch_one_side
+ and target_patch_height <= min_num_patch_one_side
+ ):
+ return target_patch_width, target_patch_height
+ elif target_patch_width <= min_num_patch_one_side:
+ return target_patch_width, target_patch_height - min_num_patch_one_side
+ elif target_patch_height <= min_num_patch_one_side:
+ return target_patch_width - min_num_patch_one_side, target_patch_height
+ else:
+ if random.random() < 0.5:
+ return (
+ target_patch_width - min_num_patch_one_side,
+ target_patch_height,
+ )
+ else:
+ return (
+ target_patch_width,
+ target_patch_height - min_num_patch_one_side,
+ )
+ else:
+ # Plus one
+ if target_patch_width * target_patch_height < current_num_tokens_available:
+ if random.random() < 0.5:
+ return (
+ target_patch_width + min_num_patch_one_side,
+ target_patch_height,
+ )
+ else:
+ return (
+ target_patch_width,
+ target_patch_height + min_num_patch_one_side,
+ )
+ return target_patch_width, target_patch_height
+
+ def compute_params(
+ self,
+ media_list: list[Image.Image],
+ num_tokens_available: int | None = None,
+ data_augment: bool = False,
+ ) -> list[DynamicResolutionParams]:
+ """Compute parameters for all media with iterative token budgeting.
+
+ Args:
+ media_list: List of media items to process
+ num_tokens_available: Total number of tokens available across all media
+ data_augment: Whether to apply data augmentation to the image. Defaults to
+ False.
+ Returns:
+ List of ImageTilingParams for each media item
+ """
+ num_tokens_available = (
+ num_tokens_available
+ * (4 if self._pixel_shuffle else 1)
+ * (4 if self._conv_merging else 1)
+ )
+ # When the number of available token is too small, allow self._min_num_patches
+ # per media and let the sample be truncated.
+ num_tokens_available = max(
+ num_tokens_available, self._min_num_patches * len(media_list)
+ )
+
+ # Clip the number of tokens available per media to be between min and max
+ # patches.
+ num_tokens_available_per_media = [
+ max(num_tokens_available, self._min_num_patches)
+ for _ in range(len(media_list))
+ ]
+
+ # In theory this could be a while True loop, but in case the process_media
+ # method slightly
+ # changes, I want to make sure we don't get stuck in an infinite loop.
+ for _ in range(10):
+ # Step 1: Process each media with current token budget
+ params = []
+ token_counts = []
+
+ for media, tokens_for_media in zip(
+ media_list, num_tokens_available_per_media
+ ):
+ param, token_count = self.process_media(
+ media, tokens_for_media, data_augment=data_augment
+ )
+ params.append(param)
+ token_counts.append(token_count)
+
+ # Step 2: Check if total tokens is within budget
+ total_tokens = sum(token_counts)
+
+ if total_tokens <= num_tokens_available:
+ # We're within budget, return the params
+ return params
+
+ # Step 3: We're over budget, need to scale down
+ # Calculate scaling factor to get under budget
+ scaling_factor = num_tokens_available / total_tokens
+
+ # Recalculate token budgets for each media based on scaling
+ # Each media gets a proportional share of the total budget
+ scaled_down_num_tokens_available_per_media = [
+ max(self._min_num_patches, int(token_count * scaling_factor))
+ for token_count in token_counts
+ ]
+ scaled_down = any(
+ [
+ scaled_down_num_tokens_available_per_media[i]
+ < num_tokens_available_per_media[i]
+ for i in range(len(num_tokens_available_per_media))
+ ]
+ )
+ # If there was not scaling down, we're stuck just use min_num_patches per
+ # media, else try with the scaled down num_tokens_available_per_media.
+ if not scaled_down:
+ num_tokens_available_per_media = [self._min_num_patches] * len(
+ media_list
+ )
+ else:
+ num_tokens_available_per_media = (
+ scaled_down_num_tokens_available_per_media
+ )
+ return params
+
+ def stack(
+ self, images: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
+ imgs_sizes = torch.tensor(
+ [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
+ )
+
+ def rearrange_img(x):
+ py = x.shape[-2] // self.patch_size
+ px = x.shape[-1] // self.patch_size
+ x = einops.rearrange(
+ x,
+ "c (py yy) (px xx) -> (py px) (c yy xx)",
+ py=py,
+ yy=self.patch_size,
+ px=px,
+ xx=self.patch_size,
+ )
+ return x
+
+ if len(images) > 0:
+ imgs = [rearrange_img(img) for img in images]
+
+ current_length = 0
+ max_length = 0
+ vision_cu_lengths = [0]
+ for img in imgs:
+ if max_length < img.shape[0]:
+ max_length = img.shape[0]
+ current_length += img.shape[0]
+ vision_cu_lengths.append(current_length)
+
+ vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
+ vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
+
+ return (
+ torch.cat(imgs, dim=0).unsqueeze(0),
+ imgs_sizes,
+ vision_cu_lengths,
+ vision_max_lengths,
+ )
+ else:
+ return (
+ torch.tensor([[0]], dtype=torch.float32),
+ torch.tensor([[0, 0]], dtype=torch.int32),
+ None,
+ None,
+ )
+
+ def _images_to_pixel_values_lst(
+ self,
+ text: list[str],
+ images: list[Image.Image],
+ max_num_tiles: int,
+ ) -> list[torch.Tensor]:
+ num_tokens_available = 2048 - len(text) - 4
+ params_per_image = self.compute_params(
+ images, num_tokens_available=num_tokens_available
+ )
+ images = []
+ for param in params_per_image:
+ t = self.apply_params(param)
+ if t.ndim == 3:
+ t = t.unsqueeze(0)
+ images.append(t)
+ return images
+
+ def __str__(self):
+ return f"DynamicResolutionImageTransform(\
+ min_num_patches={self._min_num_patches}, \
+ patch_size={self.patch_size}, \
+ pixel_shuffle={self._pixel_shuffle}, \
+ conv_merging={self._conv_merging}, \
+ use_thumbnail={self._use_thumbnail}, \
+ thumbnail_size={self._thumbnail_size}, \
+ thumbnail_area_threshold={self._thumbnail_area_threshold})"
-class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
+class NanoNemotronVLProcessor(DynamicResolutionImageTiler):
"""
HF Processor with extended video processing logic.
Code for video processing is adapted from video example:
@@ -1312,7 +1300,9 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
processor = self.get_hf_processor() # we get the CustomProcessor here
max_image_tokens = self.get_max_image_tokens() * max_images
- max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
+ max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token(
+ image_width=256, image_height=256
+ ) # TODO(nhaber): get 256 dynamically
max_frames_per_video = max_total_frames // max(max_videos, 1)
return max(max_frames_per_video, 1)
@@ -1457,7 +1447,9 @@ class NanoNemotronVLMultiModalProcessor(
video_num_patches = []
def get_video_replacement_internvl(item_idx: int):
- feature_size = hf_processor.num_image_token
+ feature_size = hf_processor.num_image_token(
+ image_width=256, image_height=256
+ ) # TODO(nhaber): get 256 dynamically
video, metadata = mm_items["video"][item_idx]
num_patches = video_num_patches[item_idx]
if num_patches is not None:
@@ -1633,9 +1625,6 @@ class NemotronH_Nano_VL_V2(
patch_size = config.patch_size
self.patch_size = patch_size
self.template = config.template
- self.num_image_token = int(
- (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
- )
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
self.image_tag_type = config.image_tag_type
@@ -2153,33 +2142,6 @@ class NemotronH_Nano_VL_V2(
if save_to_file and sys.stdout != original_stdout:
sys.stdout = original_stdout
- def get_model_info(self):
- """
- Get basic model information as a dictionary.
- """
- total_params = sum(p.numel() for p in self.parameters())
-
- component_info = {}
- for name, param in self.named_parameters():
- component = name.split(".")[0]
- if component not in component_info:
- component_info[component] = {"params": 0, "size": 0}
- component_info[component]["params"] += 1
- component_info[component]["size"] += param.numel()
-
- return {
- "model_name": "NemotronH_Nano_VL_V2",
- "total_parameters": total_params,
- "memory_estimate_mb": total_params * 2 / (1024**2), # bfloat16
- "components": component_info,
- "config": {
- "image_size": getattr(self.config, "force_image_size", None),
- "patch_size": getattr(self.config, "patch_size", None),
- "num_image_token": self.num_image_token,
- "downsample_ratio": self.downsample_ratio,
- },
- }
-
def get_vit_model_from_radio_config(self, hf_config):
hf_config_vision = hf_config.vision_config
model_name = hf_config_vision.args.get("model")
From 52e5e55a19436300877a36c518ac390088973c01 Mon Sep 17 00:00:00 2001
From: Netanel Haber
Date: Mon, 22 Dec 2025 03:41:37 -0800
Subject: [PATCH 07/10] it runs at least :shrug:
---
.../model_executor/models/nano_nemotron_vl.py | 308 ++++++++++--------
1 file changed, 171 insertions(+), 137 deletions(-)
diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
index 3bac36744ef5e..ad5a57c511fe8 100644
--- a/vllm/model_executor/models/nano_nemotron_vl.py
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -10,7 +10,6 @@
import copy
import math
import random
-import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
@@ -24,6 +23,7 @@ import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType
+from typing_extensions import assert_never
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
@@ -62,7 +62,6 @@ from vllm.multimodal.inputs import (
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
- ImageSize,
MultiModalDataItems,
MultiModalDataParser,
)
@@ -91,6 +90,7 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# TODO(nhaber): get 2048 from config
# TODO(nhaber): does use_thumbnail=True work?
+# TODO(nhaber): mixing images and videos will mess up the "text_prompt_length" calculation.
IMG_START = "
"
@@ -102,6 +102,46 @@ IMG_CONTEXT = ""
DEFAULT_NUM_TILES = 12
+@dataclass(kw_only=True, frozen=True)
+class Dims:
+ height: int
+ width: int
+
+
+CONV_MERGING = False # This is assumed to be False for now
+PIXEL_SHUFFLE = True # This is assumed to be True for now
+REDUCTION_FACTOR = 2 ** (PIXEL_SHUFFLE + CONV_MERGING)
+
+def width_and_height_for_max_num_tokens_available(
+ *,
+ target_num_tokens_post_shuffle: int,
+ patch_size: int,
+) -> Dims:
+ """
+ TODO(nhaber): optimize this so it squeezes closer to target number of tokens.
+ Calculate image dimensions that produce approximately `target` tokens after
+ pixel_shuffle.
+
+ With pixel_shuffle enabled, each 2x2 patch grid becomes 1 token, so we
+ need 4*B patches to get B tokens.
+
+ Examples:
+ >>> dims = width_and_height_for_max_num_tokens_available(B=8192, patch_size=16)
+ >>> assert dims.width, dims.height == (2880, 2880)
+ >>> assert ((dims.width // 16) * (dims.height // 16) // 4) == 8100 # tokens after shuffle
+ """
+ side_pixels = math.isqrt(target_num_tokens_post_shuffle) * REDUCTION_FACTOR * patch_size
+ assert isinstance(side_pixels, int) and side_pixels % patch_size == 0
+ return Dims(width=side_pixels, height=side_pixels)
+
+@dataclass
+class DynamicResolutionParams:
+ media: Image.Image
+ num_tiles: int
+ num_embeddings: int
+ patch_size: tuple[int, int]
+
+
class NanoNemotronVLImagePixelInputs(TensorSchema):
"""
Dimensions:
@@ -313,10 +353,10 @@ class BaseNanoNemotronVLProcessor(ABC):
self.norm_mean = torch.tensor(config.norm_mean).reshape(1, 3, 1, 1)
self.norm_std = torch.tensor(config.norm_std).reshape(1, 3, 1, 1)
- def num_image_token(self, *, image_width: int, image_height: int) -> int:
- image_size = math.sqrt(image_width * image_height)
+ def num_image_token_per_tile(self, *, tile_width: int, tile_height: int) -> int:
+ tile_size = math.sqrt(tile_width * tile_height)
num_tokens = int(
- (image_size // self.patch_size) ** 2 * (self.downsample_ratio**2)
+ (tile_size // self.patch_size) ** 2 * (self.downsample_ratio**2)
)
return num_tokens
@@ -342,7 +382,7 @@ class BaseNanoNemotronVLProcessor(ABC):
) -> int:
target_ratios = get_internvl_target_ratios(1, max_num_tiles)
- num_patches, _, _ = calculate_internvl_targets(
+ num_tiles, _, _ = calculate_internvl_targets(
orig_width=image_width,
orig_height=image_height,
target_ratios=target_ratios,
@@ -350,16 +390,16 @@ class BaseNanoNemotronVLProcessor(ABC):
use_thumbnail=self.use_thumbnail,
)
- return num_patches * self.num_image_token(
- image_width=image_width, image_height=image_height
+ return num_tiles * self.num_image_token_per_tile(
+ tile_width=image_width, tile_height=image_height
)
def _images_to_pixel_values_lst(
self,
- text: list[str],
+ text_prompt_length: int,
images: list[Image.Image],
max_num_tiles: int,
- ) -> list[torch.Tensor]:
+ ) -> tuple[list[torch.Tensor], list[int]]:
return [
image_to_pixel_values(
image,
@@ -380,8 +420,19 @@ class BaseNanoNemotronVLProcessor(ABC):
if len(images) == 0:
image_inputs = {}
else:
- pixel_values_lst = self._images_to_pixel_values_lst(
- text=text, images=images, max_num_tiles=max_num_tiles
+ assert len(text) == 1, (
+ "hf_processor is called on the output of get_dummy_text, "
+ "which should be a single string"
+ )
+ text_prompt_length = len(
+ self.tokenizer(
+ text[0].replace("", ""), add_special_tokens=False
+ )["input_ids"]
+ )
+ pixel_values_lst, token_counts = self._images_to_pixel_values_lst(
+ text_prompt_length=text_prompt_length,
+ images=images,
+ max_num_tiles=max_num_tiles,
)
image_inputs = {
"pixel_values_flat": input_conditioner(
@@ -402,12 +453,10 @@ class BaseNanoNemotronVLProcessor(ABC):
"same as the number of images"
)
- for i, pixel_values in enumerate(pixel_values_lst):
+ for i, (pixel_values, feature_size) in enumerate(
+ zip(pixel_values_lst, token_counts, strict=True)
+ ):
num_patches = pixel_values.shape[0]
- feature_size = num_patches * self.num_image_token(
- image_width=pixel_values.shape[1],
- image_height=pixel_values.shape[2],
- )
image_repl = self.get_image_repl(feature_size, num_patches)
parts[i] = parts[i].replace("", image_repl.full)
text = ["".join(parts)]
@@ -431,14 +480,6 @@ class BaseNanoNemotronVLProcessor(ABC):
raise NotImplementedError
-@dataclass
-class DynamicResolutionParams:
- media: Image.Image
- num_tiles: int
- num_embeddings: int
- patch_size: tuple[int, int]
-
-
class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
@@ -448,6 +489,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
config: PretrainedConfig,
tokenizer: TokenizerLike,
*args,
+ max_model_len: int,
max_num_tiles: int | None = None,
min_num_patches: int = 4,
factor_max: float = 1.0,
@@ -463,6 +505,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
super().__init__(
config=config, tokenizer=tokenizer, max_num_tiles=max_num_tiles, **kwargs
)
+ self.max_model_len = max_model_len
+
self._min_num_patches = min_num_patches
self._factor_max = factor_max
self._pixel_shuffle = pixel_shuffle
@@ -483,6 +527,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
self.norm_std = torch.tensor(self.CLIP_PIXEL_STD).reshape(1, 3, 1, 1)
self.downsample_ratio = 2 if pixel_shuffle else 1
+ feature_size_cache: dict[
+ Image.Image, int
+ ] = {} # TODO(nhaber): Find a less silly way of doing this... Why can't this be a class variable?
+
def apply_params(self, params: DynamicResolutionParams) -> torch.Tensor:
resized_img = params.media.resize(
(
@@ -515,7 +563,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
num_tokens_available: int,
data_augment: bool = False,
tiling_augment_prob: float = 0.4,
- ) -> DynamicResolutionParams:
+ ) -> tuple[DynamicResolutionParams, int]:
"""Process a single media item and return its parameters.
Args:
media: The media item to process
@@ -531,8 +579,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
)
orig_width, orig_height = media.width, media.height
- closest_patch_height = round(orig_height / self.patch_size + 0.5)
- closest_patch_width = round(orig_width / self.patch_size + 0.5)
+ closest_patch_height = math.ceil(
+ orig_height / self.patch_size
+ ) # TODO(nhaber): Ask Tyler - the previous round + 0.5 code is dangerous [banker's rounding], no? If we flip this back to the round, the max_wh_fill_budget needs to do -1 for each of w;h to be safe
+ closest_patch_width = math.ceil(orig_width / self.patch_size)
patches = closest_patch_height * closest_patch_width
factor = min(
@@ -660,8 +710,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
)
# Calculate embeddings for the main dynamic resolution image
- num_embeddings = self.num_image_token(
- image_width=target_patch_width, image_height=target_patch_height
+ num_embeddings_per_tile = self.num_image_token_per_tile(
+ tile_width=target_patch_width, tile_height=target_patch_height
)
token_count = target_patch_width * target_patch_height
@@ -681,8 +731,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
if area_ratio < self._thumbnail_area_threshold:
num_tiles += 1 # Add 1 for thumbnail
# Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
- num_embeddings += self.num_image_token(
- image_width=self._thumbnail_size, image_height=self._thumbnail_size
+ num_embeddings += self.num_image_token_per_tile(
+ tile_width=self._thumbnail_size, tile_height=self._thumbnail_size
)
token_count += (
self._thumbnail_size
@@ -694,7 +744,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
return DynamicResolutionParams(
media=media,
num_tiles=num_tiles,
- num_embeddings=num_embeddings,
+ num_embeddings=num_embeddings_per_tile,
patch_size=(target_patch_width, target_patch_height),
), token_count
@@ -748,7 +798,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
media_list: list[Image.Image],
num_tokens_available: int | None = None,
data_augment: bool = False,
- ) -> list[DynamicResolutionParams]:
+ ) -> tuple[list[DynamicResolutionParams], list[int]]:
"""Compute parameters for all media with iterative token budgeting.
Args:
@@ -782,8 +832,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
# changes, I want to make sure we don't get stuck in an infinite loop.
for _ in range(10):
# Step 1: Process each media with current token budget
- params = []
- token_counts = []
+ params: list[DynamicResolutionParams] = []
+ token_counts: list[int] = []
for media, tokens_for_media in zip(
media_list, num_tokens_available_per_media
@@ -799,7 +849,12 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
if total_tokens <= num_tokens_available:
# We're within budget, return the params
- return params
+ # Convert from patch count to actual token count after downsampling
+ divisor = (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1)
+ adjusted_token_counts = [tc // divisor for tc in token_counts]
+ for param, feature_size in zip(params, adjusted_token_counts, strict=True):
+ self.feature_size_cache[id(param.media)] = feature_size
+ return params, adjusted_token_counts
# Step 3: We're over budget, need to scale down
# Calculate scaling factor to get under budget
@@ -828,7 +883,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
num_tokens_available_per_media = (
scaled_down_num_tokens_available_per_media
)
- return params
+ assert_never(num_tokens_available_per_media)
def stack(
self, images: list[torch.Tensor]
@@ -879,15 +934,18 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
None,
)
+ def max_num_tokens_available(self, text_prompt_length: int) -> int:
+ return self.max_model_len - text_prompt_length - 4
+
def _images_to_pixel_values_lst(
self,
- text: list[str],
+ text_prompt_length: int,
images: list[Image.Image],
max_num_tiles: int,
- ) -> list[torch.Tensor]:
- num_tokens_available = 2048 - len(text) - 4
- params_per_image = self.compute_params(
- images, num_tokens_available=num_tokens_available
+ ) -> tuple[list[torch.Tensor], list[int]]:
+ num_tokens_available = self.max_num_tokens_available(text_prompt_length)
+ params_per_image, feature_sizes = self.compute_params(
+ images, num_tokens_available
)
images = []
for param in params_per_image:
@@ -895,17 +953,12 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
if t.ndim == 3:
t = t.unsqueeze(0)
images.append(t)
- return images
+ return images, feature_sizes
- def __str__(self):
- return f"DynamicResolutionImageTransform(\
- min_num_patches={self._min_num_patches}, \
- patch_size={self.patch_size}, \
- pixel_shuffle={self._pixel_shuffle}, \
- conv_merging={self._conv_merging}, \
- use_thumbnail={self._use_thumbnail}, \
- thumbnail_size={self._thumbnail_size}, \
- thumbnail_area_threshold={self._thumbnail_area_threshold})"
+ def get_cached_feature_size(self, image: Image.Image) -> int:
+ feature_size = self.feature_size_cache[id(image)]
+ del self.feature_size_cache[id(image)]
+ return feature_size
class NanoNemotronVLProcessor(DynamicResolutionImageTiler):
@@ -920,6 +973,7 @@ class NanoNemotronVLProcessor(DynamicResolutionImageTiler):
config: PretrainedConfig,
tokenizer: TokenizerLike,
*,
+ max_model_len: int,
max_num_tiles: int | None = None,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
@@ -930,6 +984,7 @@ class NanoNemotronVLProcessor(DynamicResolutionImageTiler):
super().__init__(
config=config,
tokenizer=tokenizer,
+ max_model_len=max_model_len,
max_num_tiles=max_num_tiles,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
@@ -1205,7 +1260,7 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
**kwargs: object,
- ) -> BaseNanoNemotronVLProcessor:
+ ) -> DynamicResolutionImageTiler:
raise NotImplementedError
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
@@ -1228,31 +1283,6 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
max_num_tiles=max_num_tiles,
)
- def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize:
- processor = self.get_hf_processor()
-
- base_size = processor.image_size
- target_ratios = get_internvl_target_ratios(1, max_num_tiles)
-
- largest_feature_size, largest_feature_pinpoint = 0, None
- for wr, hr in target_ratios:
- width, height = base_size * wr, base_size * hr
-
- feat_size = self.get_num_image_tokens(
- image_width=width,
- image_height=height,
- max_num_tiles=max_num_tiles,
- processor=processor,
- )
- if feat_size > largest_feature_size:
- largest_feature_size = feat_size
- largest_feature_pinpoint = ImageSize(width=width, height=height)
-
- if largest_feature_size == 0 or largest_feature_pinpoint is None:
- raise ValueError("Cannot have a largest feature size of 0!")
-
- return largest_feature_pinpoint
-
def get_max_image_tokens(self) -> int:
processor = self.get_hf_processor()
# Use default max_num_tiles for max tokens calculation
@@ -1277,7 +1307,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
@property
def supports_video(self):
- return self.get_hf_processor().supports_video
+ return False # TODO(nhaber): add video support
def get_supported_mm_limits(self):
video_limit = {"video": None} if self.supports_video else {}
@@ -1300,8 +1330,10 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
processor = self.get_hf_processor() # we get the CustomProcessor here
max_image_tokens = self.get_max_image_tokens() * max_images
- max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token(
- image_width=256, image_height=256
+ max_total_frames = (
+ seq_len - max_image_tokens
+ ) // processor.num_image_token_per_tile(
+ tile_width=256, tile_height=256
) # TODO(nhaber): get 256 dynamically
max_frames_per_video = max_total_frames // max(max_videos, 1)
return max(max_frames_per_video, 1)
@@ -1313,6 +1345,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
tokenizer=self.get_tokenizer(),
video_token=self.get_video_token(),
video_pruning_rate=self.get_video_pruning_rate(),
+ max_model_len=self.ctx.model_config.max_model_len,
**kwargs,
)
@@ -1362,17 +1395,8 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
else:
- image_size = images.get_image_size(item_idx)
- # Extract max_num_tiles from kwargs, default to 12
- max_num_tiles = hf_processor_mm_kwargs.get(
- "max_num_tiles", hf_processor.max_num_tiles
- )
- feature_size = self.info.get_num_image_tokens(
- image_width=image_size.width,
- image_height=image_size.height,
- max_num_tiles=max_num_tiles,
- processor=hf_processor,
- )
+ image = images.get(item_idx)
+ feature_size = hf_processor.get_cached_feature_size(image)
num_patches = None
local_image_num_patches = image_num_patches
@@ -1447,8 +1471,8 @@ class NanoNemotronVLMultiModalProcessor(
video_num_patches = []
def get_video_replacement_internvl(item_idx: int):
- feature_size = hf_processor.num_image_token(
- image_width=256, image_height=256
+ feature_size = hf_processor.num_image_token_per_tile(
+ tile_width=256, tile_height=256
) # TODO(nhaber): get 256 dynamically
video, metadata = mm_items["video"][item_idx]
num_patches = video_num_patches[item_idx]
@@ -1510,19 +1534,20 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
- # Use default max_num_tiles for dummy data generation
- max_num_tiles = 12
- target_width, target_height = self.info.get_image_size_with_most_features(
- max_num_tiles
- )
num_images = mm_counts.get("image", 0)
+ processor = self.info.get_hf_processor()
+ B = processor.max_num_tokens_available(text_prompt_length=num_images)
+ target_dims = width_and_height_for_max_num_tokens_available(
+ target_num_tokens_post_shuffle=B,
+ patch_size=processor.patch_size,
+ )
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
- width=target_width,
- height=target_height,
+ width=target_dims.width,
+ height=target_dims.height,
num_images=num_images,
overrides=image_overrides,
)
@@ -1672,33 +1697,36 @@ class NemotronH_Nano_VL_V2(
IMG_CONTEXT, add_special_tokens=False
)
- def pixel_shuffle(self, x, scale_factor=0.5):
- n, w, h, c = x.size()
- # N, W, H, C --> N, W, H * scale, C // scale
- x = x.view(
- n,
- w,
- int(h * scale_factor),
- int(c / scale_factor),
- )
- # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
- x = x.permute(0, 2, 1, 3).contiguous()
- # N, H * scale, W, C // scale -->
- # N, H * scale, W * scale, C // (scale ** 2)
- x = x.view(
- n,
- int(h * scale_factor),
- int(w * scale_factor),
- int(c / (scale_factor * scale_factor)),
- )
- if self.ps_version == "v1":
- warnings.warn(
- "In ps_version 'v1', the height and width have not "
- "been swapped back, which results in a transposed image.",
- stacklevel=2,
+ def pixel_shuffle_dynamic_res(self, x, *, imgs_sizes):
+ scale_factor = self.downsample_ratio
+ patch_dim = self.patch_size
+ seq_lens = torch.prod(imgs_sizes // patch_dim, dim=-1)
+ splits = torch.split(x, seq_lens.tolist(), dim=-2)
+ out = []
+ for i, sv in enumerate(splits):
+ h = imgs_sizes[i][0] // patch_dim
+ w = imgs_sizes[i][1] // patch_dim
+ sv = sv.reshape(sv.shape[0], h, w, -1)
+
+ n, h, w, c = sv.size()
+
+ sv = sv.view(n, h, int(w * scale_factor), int(c / scale_factor))
+ sv = sv.permute(0, 2, 1, 3).contiguous()
+ sv = sv.view(
+ n,
+ int(w * scale_factor),
+ int(h * scale_factor),
+ int(c / (scale_factor * scale_factor)),
)
- else:
- x = x.permute(0, 2, 1, 3).contiguous()
+
+ if self.ps_version == "v2":
+ sv = sv.permute(0, 2, 1, 3).contiguous()
+
+ sv = sv.reshape(sv.shape[0], -1, sv.shape[-1])
+ out.append(sv)
+
+ x = torch.cat(out, dim=-2)
+
return x
def extract_feature(self, pixel_values):
@@ -1710,16 +1738,22 @@ class NemotronH_Nano_VL_V2(
n = pixel_values.shape[0]
vit_embeds_list = []
for i in range(0, n, micro_batch_size):
- vit_embeds = self.vision_model(pixel_values[i : i + micro_batch_size])
+ current = pixel_values[i : i + micro_batch_size]
+ vit_embeds = self.vision_model(current)
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
- h = w = int(vit_embeds.shape[1] ** 0.5)
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
- vit_embeds = self.pixel_shuffle(
- vit_embeds, scale_factor=self.downsample_ratio
- )
- vit_embeds = vit_embeds.reshape(
- vit_embeds.shape[0], -1, vit_embeds.shape[-1]
- )
+
+ # pixel_shuffle_dynamic_res expects patches concatenated along dim=-2,
+ # but vision model outputs (batch, patches, hidden). Process each image
+ # individually to handle this correctly.
+ _, _, h, w = current.shape
+ shuffled_embeds = []
+ for j in range(vit_embeds.shape[0]):
+ single_embed = vit_embeds[j : j + 1] # (1, patches, hidden)
+ single_shuffled = self.pixel_shuffle_dynamic_res(
+ single_embed, imgs_sizes=torch.tensor([(h, w)])
+ )
+ shuffled_embeds.append(single_shuffled)
+ vit_embeds = torch.cat(shuffled_embeds, dim=0)
vit_embeds = self.mlp1(vit_embeds)
vit_embeds_list.append(vit_embeds)
From 2a7ea9ba37d5c70b9913f7eae5caeac2c44ffec8 Mon Sep 17 00:00:00 2001
From: Netanel Haber
Date: Mon, 22 Dec 2025 06:27:32 -0800
Subject: [PATCH 08/10] add logs
---
.../model_executor/models/nano_nemotron_vl.py | 47 ++++++++++++-------
1 file changed, 30 insertions(+), 17 deletions(-)
diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
index ad5a57c511fe8..cc9a839b87ef5 100644
--- a/vllm/model_executor/models/nano_nemotron_vl.py
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -112,6 +112,13 @@ CONV_MERGING = False # This is assumed to be False for now
PIXEL_SHUFFLE = True # This is assumed to be True for now
REDUCTION_FACTOR = 2 ** (PIXEL_SHUFFLE + CONV_MERGING)
+def num_image_token_per_tile(*, tile_dims: Dims, patch_size: int, downsample_ratio: int) -> int:
+ tile_size = math.sqrt(tile_dims.width * tile_dims.height)
+ num_tokens = int(
+ (tile_size // patch_size) ** 2 * (downsample_ratio**2)
+ )
+ return num_tokens
+
def width_and_height_for_max_num_tokens_available(
*,
target_num_tokens_post_shuffle: int,
@@ -129,6 +136,7 @@ def width_and_height_for_max_num_tokens_available(
>>> dims = width_and_height_for_max_num_tokens_available(B=8192, patch_size=16)
>>> assert dims.width, dims.height == (2880, 2880)
>>> assert ((dims.width // 16) * (dims.height // 16) // 4) == 8100 # tokens after shuffle
+ >>> assert num_image_token_per_tile(tile_dims=dims, patch_size=16, downsample_ratio=2) == 8100
"""
side_pixels = math.isqrt(target_num_tokens_post_shuffle) * REDUCTION_FACTOR * patch_size
assert isinstance(side_pixels, int) and side_pixels % patch_size == 0
@@ -353,13 +361,6 @@ class BaseNanoNemotronVLProcessor(ABC):
self.norm_mean = torch.tensor(config.norm_mean).reshape(1, 3, 1, 1)
self.norm_std = torch.tensor(config.norm_std).reshape(1, 3, 1, 1)
- def num_image_token_per_tile(self, *, tile_width: int, tile_height: int) -> int:
- tile_size = math.sqrt(tile_width * tile_height)
- num_tokens = int(
- (tile_size // self.patch_size) ** 2 * (self.downsample_ratio**2)
- )
- return num_tokens
-
@property
@abstractmethod
def image_token_id(self) -> int:
@@ -390,8 +391,10 @@ class BaseNanoNemotronVLProcessor(ABC):
use_thumbnail=self.use_thumbnail,
)
- return num_tiles * self.num_image_token_per_tile(
- tile_width=image_width, tile_height=image_height
+ return num_tiles * num_image_token_per_tile(
+ tile_dims=Dims(width=image_width, height=image_height),
+ patch_size=self.patch_size,
+ downsample_ratio=self.downsample_ratio
)
def _images_to_pixel_values_lst(
@@ -710,8 +713,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
)
# Calculate embeddings for the main dynamic resolution image
- num_embeddings_per_tile = self.num_image_token_per_tile(
- tile_width=target_patch_width, tile_height=target_patch_height
+ num_embeddings_per_tile = num_image_token_per_tile(
+ tile_dims=Dims(width=target_patch_width, height=target_patch_height),
+ patch_size=self.patch_size,
+ downsample_ratio=self.downsample_ratio
)
token_count = target_patch_width * target_patch_height
@@ -731,8 +736,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
if area_ratio < self._thumbnail_area_threshold:
num_tiles += 1 # Add 1 for thumbnail
# Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
- num_embeddings += self.num_image_token_per_tile(
- tile_width=self._thumbnail_size, tile_height=self._thumbnail_size
+ num_embeddings_per_tile += num_image_token_per_tile(
+ tile_dims=Dims(width=self._thumbnail_size, height=self._thumbnail_size),
+ patch_size=self.patch_size,
+ downsample_ratio=self.downsample_ratio
)
token_count += (
self._thumbnail_size
@@ -947,6 +954,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
params_per_image, feature_sizes = self.compute_params(
images, num_tokens_available
)
+ print(f"{feature_sizes=}")
+ print(f"{params_per_image=}")
images = []
for param in params_per_image:
t = self.apply_params(param)
@@ -1332,8 +1341,10 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = (
seq_len - max_image_tokens
- ) // processor.num_image_token_per_tile(
- tile_width=256, tile_height=256
+ ) // num_image_token_per_tile(
+ tile_dims=Dims(width=256, height=256),
+ patch_size=processor.patch_size,
+ downsample_ratio=processor.downsample_ratio
) # TODO(nhaber): get 256 dynamically
max_frames_per_video = max_total_frames // max(max_videos, 1)
return max(max_frames_per_video, 1)
@@ -1471,8 +1482,10 @@ class NanoNemotronVLMultiModalProcessor(
video_num_patches = []
def get_video_replacement_internvl(item_idx: int):
- feature_size = hf_processor.num_image_token_per_tile(
- tile_width=256, tile_height=256
+ feature_size = num_image_token_per_tile(
+ tile_dims=Dims(width=256, height=256),
+ patch_size=hf_processor.patch_size,
+ downsample_ratio=hf_processor.downsample_ratio
) # TODO(nhaber): get 256 dynamically
video, metadata = mm_items["video"][item_idx]
num_patches = video_num_patches[item_idx]
From eac0271b0fe3774281bdeef27c057fde6c263bc4 Mon Sep 17 00:00:00 2001
From: Netanel Haber
Date: Tue, 23 Dec 2025 04:14:41 -0800
Subject: [PATCH 09/10] get num_embeddings from params + cleanup + minimize
diff + ruff
---
.../model_executor/models/nano_nemotron_vl.py | 298 +++++++++---------
1 file changed, 150 insertions(+), 148 deletions(-)
diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
index cc9a839b87ef5..2a41d43ab9660 100644
--- a/vllm/model_executor/models/nano_nemotron_vl.py
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -88,7 +88,6 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
-# TODO(nhaber): get 2048 from config
# TODO(nhaber): does use_thumbnail=True work?
# TODO(nhaber): mixing images and videos will mess up the "text_prompt_length" calculation.
@@ -102,28 +101,20 @@ IMG_CONTEXT = ""
DEFAULT_NUM_TILES = 12
-@dataclass(kw_only=True, frozen=True)
-class Dims:
- height: int
- width: int
-
-
-CONV_MERGING = False # This is assumed to be False for now
-PIXEL_SHUFFLE = True # This is assumed to be True for now
-REDUCTION_FACTOR = 2 ** (PIXEL_SHUFFLE + CONV_MERGING)
-
-def num_image_token_per_tile(*, tile_dims: Dims, patch_size: int, downsample_ratio: int) -> int:
- tile_size = math.sqrt(tile_dims.width * tile_dims.height)
- num_tokens = int(
- (tile_size // patch_size) ** 2 * (downsample_ratio**2)
- )
+def num_image_token_per_tile(
+ *, width: int, height: int, patch_size: int, downsample_ratio: int
+) -> int:
+ tile_size = math.sqrt((width // patch_size) * (height // patch_size))
+ num_tokens = int(tile_size**2 // (downsample_ratio**2))
return num_tokens
+
def width_and_height_for_max_num_tokens_available(
*,
target_num_tokens_post_shuffle: int,
patch_size: int,
-) -> Dims:
+ downsample_ratio: int,
+) -> tuple[int, int]:
"""
TODO(nhaber): optimize this so it squeezes closer to target number of tokens.
Calculate image dimensions that produce approximately `target` tokens after
@@ -133,14 +124,26 @@ def width_and_height_for_max_num_tokens_available(
need 4*B patches to get B tokens.
Examples:
- >>> dims = width_and_height_for_max_num_tokens_available(B=8192, patch_size=16)
- >>> assert dims.width, dims.height == (2880, 2880)
- >>> assert ((dims.width // 16) * (dims.height // 16) // 4) == 8100 # tokens after shuffle
- >>> assert num_image_token_per_tile(tile_dims=dims, patch_size=16, downsample_ratio=2) == 8100
+ >>> width, height = width_and_height_for_max_num_tokens_available(
+ ... target_num_tokens_post_shuffle=8192,
+ ... patch_size=16,
+ ... downsample_ratio=2,
+ ... )
+ >>> assert width, height == (2880, 2880)
+ >>> assert (width // 16) * (height // 16) // 2**2 == 8100 # tokens post-shuffle
+ >>> assert (
+ ... num_image_token_per_tile(
+ ... width=width, height=height, patch_size=16, downsample_ratio=2
+ ... )
+ ... == 8100
+ ... )
"""
- side_pixels = math.isqrt(target_num_tokens_post_shuffle) * REDUCTION_FACTOR * patch_size
+ side_pixels = (
+ math.isqrt(target_num_tokens_post_shuffle) * downsample_ratio * patch_size
+ )
assert isinstance(side_pixels, int) and side_pixels % patch_size == 0
- return Dims(width=side_pixels, height=side_pixels)
+ return side_pixels, side_pixels
+
@dataclass
class DynamicResolutionParams:
@@ -354,7 +357,7 @@ class BaseNanoNemotronVLProcessor(ABC):
self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES
image_size: int = config.force_image_size
self.patch_size: int = getattr(config, "patch_size", 16)
- self.downsample_ratio: float = self.config.downsample_ratio
+ # self.downsample_ratio: float = self.config.downsample_ratio
self.image_size = image_size
self.use_thumbnail: bool = config.use_thumbnail
@@ -392,9 +395,10 @@ class BaseNanoNemotronVLProcessor(ABC):
)
return num_tiles * num_image_token_per_tile(
- tile_dims=Dims(width=image_width, height=image_height),
+ width=image_width,
+ height=image_height,
patch_size=self.patch_size,
- downsample_ratio=self.downsample_ratio
+ downsample_ratio=self.downsample_ratio,
)
def _images_to_pixel_values_lst(
@@ -508,8 +512,9 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
super().__init__(
config=config, tokenizer=tokenizer, max_num_tiles=max_num_tiles, **kwargs
)
- self.max_model_len = max_model_len
+ self._patch_size: int = getattr(config, "patch_size", 16)
+ self.max_model_len = max_model_len
self._min_num_patches = min_num_patches
self._factor_max = factor_max
self._pixel_shuffle = pixel_shuffle
@@ -518,47 +523,90 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
self._use_thumbnail = use_thumbnail
self._thumbnail_size = thumbnail_size
self._thumbnail_area_threshold = thumbnail_area_threshold
+ self.norm_mean = torch.tensor(self.CLIP_PIXEL_MEAN).reshape(1, 3, 1, 1)
+ self.norm_std = torch.tensor(self.CLIP_PIXEL_STD).reshape(1, 3, 1, 1)
self._transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)),
+ # T.Normalize(mean=pixel_mean, std=pixel_std), - This is done down below with input_conditioner
]
)
self._apply_data_augment = apply_data_augment
+ reduction_factor = 1 / self.config.downsample_ratio
+ assert reduction_factor == 2.0, (
+ "I don't understand what's going on if this isn't 4"
+ )
+ self.downsample_ratio = int(reduction_factor) ** (pixel_shuffle + conv_merging)
+ assert self.downsample_ratio == 2, (
+ f"I don't understand what's going on if {self.downsample_ratio=} isn't 2"
+ )
- self.norm_mean = torch.tensor(self.CLIP_PIXEL_MEAN).reshape(1, 3, 1, 1)
- self.norm_std = torch.tensor(self.CLIP_PIXEL_STD).reshape(1, 3, 1, 1)
- self.downsample_ratio = 2 if pixel_shuffle else 1
+ def _get_num_embeddings(self, width: int, height: int) -> int:
+ return num_image_token_per_tile(
+ width=width,
+ height=height,
+ patch_size=self._patch_size,
+ downsample_ratio=self.downsample_ratio,
+ )
+
+ def max_num_tokens_available(self, text_prompt_length: int) -> int:
+ return self.max_model_len - text_prompt_length - 4
+
+ def _images_to_pixel_values_lst(
+ self,
+ text_prompt_length: int,
+ images: list[Image.Image],
+ max_num_tiles: int,
+ ) -> tuple[list[torch.Tensor], list[int]]:
+ num_tokens_available = self.max_num_tokens_available(text_prompt_length)
+ params_per_image = self.compute_params(images, num_tokens_available)
+
+ feature_sizes = []
+ images = []
+ for param in params_per_image:
+ for t in self.apply_params(param):
+ if t.ndim == 3:
+ t = t.unsqueeze(0)
+ images.append(t)
+ feature_sizes.append(param.num_embeddings)
+ print(f"{feature_sizes=}")
+ print(f"{params_per_image=}")
+ return images, feature_sizes
feature_size_cache: dict[
Image.Image, int
- ] = {} # TODO(nhaber): Find a less silly way of doing this... Why can't this be a class variable?
+ ] = {} # TODO(nhaber): Find a less silly way of doing this... Why can't this be an instance variable?
- def apply_params(self, params: DynamicResolutionParams) -> torch.Tensor:
+ def get_cached_feature_size(self, image: Image.Image) -> int:
+ feature_size = self.feature_size_cache[id(image)]
+ del self.feature_size_cache[id(image)]
+ return feature_size
+
+ def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]:
resized_img = params.media.resize(
(
- params.patch_size[0] * self.patch_size,
- params.patch_size[1] * self.patch_size,
+ params.patch_size[0] * self._patch_size,
+ params.patch_size[1] * self._patch_size,
)
)
- # processed_images = [resized_img]
+ processed_images = [resized_img]
- # # Add thumbnail if enabled and image area is below threshold
- # if self._use_thumbnail:
- # # Calculate areas
- # resized_area = resized_img.size[0] * resized_img.size[1]
- # thumbnail_area = self._thumbnail_size * self._thumbnail_size
- # area_ratio = resized_area / thumbnail_area
+ # Add thumbnail if enabled and image area is below threshold
+ if self._use_thumbnail:
+ # Calculate areas
+ resized_area = resized_img.size[0] * resized_img.size[1]
+ thumbnail_area = self._thumbnail_size * self._thumbnail_size
+ area_ratio = resized_area / thumbnail_area
- # # Only add thumbnail if resized image area is less than threshold % of
- # # thumbnail area
- # if area_ratio < self._thumbnail_area_threshold:
- # thumbnail_img = params.media.resize(
- # (self._thumbnail_size, self._thumbnail_size)
- # )
- # processed_images.append(thumbnail_img)
+ # Only add thumbnail if resized image area is less than threshold % of thumbnail area
+ if area_ratio < self._thumbnail_area_threshold:
+ thumbnail_img = params.media.resize(
+ (self._thumbnail_size, self._thumbnail_size)
+ )
+ processed_images.append(thumbnail_img)
- return self._transform(resized_img)
+ return [self._transform(img) for img in processed_images]
def process_media(
self,
@@ -568,11 +616,11 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
tiling_augment_prob: float = 0.4,
) -> tuple[DynamicResolutionParams, int]:
"""Process a single media item and return its parameters.
+
Args:
media: The media item to process
num_tokens_available: Number of tokens available for this media
- data_augment: Whether to apply data augmentation to the image. Defaults to
- False.
+ data_augment: Whether to apply data augmentation to the image. Defaults to False.
Returns:
DynamicResolutionParams for the media
"""
@@ -581,11 +629,9 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
"Dynamic resolution is only supported for image media"
)
orig_width, orig_height = media.width, media.height
-
- closest_patch_height = math.ceil(
- orig_height / self.patch_size
- ) # TODO(nhaber): Ask Tyler - the previous round + 0.5 code is dangerous [banker's rounding], no? If we flip this back to the round, the max_wh_fill_budget needs to do -1 for each of w;h to be safe
- closest_patch_width = math.ceil(orig_width / self.patch_size)
+ # TODO(nhaber): Ask Tyler - the round + 0.5 code is dangerous [banker's rounding], no?
+ closest_patch_height = round(orig_height / self._patch_size + 0.5)
+ closest_patch_width = round(orig_width / self._patch_size + 0.5)
patches = closest_patch_height * closest_patch_width
factor = min(
@@ -594,8 +640,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
target_patch_height = math.floor(factor * closest_patch_height)
target_patch_width = math.floor(factor * closest_patch_width)
- # We only consider self._min_num_patches if it is greater than
- # current_num_tokens_available.
+ # We only consider self._min_num_patches if it is greater than current_num_tokens_available.
if (
current_num_tokens_available > self._min_num_patches
and target_patch_height * target_patch_width < self._min_num_patches
@@ -608,20 +653,19 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
if (
self._min_side is not None
- and min(target_patch_width, target_patch_height) * self.patch_size
+ and min(target_patch_width, target_patch_height) * self._patch_size
< self._min_side
):
if target_patch_width <= target_patch_height:
- up_factor = self._min_side / (target_patch_width * self.patch_size)
+ up_factor = self._min_side / (target_patch_width * self._patch_size)
new_patch_height = math.ceil(up_factor * target_patch_height)
new_patch_width = math.ceil(up_factor * target_patch_width)
if new_patch_height * new_patch_width > current_num_tokens_available:
- # If only one side can be min_side, make as big as possible at
- # native aspect ratio while staying below max_patches
+ # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
if (
max(current_num_tokens_available // new_patch_width, 1)
- * self.patch_size
+ * self._patch_size
< self._min_side
):
up_factor = math.sqrt(
@@ -640,16 +684,15 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
target_patch_height = new_patch_height
target_patch_width = new_patch_width
else:
- up_factor = self._min_side / (target_patch_height * self.patch_size)
+ up_factor = self._min_side / (target_patch_height * self._patch_size)
new_patch_height = math.ceil(up_factor * target_patch_height)
new_patch_width = math.ceil(up_factor * target_patch_width)
if new_patch_height * new_patch_width > current_num_tokens_available:
- # If only one side can be min_side, make as big as possible at
- # native aspect ratio while staying below max_patches
+ # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
if (
max(current_num_tokens_available // new_patch_height, 1)
- * self.patch_size
+ * self._patch_size
< self._min_side
):
up_factor = math.sqrt(
@@ -708,15 +751,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
target_patch_width, target_patch_height, current_num_tokens_available
)
- assert isinstance(media, Image.Image), (
- "Dynamic resolution is only supported for image media"
- )
-
# Calculate embeddings for the main dynamic resolution image
- num_embeddings_per_tile = num_image_token_per_tile(
- tile_dims=Dims(width=target_patch_width, height=target_patch_height),
- patch_size=self.patch_size,
- downsample_ratio=self.downsample_ratio
+ num_embeddings = self._get_num_embeddings(
+ target_patch_width * self._patch_size,
+ target_patch_height * self._patch_size,
)
token_count = target_patch_width * target_patch_height
@@ -725,33 +763,30 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
num_tiles = 1 # Base dynamic resolution image
if self._use_thumbnail:
# Calculate areas
- resized_area = (target_patch_width * self.patch_size) * (
- target_patch_height * self.patch_size
+ resized_area = (target_patch_width * self._patch_size) * (
+ target_patch_height * self._patch_size
)
thumbnail_area = self._thumbnail_size * self._thumbnail_size
area_ratio = resized_area / thumbnail_area
- # Only add thumbnail if resized image area is less than threshold % of
- # thumbnail area
+ # Only add thumbnail if resized image area is less than threshold % of thumbnail area
if area_ratio < self._thumbnail_area_threshold:
num_tiles += 1 # Add 1 for thumbnail
# Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
- num_embeddings_per_tile += num_image_token_per_tile(
- tile_dims=Dims(width=self._thumbnail_size, height=self._thumbnail_size),
- patch_size=self.patch_size,
- downsample_ratio=self.downsample_ratio
+ num_embeddings += self._get_num_embeddings(
+ self._thumbnail_size, self._thumbnail_size
)
token_count += (
self._thumbnail_size
- // self.patch_size
+ // self._patch_size
* self._thumbnail_size
- // self.patch_size
+ // self._patch_size
)
return DynamicResolutionParams(
media=media,
num_tiles=num_tiles,
- num_embeddings=num_embeddings_per_tile,
+ num_embeddings=num_embeddings,
patch_size=(target_patch_width, target_patch_height),
), token_count
@@ -805,7 +840,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
media_list: list[Image.Image],
num_tokens_available: int | None = None,
data_augment: bool = False,
- ) -> tuple[list[DynamicResolutionParams], list[int]]:
+ ) -> list[DynamicResolutionParams]:
"""Compute parameters for all media with iterative token budgeting.
Args:
@@ -821,26 +856,24 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
* (4 if self._pixel_shuffle else 1)
* (4 if self._conv_merging else 1)
)
- # When the number of available token is too small, allow self._min_num_patches
- # per media and let the sample be truncated.
+ # When the number of available token is too small, allow self._min_num_patches per media and
+ # let the sample be truncated.
num_tokens_available = max(
num_tokens_available, self._min_num_patches * len(media_list)
)
- # Clip the number of tokens available per media to be between min and max
- # patches.
+ # Clip the number of tokens available per media to be between min and max patches.
num_tokens_available_per_media = [
max(num_tokens_available, self._min_num_patches)
for _ in range(len(media_list))
]
- # In theory this could be a while True loop, but in case the process_media
- # method slightly
+ # In theory this could be a while True loop, but in case the process_media method slightly
# changes, I want to make sure we don't get stuck in an infinite loop.
for _ in range(10):
# Step 1: Process each media with current token budget
- params: list[DynamicResolutionParams] = []
- token_counts: list[int] = []
+ params = []
+ token_counts = []
for media, tokens_for_media in zip(
media_list, num_tokens_available_per_media
@@ -850,18 +883,14 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
)
params.append(param)
token_counts.append(token_count)
+ self.feature_size_cache[id(param.media)] = param.num_embeddings
# Step 2: Check if total tokens is within budget
total_tokens = sum(token_counts)
if total_tokens <= num_tokens_available:
# We're within budget, return the params
- # Convert from patch count to actual token count after downsampling
- divisor = (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1)
- adjusted_token_counts = [tc // divisor for tc in token_counts]
- for param, feature_size in zip(params, adjusted_token_counts, strict=True):
- self.feature_size_cache[id(param.media)] = feature_size
- return params, adjusted_token_counts
+ return params
# Step 3: We're over budget, need to scale down
# Calculate scaling factor to get under budget
@@ -880,8 +909,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
for i in range(len(num_tokens_available_per_media))
]
)
- # If there was not scaling down, we're stuck just use min_num_patches per
- # media, else try with the scaled down num_tokens_available_per_media.
+ # If there was not scaling down, we're stuck just use min_num_patches per media, else
+ # try with the scaled down num_tokens_available_per_media.
if not scaled_down:
num_tokens_available_per_media = [self._min_num_patches] * len(
media_list
@@ -900,15 +929,15 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
)
def rearrange_img(x):
- py = x.shape[-2] // self.patch_size
- px = x.shape[-1] // self.patch_size
+ py = x.shape[-2] // self._patch_size
+ px = x.shape[-1] // self._patch_size
x = einops.rearrange(
x,
"c (py yy) (px xx) -> (py px) (c yy xx)",
py=py,
- yy=self.patch_size,
+ yy=self._patch_size,
px=px,
- xx=self.patch_size,
+ xx=self._patch_size,
)
return x
@@ -941,34 +970,6 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
None,
)
- def max_num_tokens_available(self, text_prompt_length: int) -> int:
- return self.max_model_len - text_prompt_length - 4
-
- def _images_to_pixel_values_lst(
- self,
- text_prompt_length: int,
- images: list[Image.Image],
- max_num_tiles: int,
- ) -> tuple[list[torch.Tensor], list[int]]:
- num_tokens_available = self.max_num_tokens_available(text_prompt_length)
- params_per_image, feature_sizes = self.compute_params(
- images, num_tokens_available
- )
- print(f"{feature_sizes=}")
- print(f"{params_per_image=}")
- images = []
- for param in params_per_image:
- t = self.apply_params(param)
- if t.ndim == 3:
- t = t.unsqueeze(0)
- images.append(t)
- return images, feature_sizes
-
- def get_cached_feature_size(self, image: Image.Image) -> int:
- feature_size = self.feature_size_cache[id(image)]
- del self.feature_size_cache[id(image)]
- return feature_size
-
class NanoNemotronVLProcessor(DynamicResolutionImageTiler):
"""
@@ -1339,12 +1340,11 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
processor = self.get_hf_processor() # we get the CustomProcessor here
max_image_tokens = self.get_max_image_tokens() * max_images
- max_total_frames = (
- seq_len - max_image_tokens
- ) // num_image_token_per_tile(
- tile_dims=Dims(width=256, height=256),
- patch_size=processor.patch_size,
- downsample_ratio=processor.downsample_ratio
+ max_total_frames = (seq_len - max_image_tokens) // num_image_token_per_tile(
+ width=256,
+ height=256,
+ patch_size=processor._patch_size,
+ downsample_ratio=processor.downsample_ratio,
) # TODO(nhaber): get 256 dynamically
max_frames_per_video = max_total_frames // max(max_videos, 1)
return max(max_frames_per_video, 1)
@@ -1483,9 +1483,10 @@ class NanoNemotronVLMultiModalProcessor(
def get_video_replacement_internvl(item_idx: int):
feature_size = num_image_token_per_tile(
- tile_dims=Dims(width=256, height=256),
- patch_size=hf_processor.patch_size,
- downsample_ratio=hf_processor.downsample_ratio
+ width=256,
+ height=256,
+ patch_size=hf_processor._patch_size,
+ downsample_ratio=hf_processor.downsample_ratio,
) # TODO(nhaber): get 256 dynamically
video, metadata = mm_items["video"][item_idx]
num_patches = video_num_patches[item_idx]
@@ -1550,17 +1551,18 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
B = processor.max_num_tokens_available(text_prompt_length=num_images)
- target_dims = width_and_height_for_max_num_tokens_available(
+ target_width, target_height = width_and_height_for_max_num_tokens_available(
target_num_tokens_post_shuffle=B,
- patch_size=processor.patch_size,
+ patch_size=processor._patch_size,
+ downsample_ratio=processor.downsample_ratio,
)
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
- width=target_dims.width,
- height=target_dims.height,
+ width=target_width,
+ height=target_height,
num_images=num_images,
overrides=image_overrides,
)
From 6f1249fdb28615972f3b89234d11f3b19a0bd72a Mon Sep 17 00:00:00 2001
From: Netanel Haber
Date: Tue, 23 Dec 2025 06:27:03 -0800
Subject: [PATCH 10/10] minimize diff
---
vllm/model_executor/models/nano_nemotron_vl.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
index 2a41d43ab9660..23d9496ea37c0 100644
--- a/vllm/model_executor/models/nano_nemotron_vl.py
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -386,7 +386,7 @@ class BaseNanoNemotronVLProcessor(ABC):
) -> int:
target_ratios = get_internvl_target_ratios(1, max_num_tiles)
- num_tiles, _, _ = calculate_internvl_targets(
+ num_patches, _, _ = calculate_internvl_targets(
orig_width=image_width,
orig_height=image_height,
target_ratios=target_ratios,
@@ -394,7 +394,7 @@ class BaseNanoNemotronVLProcessor(ABC):
use_thumbnail=self.use_thumbnail,
)
- return num_tiles * num_image_token_per_tile(
+ return num_patches * num_image_token_per_tile(
width=image_width,
height=image_height,
patch_size=self.patch_size,