From 1bceb28678fcf6f7a5c39cd082222023b393e5e5 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Thu, 11 Dec 2025 17:03:35 +0200 Subject: [PATCH] import image processing into model_executor/models/nano_nemotron_vl.py --- .../model_executor/models/nano_nemotron_vl.py | 489 +++++++++++++++++- 1 file changed, 488 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 6dfab595e5b92..1fbbf06d0a695 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -7,10 +7,14 @@ # LICENSE is in root directory. # -------------------------------------------------------- +import math +import random +from dataclasses import dataclass + import copy import warnings from abc import ABC, abstractmethod -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Iterable, Mapping, Sequence, Callable from typing import Annotated, Any, Literal, TypeAlias, TypeVar import numpy.typing as npt @@ -20,6 +24,7 @@ import torch.nn as nn import torchvision.transforms as T from PIL import Image from transformers import BatchFeature, PretrainedConfig, TensorType +import einops from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions @@ -84,6 +89,488 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely # Alternative: Set a specific higher limit # Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels +IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] +IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] +SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5] +SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5] +CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] +CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] +RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060] +RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250] + + +pixel_statistics = { + "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), + "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), + "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD), + "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), + "radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), +} + + +@dataclass +class DynamicResolutionParams: + image: Image.Image + num_tiles: int + num_embeddings: int + patch_size: tuple[int, int] + + +class DynamicResolutionImageTilingStrategy: + def __init__( + self, + vision_model_type: str, + min_num_patches: int, + patch_size: int, + get_num_embeddings: Callable[[int, int], int], + factor_max: float = 1.0, + pixel_shuffle: bool = False, + min_side: int | None = None, + conv_merging: bool = False, + use_thumbnail: bool = False, + thumbnail_size: int = 448, + thumbnail_area_threshold: float = 0.8, + max_num_patches: int = 0, + apply_data_augment: bool = False, + ): + assert "radio" in vision_model_type, ( + "Dynamic resolution is only supported for radio models" + ) + self._vision_model_type = vision_model_type + self._min_num_patches = min_num_patches + self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf") + self._patch_size = patch_size + self._get_num_embeddings = get_num_embeddings + self._factor_max = factor_max + self._pixel_shuffle = pixel_shuffle + self._min_side = min_side + self._conv_merging = conv_merging + self._use_thumbnail = use_thumbnail + self._thumbnail_size = thumbnail_size + self._thumbnail_area_threshold = thumbnail_area_threshold + pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] + self._transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)), + T.Normalize(mean=pixel_mean, std=pixel_std), + ] + ) + self._apply_data_augment = apply_data_augment + + def apply_params( + self, params: DynamicResolutionParams, **kwargs + ) -> list[torch.Tensor]: + # resize the image + resized_img = params.image.resize( + ( + params.patch_size[0] * self._patch_size, + params.patch_size[1] * self._patch_size, + ) + ) + processed_images = [resized_img] + + # Add thumbnail if enabled and image area is below threshold + if self._use_thumbnail: + # Calculate areas + resized_area = resized_img.size[0] * resized_img.size[1] + thumbnail_area = self._thumbnail_size * self._thumbnail_size + area_ratio = resized_area / thumbnail_area + + # Only add thumbnail if resized image area is less than threshold % of thumbnail area + if area_ratio < self._thumbnail_area_threshold: + thumbnail_img = params.image.resize( + (self._thumbnail_size, self._thumbnail_size) + ) + processed_images.append(thumbnail_img) + + return [self._transform(img) for img in processed_images] + + def process_media( + self, + image: Image.Image, + num_tokens_available: int, + data_augment: bool = False, + tiling_augment_prob: float = 0.4, + ) -> DynamicResolutionParams: + """Process a single media item and return its parameters. + + Args: + media: The media item to process + num_tokens_available: Number of tokens available for this media + data_augment: Whether to apply data augmentation to the image. Defaults to False. + Returns: + DynamicResolutionParams for the media + """ + current_num_tokens_available = num_tokens_available + assert isinstance(image, Image.Image), ( + "Dynamic resolution is only supported for image media" + ) + orig_width, orig_height = image.width, image.height + + closest_patch_height = round(orig_height / self._patch_size + 0.5) + closest_patch_width = round(orig_width / self._patch_size + 0.5) + patches = closest_patch_height * closest_patch_width + + factor = min( + math.sqrt(current_num_tokens_available / patches), self._factor_max + ) + target_patch_height = math.floor(factor * closest_patch_height) + target_patch_width = math.floor(factor * closest_patch_width) + + # We only consider self._min_num_patches if it is greater than current_num_tokens_available. + if ( + current_num_tokens_available > self._min_num_patches + and target_patch_height * target_patch_width < self._min_num_patches + ): + up_factor = math.sqrt( + self._min_num_patches / (target_patch_height * target_patch_width) + ) + target_patch_height = math.ceil(up_factor * target_patch_height) + target_patch_width = math.ceil(up_factor * target_patch_width) + + if ( + self._min_side is not None + and min(target_patch_width, target_patch_height) * self._patch_size + < self._min_side + ): + if target_patch_width <= target_patch_height: + up_factor = self._min_side / (target_patch_width * self._patch_size) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > current_num_tokens_available: + # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches + if ( + max(current_num_tokens_available // new_patch_width, 1) + * self._patch_size + < self._min_side + ): + up_factor = math.sqrt( + current_num_tokens_available + / (target_patch_height * target_patch_width) + ) + target_patch_height = math.floor( + up_factor * target_patch_height + ) + target_patch_width = math.floor(up_factor * target_patch_width) + target_patch_width = new_patch_width + target_patch_height = max( + current_num_tokens_available // new_patch_width, 1 + ) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width + else: + up_factor = self._min_side / (target_patch_height * self._patch_size) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > current_num_tokens_available: + # If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches + if ( + max(current_num_tokens_available // new_patch_height, 1) + * self._patch_size + < self._min_side + ): + up_factor = math.sqrt( + current_num_tokens_available + / (target_patch_height * target_patch_width) + ) + target_patch_height = math.floor( + up_factor * target_patch_height + ) + target_patch_width = math.floor(up_factor * target_patch_width) + else: + target_patch_height = new_patch_height + target_patch_width = max( + current_num_tokens_available // new_patch_height, 1 + ) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width + + # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging) + # or by 4 when BOTH are enabled (two successive 2x reductions) + if self._pixel_shuffle or self._conv_merging: + required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2 + + rem_h = target_patch_height % required_divisor + if rem_h != 0: + inc_h = required_divisor - rem_h + if ( + target_patch_height + inc_h + ) * target_patch_width <= current_num_tokens_available: + target_patch_height += inc_h + else: + target_patch_height = max( + required_divisor, target_patch_height - rem_h + ) + + rem_w = target_patch_width % required_divisor + if rem_w != 0: + inc_w = required_divisor - rem_w + if ( + target_patch_height * (target_patch_width + inc_w) + <= current_num_tokens_available + ): + target_patch_width += inc_w + else: + target_patch_width = max( + required_divisor, target_patch_width - rem_w + ) + + if ( + data_augment + and self._apply_data_augment + and random.random() < tiling_augment_prob + ): + target_patch_width, target_patch_height = self.augment_resolution( + target_patch_width, target_patch_height, current_num_tokens_available + ) + + assert isinstance(image, Image.Image), ( + "Dynamic resolution is only supported for image media" + ) + + # Calculate embeddings for the main dynamic resolution image + num_embeddings = self._get_num_embeddings( + target_patch_width * self._patch_size, + target_patch_height * self._patch_size, + ) + + token_count = target_patch_width * target_patch_height + + # Add thumbnail embeddings if enabled and image area is below threshold + num_tiles = 1 # Base dynamic resolution image + if self._use_thumbnail: + # Calculate areas + resized_area = (target_patch_width * self._patch_size) * ( + target_patch_height * self._patch_size + ) + thumbnail_area = self._thumbnail_size * self._thumbnail_size + area_ratio = resized_area / thumbnail_area + + # Only add thumbnail if resized image area is less than threshold % of thumbnail area + if area_ratio < self._thumbnail_area_threshold: + num_tiles += 1 # Add 1 for thumbnail + # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) + num_embeddings += self._get_num_embeddings( + self._thumbnail_size, self._thumbnail_size + ) + token_count += ( + self._thumbnail_size + // self._patch_size + * self._thumbnail_size + // self._patch_size + ) + + return DynamicResolutionParams( + image=image, + num_tiles=num_tiles, + num_embeddings=num_embeddings, + patch_size=(target_patch_width, target_patch_height), + ), token_count + + def augment_resolution( + self, + target_patch_width: int, + target_patch_height: int, + current_num_tokens_available: int, + ) -> tuple[int, int]: + min_num_patch_one_side = 32 + + if random.random() < 0.5: + # Minus one + if ( + target_patch_width <= min_num_patch_one_side + and target_patch_height <= min_num_patch_one_side + ): + return target_patch_width, target_patch_height + elif target_patch_width <= min_num_patch_one_side: + return target_patch_width, target_patch_height - min_num_patch_one_side + elif target_patch_height <= min_num_patch_one_side: + return target_patch_width - min_num_patch_one_side, target_patch_height + else: + if random.random() < 0.5: + return ( + target_patch_width - min_num_patch_one_side, + target_patch_height, + ) + else: + return ( + target_patch_width, + target_patch_height - min_num_patch_one_side, + ) + else: + # Plus one + if target_patch_width * target_patch_height < current_num_tokens_available: + if random.random() < 0.5: + return ( + target_patch_width + min_num_patch_one_side, + target_patch_height, + ) + else: + return ( + target_patch_width, + target_patch_height + min_num_patch_one_side, + ) + return target_patch_width, target_patch_height + + def compute_params( + self, + media_list: list[Image.Image], + num_tokens_available: int | None = None, + max_num_tiles: int | None = None, + data_augment: bool = False, + **kwargs, + ) -> list[DynamicResolutionParams]: + """Compute parameters for all media with iterative token budgeting. + + Args: + media_list: List of media items to process + num_tokens_available: Total number of tokens available across all media + max_num_tiles: Maximum number of tiles (unused in this implementation) + data_augment: Whether to apply data augmentation to the image. Defaults to False. + Returns: + List of ImageTilingParams for each media item + """ + num_tokens_available = ( + num_tokens_available + * (4 if self._pixel_shuffle else 1) + * (4 if self._conv_merging else 1) + ) + # When the number of available token is too small, allow self._min_num_patches per media and + # let the sample be truncated. + num_tokens_available = max( + num_tokens_available, self._min_num_patches * len(media_list) + ) + + # Clip the number of tokens available per media to be between min and max patches. + num_tokens_available_per_media = [ + max(min(num_tokens_available, self._max_num_patches), self._min_num_patches) + for _ in range(len(media_list)) + ] + + # In theory this could be a while True loop, but in case the process_media method slightly + # changes, I want to make sure we don't get stuck in an infinite loop. + for _ in range(10): + # Step 1: Process each media with current token budget + params = [] + token_counts = [] + + for media, tokens_for_media in zip( + media_list, num_tokens_available_per_media + ): + param, token_count = self.process_media( + media, tokens_for_media, data_augment=data_augment + ) + params.append(param) + token_counts.append(token_count) + + # Step 2: Check if total tokens is within budget + total_tokens = sum(token_counts) + + if total_tokens <= num_tokens_available: + # We're within budget, return the params + return params + + # Step 3: We're over budget, need to scale down + # Calculate scaling factor to get under budget + scaling_factor = num_tokens_available / total_tokens + + # Recalculate token budgets for each media based on scaling + # Each media gets a proportional share of the total budget + scaled_down_num_tokens_available_per_media = [ + max(self._min_num_patches, int(token_count * scaling_factor)) + for token_count in token_counts + ] + scaled_down = any( + [ + scaled_down_num_tokens_available_per_media[i] + < num_tokens_available_per_media[i] + for i in range(len(num_tokens_available_per_media)) + ] + ) + # If there was not scaling down, we're stuck just use min_num_patches per media, else + # try with the scaled down num_tokens_available_per_media. + if not scaled_down: + num_tokens_available_per_media = [self._min_num_patches] * len( + media_list + ) + else: + num_tokens_available_per_media = ( + scaled_down_num_tokens_available_per_media + ) + return params + + def stack( + self, images: list[torch.Tensor] + ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: + imgs_sizes = torch.tensor( + [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 + ) + + def rearrange_img(x): + py = x.shape[-2] // self._patch_size + px = x.shape[-1] // self._patch_size + x = einops.rearrange( + x, + "c (py yy) (px xx) -> (py px) (c yy xx)", + py=py, + yy=self._patch_size, + px=px, + xx=self._patch_size, + ) + return x + + if len(images) > 0: + imgs = [rearrange_img(img) for img in images] + + current_length = 0 + max_length = 0 + vision_cu_lengths = [0] + for img in imgs: + if max_length < img.shape[0]: + max_length = img.shape[0] + current_length += img.shape[0] + vision_cu_lengths.append(current_length) + + vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) + vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) + + return ( + torch.cat(imgs, dim=0).unsqueeze(0), + imgs_sizes, + vision_cu_lengths, + vision_max_lengths, + ) + else: + return ( + torch.tensor([[0]], dtype=torch.float32), + torch.tensor([[0, 0]], dtype=torch.int32), + None, + None, + ) + + def __str__(self): + return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})" + + + +image_tiling_strategy = DynamicResolutionImageTilingStrategy( + vision_model_type="radio", + min_num_patches=4, + patch_size=16, + get_num_embeddings=lambda x, y: x * y * 2, + max_num_patches=64, +) + + IMG_START = "" IMG_END = "" IMG_CONTEXT = ""