From 6979edb5755c320eac3e1432036c1c8242085012 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Thu, 11 Dec 2025 13:49:27 +0200 Subject: [PATCH] import --- image_processing.py | 508 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 508 insertions(+) create mode 100644 image_processing.py diff --git a/image_processing.py b/image_processing.py new file mode 100644 index 0000000000000..979ff681edc87 --- /dev/null +++ b/image_processing.py @@ -0,0 +1,508 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE. +import math + +import torch +from einops import rearrange +from torchvision import transforms as T +from torchvision.transforms import Compose +from torchvision.transforms.functional import InterpolationMode + + +IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] +IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] +SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5] +SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5] +CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] +CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] +RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060] +RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250] + + +pixel_statistics = { + "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), + "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), + "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD), + "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "internvit300M": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), + "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), +} + + +# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685 +# Copyright (c) 2023 OpenGVLab. +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def find_closest_area_weighted_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + """ + Find the best number of tiles based on the aspect ratio and the area covered by the tiles. + """ + best_factor = float('-inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + factor_based_on_area_n_ratio = ( + min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6) * + min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)) + if factor_based_on_area_n_ratio > best_factor: + best_factor = factor_based_on_area_n_ratio + best_ratio = ratio + return best_ratio + + +def process_images(sample_imgs, patch_dim, dynamic_resolution, batch_mode=False): + """Process a batch of images for multimodal training or evaluation. + + This function handles image preprocessing with support for both static and dynamic + resolution processing. For dynamic resolution, it rearranges images into patches + and computes cumulative sequence lengths for efficient batching. + + Args: + sample_imgs (List[torch.Tensor]): List of image tensors with shape (C, H, W). + patch_dim (int): Dimension of each patch (e.g., 14 for 14x14 patches). + dynamic_resolution (bool): Whether to use dynamic resolution processing. + If True, images are rearranged into patches with variable sequence lengths. + If False, images are simply stacked into a batch tensor. + batch_mode (bool, optional): Whether this is being called from training batch processing. + If True, wraps tensors in additional list dimension for consistency with batch format. + If False, returns tensors directly as used in evaluation. Defaults to False. + + Returns: + tuple: A 4-tuple containing: + - images (torch.Tensor): Processed image tensor. + For dynamic resolution: shape (1, total_patches, patch_features) if batch_mode=False, + or shape (1, total_patches, patch_features) if batch_mode=True + For static resolution: shape (batch_size, C, H, W) + - imgs_sizes (torch.Tensor): Image sizes tensor with shape (N, 2) + containing [width, height] for each image, or [[0,0]] if no images. + - vision_cu_lengths (torch.Tensor or None): Cumulative sequence lengths + for dynamic resolution. Shape (batch_size + 1,) for evaluation mode, + or shape (1, batch_size + 1) for batch mode. None for static resolution. + - vision_max_lengths (torch.Tensor or None): Maximum sequence length + among all images for dynamic resolution. Scalar tensor for evaluation mode, + or shape (1,) for batch mode. None for static resolution. + + Note: + This function is designed for processing one microbatch at a time for dynamic resolution. + """ + vision_cu_lengths = None + vision_max_lengths = None + + if len(sample_imgs) > 0: + imgs_sizes = torch.tensor([[img.shape[1], img.shape[2]] for img in sample_imgs], dtype=torch.int32) + if dynamic_resolution: + def rearrange_img(x): + py = x.shape[-2] // patch_dim + px = x.shape[-1] // patch_dim + x = rearrange(x, 'c (py yy) (px xx) -> (py px) (c yy xx)', + py=py, yy=patch_dim, + px=px, xx=patch_dim, + ) + return x + imgs = [rearrange_img(img) for img in sample_imgs] + + current_length = 0 + max_length = 0 + vision_cu_lengths = [0] + for img in imgs: + if max_length < img.shape[0]: + max_length = img.shape[0] + current_length += img.shape[0] + vision_cu_lengths.append(current_length) + + vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) + vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) + + # For batch mode, wrap in additional dimension for consistency + if batch_mode: + vision_cu_lengths = vision_cu_lengths.unsqueeze(0) # Shape: (1, batch_size + 1) + vision_max_lengths = vision_max_lengths.unsqueeze(0) # Shape: (1,) + + imgs = torch.cat(imgs, dim=0) + images = imgs.unsqueeze(0) + else: + images = torch.stack(sample_imgs) + else: + imgs_sizes = torch.tensor([[0,0]], dtype=torch.int32) + if len(sample_imgs) == 0 and batch_mode: + # For batch mode when no images, use appropriate dummy tensor + images = torch.tensor([[0]], dtype=torch.float32) + else: + images = torch.stack(sample_imgs) + + return images, imgs_sizes, vision_cu_lengths, vision_max_lengths + + +class ImageTransform: + """Image transformation.""" + + def __init__(self, input_size, vision_model_type, *, dynamic_resolution=False, res_step=16, min_num_patches=1, max_num_patches=128, pixel_shuffle=False, min_side=None, conv_merging=False, match_tiling_dynamic_resolution=False, masked_tiling_dynamic_resolution=False, thumbnail_area_threshold=0.8): + self._transform = _build_transform(input_size, vision_model_type) + self._vision_model_type = vision_model_type + self._dynamic_resolution = dynamic_resolution + self._res_step = res_step + self._min_num_patches = min_num_patches + self._max_num_patches = max_num_patches + self._pixel_shuffle = pixel_shuffle + self._min_side = min_side + self._conv_merging = conv_merging + self._match_tiling_dynamic_resolution = match_tiling_dynamic_resolution + self._masked_tiling_dynamic_resolution = masked_tiling_dynamic_resolution + self._thumbnail_area_threshold = thumbnail_area_threshold + + def __call__(self, img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, find_closest_aspect_ratio_fn=find_closest_aspect_ratio, is_video=False): + assert not augment, "Image augmentation not implemented." + if use_tiling: + assert img_h == img_w, "dynamic tiling expects equal tile height and width" + imgs = dynamic_preprocess( + img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail, + find_closest_aspect_ratio_fn=find_closest_aspect_ratio_fn) + imgs = [self._transform(img) for img in imgs] + elif self._masked_tiling_dynamic_resolution: + assert img_h == img_w, "masked tiling dynamic resolution expects equal tile height and width" + assert "radio" in self._vision_model_type, "Masked tiling dynamic resolution is only supported for radio models" + + # Use tiling logic to determine tile grid (nx, ny) + orig_width, orig_height = img.size + aspect_ratio = orig_width / orig_height + + target_ratios = set( + (i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) + if i * j <= max_num_tiles and i * j >= 1 + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + tiling = find_closest_aspect_ratio_fn( + aspect_ratio, target_ratios, orig_width, orig_height, img_h + ) + + # Resize and split into tiles of size (img_h x img_h) + target_width = img_h * tiling[0] + target_height = img_w * tiling[1] + blocks = tiling[0] * tiling[1] + + resized_img = img.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // img_h)) * img_h, + (i // (target_width // img_h)) * img_h, + ((i % (target_width // img_h)) + 1) * img_h, + ((i // (target_width // img_h)) + 1) * img_h, + ) + tile_img = resized_img.crop(box) + processed_images.append(tile_img) + assert len(processed_images) == blocks + + # Optional thumbnail + if use_thumbnail and blocks != 1: + thumbnail_img = img.resize((img_h, img_h)) + processed_images.append(thumbnail_img) + + pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std), + ]) + imgs = [transform(im) for im in processed_images] + elif self._match_tiling_dynamic_resolution: + assert img_h == img_w, "match tiling dynamic resolution expects equal tile height and width" + assert "radio" in self._vision_model_type, "Match tiling dynamic resolution is only supported for radio models" + + # Use tiling logic to determine optimal dimensions + orig_width, orig_height = img.size + aspect_ratio = orig_width / orig_height + + # Calculate target ratios (same logic as tiling) + target_ratios = set( + (i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num_tiles and i * j >= 1) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # Find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio_fn( + aspect_ratio, target_ratios, orig_width, orig_height, img_h) + + # Calculate the target width and height using tiling logic + target_width = img_h * target_aspect_ratio[0] + target_height = img_w * target_aspect_ratio[1] + + # Resize image to target dimensions (same as tiling, but don't split) + resized_img = img.resize((target_width, target_height)) + + # Process as single dynamic resolution image + pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std), + ]) + processed_images = [resized_img] + + # Add thumbnail if use_thumbnail=True and there's more than 1 tile + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + if use_thumbnail and blocks != 1: + thumbnail_img = img.resize((img_h, img_h)) + processed_images.append(thumbnail_img) + + imgs = [transform(img) for img in processed_images] + elif self._dynamic_resolution: + pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std), + ]) + processed_img = dynamic_res_preprocess(img, min_patches=self._min_num_patches, max_patches=self._max_num_patches, res_step=self._res_step, pixel_shuffle=self._pixel_shuffle, min_side=self._min_side, conv_merging=self._conv_merging, is_video=is_video) + processed_images = [processed_img] + + # Add thumbnail if enabled and image area is below threshold + if use_thumbnail: + # Calculate areas + processed_width, processed_height = processed_img.size + resized_area = processed_width * processed_height + thumbnail_area = img_h * img_h # img_h should be square thumbnail size + area_ratio = resized_area / thumbnail_area + + # Only add thumbnail if resized image area is less than threshold % of thumbnail area + if area_ratio < self._thumbnail_area_threshold: + thumbnail_img = img.resize((img_h, img_h)) # Use square thumbnail with img_h size + processed_images.append(thumbnail_img) + + imgs = [transform(img) for img in processed_images] + else: + imgs = [self._transform(img)] + + return imgs + + +# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702 +# Copyright (c) 2023 OpenGVLab. +def dynamic_preprocess( + image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, + find_closest_aspect_ratio_fn=find_closest_aspect_ratio): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio_fn( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def dynamic_res_preprocess(image, min_patches=1, max_patches=128, res_step=16, factor_max=1., pixel_shuffle=False, min_side=None, conv_merging=False, is_video=False): + """Preprocess an image with dynamic resolution for vision transformers. + + This function resizes an image to optimize the number of patches while respecting + constraints on minimum/maximum patches, minimum side length, and compatibility + with pixel shuffle or convolution merging operations. + + The algorithm works by: + 1. Computing the initial patch grid size based on the image dimensions and res_step + 2. Scaling the patch grid to fit within the max_patches constraint + 3. Ensuring the result has at least min_patches + 4. Optionally enforcing a minimum side length constraint + 5. Rounding patch dimensions to even numbers for pixel_shuffle/conv_merging compatibility + 6. Resizing the image to the computed target dimensions + + Args: + image (PIL.Image): Input image to preprocess. + min_patches (int, optional): Minimum number of patches required. Defaults to 1. + max_patches (int, optional): Maximum number of patches allowed. Defaults to 128. + res_step (int, optional): Resolution step size (patch dimension). Defaults to 16. + factor_max (float, optional): Maximum scaling factor to apply. Defaults to 1.0. + pixel_shuffle (bool, optional): Whether to ensure compatibility with pixel shuffle + operations by rounding to even patch dimensions. Defaults to False. + min_side (int, optional): Minimum side length in pixels. If specified, ensures + at least one side meets this constraint. Defaults to None. + conv_merging (bool, optional): Whether to ensure compatibility with convolution + merging by rounding to even patch dimensions. Defaults to False. + + Returns: + PIL.Image: Resized image with dimensions optimized for patch-based processing. + The output dimensions will be (target_patch_width * res_step, target_patch_height * res_step). + + Note: + The function preserves aspect ratio as much as possible while satisfying all constraints. + When constraints conflict (e.g., min_side vs max_patches), the function prioritizes + staying within max_patches while maximizing the image size. + + Example: + >>> from PIL import Image + >>> img = Image.open("example.jpg") # 800x600 image + >>> resized_img = dynamic_res_preprocess(img, min_patches=4, max_patches=64, res_step=14) + >>> # Returns image resized to maintain aspect ratio with 4-64 patches of size 14x14 + """ + orig_width, orig_height = image.size + + closest_patch_height = round(orig_height / res_step + 0.5) + closest_patch_width = round(orig_width / res_step + 0.5) + patches = closest_patch_height * closest_patch_width + + factor = min(math.sqrt(max_patches / patches), factor_max) + target_patch_height = math.floor(factor * closest_patch_height) + target_patch_width = math.floor(factor * closest_patch_width) + + if target_patch_height * target_patch_width < min_patches: + up_factor = math.sqrt(min_patches / (target_patch_height * target_patch_width)) + target_patch_height = math.ceil(up_factor * target_patch_height) + target_patch_width = math.ceil(up_factor * target_patch_width) + + if min_side is not None and min(target_patch_width, target_patch_height) * res_step < min_side: + if target_patch_width <= target_patch_height: + up_factor = min_side / (target_patch_width * res_step) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > max_patches: + # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches + if max(max_patches // new_patch_width, 1) * res_step < min_side: + up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width)) + target_patch_height = math.floor(up_factor * target_patch_height) + target_patch_width = math.floor(up_factor * target_patch_width) + target_patch_width = new_patch_width + target_patch_height = max(max_patches // new_patch_width, 1) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width + else: + up_factor = min_side / (target_patch_height * res_step) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > max_patches: + # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches + if max(max_patches // new_patch_height, 1) * res_step < min_side: + up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width)) + target_patch_height = math.floor(up_factor * target_patch_height) + target_patch_width = math.floor(up_factor * target_patch_width) + else: + target_patch_height = new_patch_height + target_patch_width = max(max_patches // new_patch_height, 1) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width + + # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging) + # or by 4 when BOTH are enabled (two successive 2x reductions) + if pixel_shuffle or conv_merging: + required_divisor = 4 if (pixel_shuffle and conv_merging) else 2 + + rem_h = target_patch_height % required_divisor + if rem_h != 0: + inc_h = required_divisor - rem_h + if (target_patch_height + inc_h) * target_patch_width <= max_patches: + target_patch_height += inc_h + else: + target_patch_height = max(1, target_patch_height - rem_h) + + rem_w = target_patch_width % required_divisor + if rem_w != 0: + inc_w = required_divisor - rem_w + if target_patch_height * (target_patch_width + inc_w) <= max_patches: + target_patch_width += inc_w + else: + target_patch_width = max(1, target_patch_width - rem_w) + assert target_patch_height * target_patch_width <= max_patches + + #TEMP: hacky way to process video same as in training + if is_video: + # max_patches = 1024 + # min_patches = 512 + target_patch_width = 32 + target_patch_height = 32 + + + # resize the image + resized_img = image.resize((target_patch_width * res_step, target_patch_height * res_step)) + + return resized_img + + + +# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 +# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 +def _build_transform(input_size, vision_model_type): + if vision_model_type in ("siglip", "internvit", "internvit300M", "radio", "radio-g", "cradio-g"): + pixel_mean, pixel_std = pixel_statistics[vision_model_type] + + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std) + ]) + elif vision_model_type == "clip": + pixel_mean, pixel_std = pixel_statistics[vision_model_type] + + transform = Compose([ + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std), + ]) + elif vision_model_type.startswith("hf://"): + from megatron.core.models.huggingface.module import get_hf_model_type + + model_type = get_hf_model_type(vision_model_type) + if "siglip" in model_type: + from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor + + processor = SiglipImageProcessor(size={"height": input_size, "width": input_size}) + + def transform(x): + x = x.convert("RGB") if x.mode != "RGB" else x + x = processor(x, return_tensors="pt") + return x["pixel_values"][0] + else: + raise NotImplementedError(f"image processing not defined for huggingface model {vision_model_type}") + else: + raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}") + + return transform