import image processing into model_executor/models/nano_nemotron_vl.py

This commit is contained in:
Netanel Haber 2025-12-11 17:03:35 +02:00 committed by Netanel Haber
parent 50ffea9826
commit 1bceb28678

View File

@ -7,10 +7,14 @@
# LICENSE is in root directory.
# --------------------------------------------------------
import math
import random
from dataclasses import dataclass
import copy
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Iterable, Mapping, Sequence, Callable
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy.typing as npt
@ -20,6 +24,7 @@ import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType
import einops
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
@ -84,6 +89,488 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Alternative: Set a specific higher limit
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060]
RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250]
pixel_statistics = {
"clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
"internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
"radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
"huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
"radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
}
@dataclass
class DynamicResolutionParams:
image: Image.Image
num_tiles: int
num_embeddings: int
patch_size: tuple[int, int]
class DynamicResolutionImageTilingStrategy:
def __init__(
self,
vision_model_type: str,
min_num_patches: int,
patch_size: int,
get_num_embeddings: Callable[[int, int], int],
factor_max: float = 1.0,
pixel_shuffle: bool = False,
min_side: int | None = None,
conv_merging: bool = False,
use_thumbnail: bool = False,
thumbnail_size: int = 448,
thumbnail_area_threshold: float = 0.8,
max_num_patches: int = 0,
apply_data_augment: bool = False,
):
assert "radio" in vision_model_type, (
"Dynamic resolution is only supported for radio models"
)
self._vision_model_type = vision_model_type
self._min_num_patches = min_num_patches
self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf")
self._patch_size = patch_size
self._get_num_embeddings = get_num_embeddings
self._factor_max = factor_max
self._pixel_shuffle = pixel_shuffle
self._min_side = min_side
self._conv_merging = conv_merging
self._use_thumbnail = use_thumbnail
self._thumbnail_size = thumbnail_size
self._thumbnail_area_threshold = thumbnail_area_threshold
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
self._transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)),
T.Normalize(mean=pixel_mean, std=pixel_std),
]
)
self._apply_data_augment = apply_data_augment
def apply_params(
self, params: DynamicResolutionParams, **kwargs
) -> list[torch.Tensor]:
# resize the image
resized_img = params.image.resize(
(
params.patch_size[0] * self._patch_size,
params.patch_size[1] * self._patch_size,
)
)
processed_images = [resized_img]
# Add thumbnail if enabled and image area is below threshold
if self._use_thumbnail:
# Calculate areas
resized_area = resized_img.size[0] * resized_img.size[1]
thumbnail_area = self._thumbnail_size * self._thumbnail_size
area_ratio = resized_area / thumbnail_area
# Only add thumbnail if resized image area is less than threshold % of thumbnail area
if area_ratio < self._thumbnail_area_threshold:
thumbnail_img = params.image.resize(
(self._thumbnail_size, self._thumbnail_size)
)
processed_images.append(thumbnail_img)
return [self._transform(img) for img in processed_images]
def process_media(
self,
image: Image.Image,
num_tokens_available: int,
data_augment: bool = False,
tiling_augment_prob: float = 0.4,
) -> DynamicResolutionParams:
"""Process a single media item and return its parameters.
Args:
media: The media item to process
num_tokens_available: Number of tokens available for this media
data_augment: Whether to apply data augmentation to the image. Defaults to False.
Returns:
DynamicResolutionParams for the media
"""
current_num_tokens_available = num_tokens_available
assert isinstance(image, Image.Image), (
"Dynamic resolution is only supported for image media"
)
orig_width, orig_height = image.width, image.height
closest_patch_height = round(orig_height / self._patch_size + 0.5)
closest_patch_width = round(orig_width / self._patch_size + 0.5)
patches = closest_patch_height * closest_patch_width
factor = min(
math.sqrt(current_num_tokens_available / patches), self._factor_max
)
target_patch_height = math.floor(factor * closest_patch_height)
target_patch_width = math.floor(factor * closest_patch_width)
# We only consider self._min_num_patches if it is greater than current_num_tokens_available.
if (
current_num_tokens_available > self._min_num_patches
and target_patch_height * target_patch_width < self._min_num_patches
):
up_factor = math.sqrt(
self._min_num_patches / (target_patch_height * target_patch_width)
)
target_patch_height = math.ceil(up_factor * target_patch_height)
target_patch_width = math.ceil(up_factor * target_patch_width)
if (
self._min_side is not None
and min(target_patch_width, target_patch_height) * self._patch_size
< self._min_side
):
if target_patch_width <= target_patch_height:
up_factor = self._min_side / (target_patch_width * self._patch_size)
new_patch_height = math.ceil(up_factor * target_patch_height)
new_patch_width = math.ceil(up_factor * target_patch_width)
if new_patch_height * new_patch_width > current_num_tokens_available:
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
if (
max(current_num_tokens_available // new_patch_width, 1)
* self._patch_size
< self._min_side
):
up_factor = math.sqrt(
current_num_tokens_available
/ (target_patch_height * target_patch_width)
)
target_patch_height = math.floor(
up_factor * target_patch_height
)
target_patch_width = math.floor(up_factor * target_patch_width)
target_patch_width = new_patch_width
target_patch_height = max(
current_num_tokens_available // new_patch_width, 1
)
else:
target_patch_height = new_patch_height
target_patch_width = new_patch_width
else:
up_factor = self._min_side / (target_patch_height * self._patch_size)
new_patch_height = math.ceil(up_factor * target_patch_height)
new_patch_width = math.ceil(up_factor * target_patch_width)
if new_patch_height * new_patch_width > current_num_tokens_available:
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
if (
max(current_num_tokens_available // new_patch_height, 1)
* self._patch_size
< self._min_side
):
up_factor = math.sqrt(
current_num_tokens_available
/ (target_patch_height * target_patch_width)
)
target_patch_height = math.floor(
up_factor * target_patch_height
)
target_patch_width = math.floor(up_factor * target_patch_width)
else:
target_patch_height = new_patch_height
target_patch_width = max(
current_num_tokens_available // new_patch_height, 1
)
else:
target_patch_height = new_patch_height
target_patch_width = new_patch_width
# Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
# or by 4 when BOTH are enabled (two successive 2x reductions)
if self._pixel_shuffle or self._conv_merging:
required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2
rem_h = target_patch_height % required_divisor
if rem_h != 0:
inc_h = required_divisor - rem_h
if (
target_patch_height + inc_h
) * target_patch_width <= current_num_tokens_available:
target_patch_height += inc_h
else:
target_patch_height = max(
required_divisor, target_patch_height - rem_h
)
rem_w = target_patch_width % required_divisor
if rem_w != 0:
inc_w = required_divisor - rem_w
if (
target_patch_height * (target_patch_width + inc_w)
<= current_num_tokens_available
):
target_patch_width += inc_w
else:
target_patch_width = max(
required_divisor, target_patch_width - rem_w
)
if (
data_augment
and self._apply_data_augment
and random.random() < tiling_augment_prob
):
target_patch_width, target_patch_height = self.augment_resolution(
target_patch_width, target_patch_height, current_num_tokens_available
)
assert isinstance(image, Image.Image), (
"Dynamic resolution is only supported for image media"
)
# Calculate embeddings for the main dynamic resolution image
num_embeddings = self._get_num_embeddings(
target_patch_width * self._patch_size,
target_patch_height * self._patch_size,
)
token_count = target_patch_width * target_patch_height
# Add thumbnail embeddings if enabled and image area is below threshold
num_tiles = 1 # Base dynamic resolution image
if self._use_thumbnail:
# Calculate areas
resized_area = (target_patch_width * self._patch_size) * (
target_patch_height * self._patch_size
)
thumbnail_area = self._thumbnail_size * self._thumbnail_size
area_ratio = resized_area / thumbnail_area
# Only add thumbnail if resized image area is less than threshold % of thumbnail area
if area_ratio < self._thumbnail_area_threshold:
num_tiles += 1 # Add 1 for thumbnail
# Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
num_embeddings += self._get_num_embeddings(
self._thumbnail_size, self._thumbnail_size
)
token_count += (
self._thumbnail_size
// self._patch_size
* self._thumbnail_size
// self._patch_size
)
return DynamicResolutionParams(
image=image,
num_tiles=num_tiles,
num_embeddings=num_embeddings,
patch_size=(target_patch_width, target_patch_height),
), token_count
def augment_resolution(
self,
target_patch_width: int,
target_patch_height: int,
current_num_tokens_available: int,
) -> tuple[int, int]:
min_num_patch_one_side = 32
if random.random() < 0.5:
# Minus one
if (
target_patch_width <= min_num_patch_one_side
and target_patch_height <= min_num_patch_one_side
):
return target_patch_width, target_patch_height
elif target_patch_width <= min_num_patch_one_side:
return target_patch_width, target_patch_height - min_num_patch_one_side
elif target_patch_height <= min_num_patch_one_side:
return target_patch_width - min_num_patch_one_side, target_patch_height
else:
if random.random() < 0.5:
return (
target_patch_width - min_num_patch_one_side,
target_patch_height,
)
else:
return (
target_patch_width,
target_patch_height - min_num_patch_one_side,
)
else:
# Plus one
if target_patch_width * target_patch_height < current_num_tokens_available:
if random.random() < 0.5:
return (
target_patch_width + min_num_patch_one_side,
target_patch_height,
)
else:
return (
target_patch_width,
target_patch_height + min_num_patch_one_side,
)
return target_patch_width, target_patch_height
def compute_params(
self,
media_list: list[Image.Image],
num_tokens_available: int | None = None,
max_num_tiles: int | None = None,
data_augment: bool = False,
**kwargs,
) -> list[DynamicResolutionParams]:
"""Compute parameters for all media with iterative token budgeting.
Args:
media_list: List of media items to process
num_tokens_available: Total number of tokens available across all media
max_num_tiles: Maximum number of tiles (unused in this implementation)
data_augment: Whether to apply data augmentation to the image. Defaults to False.
Returns:
List of ImageTilingParams for each media item
"""
num_tokens_available = (
num_tokens_available
* (4 if self._pixel_shuffle else 1)
* (4 if self._conv_merging else 1)
)
# When the number of available token is too small, allow self._min_num_patches per media and
# let the sample be truncated.
num_tokens_available = max(
num_tokens_available, self._min_num_patches * len(media_list)
)
# Clip the number of tokens available per media to be between min and max patches.
num_tokens_available_per_media = [
max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
for _ in range(len(media_list))
]
# In theory this could be a while True loop, but in case the process_media method slightly
# changes, I want to make sure we don't get stuck in an infinite loop.
for _ in range(10):
# Step 1: Process each media with current token budget
params = []
token_counts = []
for media, tokens_for_media in zip(
media_list, num_tokens_available_per_media
):
param, token_count = self.process_media(
media, tokens_for_media, data_augment=data_augment
)
params.append(param)
token_counts.append(token_count)
# Step 2: Check if total tokens is within budget
total_tokens = sum(token_counts)
if total_tokens <= num_tokens_available:
# We're within budget, return the params
return params
# Step 3: We're over budget, need to scale down
# Calculate scaling factor to get under budget
scaling_factor = num_tokens_available / total_tokens
# Recalculate token budgets for each media based on scaling
# Each media gets a proportional share of the total budget
scaled_down_num_tokens_available_per_media = [
max(self._min_num_patches, int(token_count * scaling_factor))
for token_count in token_counts
]
scaled_down = any(
[
scaled_down_num_tokens_available_per_media[i]
< num_tokens_available_per_media[i]
for i in range(len(num_tokens_available_per_media))
]
)
# If there was not scaling down, we're stuck just use min_num_patches per media, else
# try with the scaled down num_tokens_available_per_media.
if not scaled_down:
num_tokens_available_per_media = [self._min_num_patches] * len(
media_list
)
else:
num_tokens_available_per_media = (
scaled_down_num_tokens_available_per_media
)
return params
def stack(
self, images: list[torch.Tensor]
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
imgs_sizes = torch.tensor(
[[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
)
def rearrange_img(x):
py = x.shape[-2] // self._patch_size
px = x.shape[-1] // self._patch_size
x = einops.rearrange(
x,
"c (py yy) (px xx) -> (py px) (c yy xx)",
py=py,
yy=self._patch_size,
px=px,
xx=self._patch_size,
)
return x
if len(images) > 0:
imgs = [rearrange_img(img) for img in images]
current_length = 0
max_length = 0
vision_cu_lengths = [0]
for img in imgs:
if max_length < img.shape[0]:
max_length = img.shape[0]
current_length += img.shape[0]
vision_cu_lengths.append(current_length)
vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
return (
torch.cat(imgs, dim=0).unsqueeze(0),
imgs_sizes,
vision_cu_lengths,
vision_max_lengths,
)
else:
return (
torch.tensor([[0]], dtype=torch.float32),
torch.tensor([[0, 0]], dtype=torch.int32),
None,
None,
)
def __str__(self):
return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})"
image_tiling_strategy = DynamicResolutionImageTilingStrategy(
vision_model_type="radio",
min_num_patches=4,
patch_size=16,
get_num_embeddings=lambda x, y: x * y * 2,
max_num_patches=64,
)
IMG_START = "<img>"
IMG_END = "</img>"
IMG_CONTEXT = "<image>"