diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
index 3b8a3841cf938..3bac36744ef5e 100644
--- a/vllm/model_executor/models/nano_nemotron_vl.py
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -12,7 +12,7 @@ import math
import random
import warnings
from abc import ABC, abstractmethod
-from collections.abc import Callable, Iterable, Mapping, Sequence
+from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
@@ -88,501 +88,9 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Alternative: Set a specific higher limit
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
-IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
-IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
-SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
-SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
-CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
-CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
-RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060]
-RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250]
-
-pixel_statistics = {
- "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
- "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
- "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
- "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
- "radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
- "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
-}
-
-
-@dataclass
-class DynamicResolutionParams:
- media: Image.Image
- num_tiles: int
- num_embeddings: int
- patch_size: tuple[int, int]
-
-
-class DynamicResolutionImageTilingStrategy:
- def __init__(
- self,
- vision_model_type: str,
- min_num_patches: int,
- patch_size: int,
- get_num_embeddings: Callable[[int, int], int],
- factor_max: float = 1.0,
- pixel_shuffle: bool = False,
- min_side: int | None = None,
- conv_merging: bool = False,
- use_thumbnail: bool = False,
- thumbnail_size: int = 448,
- thumbnail_area_threshold: float = 0.8,
- max_num_patches: int = 0,
- apply_data_augment: bool = False,
- ):
- assert "radio" in vision_model_type, (
- "Dynamic resolution is only supported for radio models"
- )
- self._vision_model_type = vision_model_type
- self._min_num_patches = min_num_patches
- self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf")
- self._patch_size = patch_size
- self._get_num_embeddings = get_num_embeddings
- self._factor_max = factor_max
- self._pixel_shuffle = pixel_shuffle
- self._min_side = min_side
- self._conv_merging = conv_merging
- self._use_thumbnail = use_thumbnail
- self._thumbnail_size = thumbnail_size
- self._thumbnail_area_threshold = thumbnail_area_threshold
- pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
- self._transform = T.Compose(
- [
- T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
- T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)),
- T.Normalize(mean=pixel_mean, std=pixel_std),
- ]
- )
- self._apply_data_augment = apply_data_augment
-
- def apply_params(
- self, params: DynamicResolutionParams, **kwargs
- ) -> list[torch.Tensor]:
- # resize the image
- resized_img = params.media.resize(
- (
- params.patch_size[0] * self._patch_size,
- params.patch_size[1] * self._patch_size,
- )
- )
- processed_images = [resized_img]
-
- # Add thumbnail if enabled and image area is below threshold
- if self._use_thumbnail:
- # Calculate areas
- resized_area = resized_img.size[0] * resized_img.size[1]
- thumbnail_area = self._thumbnail_size * self._thumbnail_size
- area_ratio = resized_area / thumbnail_area
-
- # Only add thumbnail if resized image area is less than threshold % of
- # thumbnail area
- if area_ratio < self._thumbnail_area_threshold:
- thumbnail_img = params.media.resize(
- (self._thumbnail_size, self._thumbnail_size)
- )
- processed_images.append(thumbnail_img)
-
- return [self._transform(img) for img in processed_images]
-
- def process_media(
- self,
- media: Image.Image,
- num_tokens_available: int,
- data_augment: bool = False,
- tiling_augment_prob: float = 0.4,
- ) -> DynamicResolutionParams:
- """Process a single media item and return its parameters.
- Args:
- media: The media item to process
- num_tokens_available: Number of tokens available for this media
- data_augment: Whether to apply data augmentation to the image. Defaults to
- False.
- Returns:
- DynamicResolutionParams for the media
- """
- current_num_tokens_available = num_tokens_available
- assert isinstance(media, Image.Image), (
- "Dynamic resolution is only supported for image media"
- )
- orig_width, orig_height = media.width, media.height
-
- closest_patch_height = round(orig_height / self._patch_size + 0.5)
- closest_patch_width = round(orig_width / self._patch_size + 0.5)
- patches = closest_patch_height * closest_patch_width
-
- factor = min(
- math.sqrt(current_num_tokens_available / patches), self._factor_max
- )
- target_patch_height = math.floor(factor * closest_patch_height)
- target_patch_width = math.floor(factor * closest_patch_width)
-
- # We only consider self._min_num_patches if it is greater than
- # current_num_tokens_available.
- if (
- current_num_tokens_available > self._min_num_patches
- and target_patch_height * target_patch_width < self._min_num_patches
- ):
- up_factor = math.sqrt(
- self._min_num_patches / (target_patch_height * target_patch_width)
- )
- target_patch_height = math.ceil(up_factor * target_patch_height)
- target_patch_width = math.ceil(up_factor * target_patch_width)
-
- if (
- self._min_side is not None
- and min(target_patch_width, target_patch_height) * self._patch_size
- < self._min_side
- ):
- if target_patch_width <= target_patch_height:
- up_factor = self._min_side / (target_patch_width * self._patch_size)
- new_patch_height = math.ceil(up_factor * target_patch_height)
- new_patch_width = math.ceil(up_factor * target_patch_width)
-
- if new_patch_height * new_patch_width > current_num_tokens_available:
- # If only one side can be min_side, make as big as possible at
- # native aspect ratio while staying below max_patches
- if (
- max(current_num_tokens_available // new_patch_width, 1)
- * self._patch_size
- < self._min_side
- ):
- up_factor = math.sqrt(
- current_num_tokens_available
- / (target_patch_height * target_patch_width)
- )
- target_patch_height = math.floor(
- up_factor * target_patch_height
- )
- target_patch_width = math.floor(up_factor * target_patch_width)
- target_patch_width = new_patch_width
- target_patch_height = max(
- current_num_tokens_available // new_patch_width, 1
- )
- else:
- target_patch_height = new_patch_height
- target_patch_width = new_patch_width
- else:
- up_factor = self._min_side / (target_patch_height * self._patch_size)
- new_patch_height = math.ceil(up_factor * target_patch_height)
- new_patch_width = math.ceil(up_factor * target_patch_width)
-
- if new_patch_height * new_patch_width > current_num_tokens_available:
- # If only one side can be min_side, make as big as possible at
- # native aspect ratio while staying below max_patches
- if (
- max(current_num_tokens_available // new_patch_height, 1)
- * self._patch_size
- < self._min_side
- ):
- up_factor = math.sqrt(
- current_num_tokens_available
- / (target_patch_height * target_patch_width)
- )
- target_patch_height = math.floor(
- up_factor * target_patch_height
- )
- target_patch_width = math.floor(up_factor * target_patch_width)
- else:
- target_patch_height = new_patch_height
- target_patch_width = max(
- current_num_tokens_available // new_patch_height, 1
- )
- else:
- target_patch_height = new_patch_height
- target_patch_width = new_patch_width
-
- # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
- # or by 4 when BOTH are enabled (two successive 2x reductions)
- if self._pixel_shuffle or self._conv_merging:
- required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2
-
- rem_h = target_patch_height % required_divisor
- if rem_h != 0:
- inc_h = required_divisor - rem_h
- if (
- target_patch_height + inc_h
- ) * target_patch_width <= current_num_tokens_available:
- target_patch_height += inc_h
- else:
- target_patch_height = max(
- required_divisor, target_patch_height - rem_h
- )
-
- rem_w = target_patch_width % required_divisor
- if rem_w != 0:
- inc_w = required_divisor - rem_w
- if (
- target_patch_height * (target_patch_width + inc_w)
- <= current_num_tokens_available
- ):
- target_patch_width += inc_w
- else:
- target_patch_width = max(
- required_divisor, target_patch_width - rem_w
- )
-
- if (
- data_augment
- and self._apply_data_augment
- and random.random() < tiling_augment_prob
- ):
- target_patch_width, target_patch_height = self.augment_resolution(
- target_patch_width, target_patch_height, current_num_tokens_available
- )
-
- assert isinstance(media, Image.Image), (
- "Dynamic resolution is only supported for image media"
- )
-
- # Calculate embeddings for the main dynamic resolution image
- num_embeddings = self._get_num_embeddings(
- target_patch_width * self._patch_size,
- target_patch_height * self._patch_size,
- )
-
- token_count = target_patch_width * target_patch_height
-
- # Add thumbnail embeddings if enabled and image area is below threshold
- num_tiles = 1 # Base dynamic resolution image
- if self._use_thumbnail:
- # Calculate areas
- resized_area = (target_patch_width * self._patch_size) * (
- target_patch_height * self._patch_size
- )
- thumbnail_area = self._thumbnail_size * self._thumbnail_size
- area_ratio = resized_area / thumbnail_area
-
- # Only add thumbnail if resized image area is less than threshold % of
- # thumbnail area
- if area_ratio < self._thumbnail_area_threshold:
- num_tiles += 1 # Add 1 for thumbnail
- # Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
- num_embeddings += self._get_num_embeddings(
- self._thumbnail_size, self._thumbnail_size
- )
- token_count += (
- self._thumbnail_size
- // self._patch_size
- * self._thumbnail_size
- // self._patch_size
- )
-
- return DynamicResolutionParams(
- media=media,
- num_tiles=num_tiles,
- num_embeddings=num_embeddings,
- patch_size=(target_patch_width, target_patch_height),
- ), token_count
-
- def augment_resolution(
- self,
- target_patch_width: int,
- target_patch_height: int,
- current_num_tokens_available: int,
- ) -> tuple[int, int]:
- min_num_patch_one_side = 32
-
- if random.random() < 0.5:
- # Minus one
- if (
- target_patch_width <= min_num_patch_one_side
- and target_patch_height <= min_num_patch_one_side
- ):
- return target_patch_width, target_patch_height
- elif target_patch_width <= min_num_patch_one_side:
- return target_patch_width, target_patch_height - min_num_patch_one_side
- elif target_patch_height <= min_num_patch_one_side:
- return target_patch_width - min_num_patch_one_side, target_patch_height
- else:
- if random.random() < 0.5:
- return (
- target_patch_width - min_num_patch_one_side,
- target_patch_height,
- )
- else:
- return (
- target_patch_width,
- target_patch_height - min_num_patch_one_side,
- )
- else:
- # Plus one
- if target_patch_width * target_patch_height < current_num_tokens_available:
- if random.random() < 0.5:
- return (
- target_patch_width + min_num_patch_one_side,
- target_patch_height,
- )
- else:
- return (
- target_patch_width,
- target_patch_height + min_num_patch_one_side,
- )
- return target_patch_width, target_patch_height
-
- def compute_params(
- self,
- media_list: list[Image.Image],
- num_tokens_available: int | None = None,
- max_num_tiles: int | None = None,
- data_augment: bool = False,
- **kwargs,
- ) -> list[DynamicResolutionParams]:
- """Compute parameters for all media with iterative token budgeting.
-
- Args:
- media_list: List of media items to process
- num_tokens_available: Total number of tokens available across all media
- max_num_tiles: Maximum number of tiles (unused in this implementation)
- data_augment: Whether to apply data augmentation to the image. Defaults to
- False.
- Returns:
- List of ImageTilingParams for each media item
- """
- num_tokens_available = (
- num_tokens_available
- * (4 if self._pixel_shuffle else 1)
- * (4 if self._conv_merging else 1)
- )
- # When the number of available token is too small, allow self._min_num_patches
- # per media and let the sample be truncated.
- num_tokens_available = max(
- num_tokens_available, self._min_num_patches * len(media_list)
- )
-
- # Clip the number of tokens available per media to be between min and max
- # patches.
- num_tokens_available_per_media = [
- max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
- for _ in range(len(media_list))
- ]
-
- # In theory this could be a while True loop, but in case the process_media
- # method slightly
- # changes, I want to make sure we don't get stuck in an infinite loop.
- for _ in range(10):
- # Step 1: Process each media with current token budget
- params = []
- token_counts = []
-
- for media, tokens_for_media in zip(
- media_list, num_tokens_available_per_media
- ):
- param, token_count = self.process_media(
- media, tokens_for_media, data_augment=data_augment
- )
- params.append(param)
- token_counts.append(token_count)
-
- # Step 2: Check if total tokens is within budget
- total_tokens = sum(token_counts)
-
- if total_tokens <= num_tokens_available:
- # We're within budget, return the params
- return params
-
- # Step 3: We're over budget, need to scale down
- # Calculate scaling factor to get under budget
- scaling_factor = num_tokens_available / total_tokens
-
- # Recalculate token budgets for each media based on scaling
- # Each media gets a proportional share of the total budget
- scaled_down_num_tokens_available_per_media = [
- max(self._min_num_patches, int(token_count * scaling_factor))
- for token_count in token_counts
- ]
- scaled_down = any(
- [
- scaled_down_num_tokens_available_per_media[i]
- < num_tokens_available_per_media[i]
- for i in range(len(num_tokens_available_per_media))
- ]
- )
- # If there was not scaling down, we're stuck just use min_num_patches per
- # media, else try with the scaled down num_tokens_available_per_media.
- if not scaled_down:
- num_tokens_available_per_media = [self._min_num_patches] * len(
- media_list
- )
- else:
- num_tokens_available_per_media = (
- scaled_down_num_tokens_available_per_media
- )
- return params
-
- def stack(
- self, images: list[torch.Tensor]
- ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
- imgs_sizes = torch.tensor(
- [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
- )
-
- def rearrange_img(x):
- py = x.shape[-2] // self._patch_size
- px = x.shape[-1] // self._patch_size
- x = einops.rearrange(
- x,
- "c (py yy) (px xx) -> (py px) (c yy xx)",
- py=py,
- yy=self._patch_size,
- px=px,
- xx=self._patch_size,
- )
- return x
-
- if len(images) > 0:
- imgs = [rearrange_img(img) for img in images]
-
- current_length = 0
- max_length = 0
- vision_cu_lengths = [0]
- for img in imgs:
- if max_length < img.shape[0]:
- max_length = img.shape[0]
- current_length += img.shape[0]
- vision_cu_lengths.append(current_length)
-
- vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
- vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
-
- return (
- torch.cat(imgs, dim=0).unsqueeze(0),
- imgs_sizes,
- vision_cu_lengths,
- vision_max_lengths,
- )
- else:
- return (
- torch.tensor([[0]], dtype=torch.float32),
- torch.tensor([[0, 0]], dtype=torch.int32),
- None,
- None,
- )
-
- def __str__(self):
- return f"DynamicResolutionImageTransform(\
- vision_model_type={self._vision_model_type}, \
- min_num_patches={self._min_num_patches}, \
- patch_size={self._patch_size}, \
- pixel_shuffle={self._pixel_shuffle}, \
- conv_merging={self._conv_merging}, \
- use_thumbnail={self._use_thumbnail}, \
- thumbnail_size={self._thumbnail_size}, \
- thumbnail_area_threshold={self._thumbnail_area_threshold})"
-
-
-image_tiling_strategy = DynamicResolutionImageTilingStrategy(
- vision_model_type="radio",
- min_num_patches=4,
- patch_size=16,
- get_num_embeddings=lambda x, y: x * y * 2,
- max_num_patches=64,
-)
+# TODO(nhaber): get 2048 from config
+# TODO(nhaber): does use_thumbnail=True work?
IMG_START = "
"
@@ -753,7 +261,12 @@ def video_to_pixel_values(
return torch.stack(frames_tensors)
-def input_conditioner(x, norm_mean, norm_std):
+def input_conditioner(
+ x: torch.Tensor, norm_mean: torch.Tensor, norm_std: torch.Tensor
+) -> torch.Tensor:
+ assert isinstance(x, torch.Tensor), "x must be a tensor"
+ assert isinstance(norm_mean, torch.Tensor), "norm_mean must be a tensor"
+ assert isinstance(norm_std, torch.Tensor), "norm_std must be a tensor"
return (x - norm_mean) / norm_std
@@ -792,15 +305,20 @@ class BaseNanoNemotronVLProcessor(ABC):
self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES
image_size: int = config.force_image_size
- patch_size: int = config.patch_size
+ self.patch_size: int = getattr(config, "patch_size", 16)
+ self.downsample_ratio: float = self.config.downsample_ratio
- self.num_image_token = int(
- (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
- )
self.image_size = image_size
self.use_thumbnail: bool = config.use_thumbnail
- self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1)
- self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1)
+ self.norm_mean = torch.tensor(config.norm_mean).reshape(1, 3, 1, 1)
+ self.norm_std = torch.tensor(config.norm_std).reshape(1, 3, 1, 1)
+
+ def num_image_token(self, *, image_width: int, image_height: int) -> int:
+ image_size = math.sqrt(image_width * image_height)
+ num_tokens = int(
+ (image_size // self.patch_size) ** 2 * (self.downsample_ratio**2)
+ )
+ return num_tokens
@property
@abstractmethod
@@ -832,10 +350,13 @@ class BaseNanoNemotronVLProcessor(ABC):
use_thumbnail=self.use_thumbnail,
)
- return num_patches * self.num_image_token
+ return num_patches * self.num_image_token(
+ image_width=image_width, image_height=image_height
+ )
def _images_to_pixel_values_lst(
self,
+ text: list[str],
images: list[Image.Image],
max_num_tiles: int,
) -> list[torch.Tensor]:
@@ -859,7 +380,9 @@ class BaseNanoNemotronVLProcessor(ABC):
if len(images) == 0:
image_inputs = {}
else:
- pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
+ pixel_values_lst = self._images_to_pixel_values_lst(
+ text=text, images=images, max_num_tiles=max_num_tiles
+ )
image_inputs = {
"pixel_values_flat": input_conditioner(
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
@@ -881,7 +404,10 @@ class BaseNanoNemotronVLProcessor(ABC):
for i, pixel_values in enumerate(pixel_values_lst):
num_patches = pixel_values.shape[0]
- feature_size = num_patches * self.num_image_token
+ feature_size = num_patches * self.num_image_token(
+ image_width=pixel_values.shape[1],
+ image_height=pixel_values.shape[2],
+ )
image_repl = self.get_image_repl(feature_size, num_patches)
parts[i] = parts[i].replace("", image_repl.full)
text = ["".join(parts)]
@@ -894,6 +420,7 @@ class BaseNanoNemotronVLProcessor(ABC):
input_item = [input_item]
return input_item
+ @abstractmethod
def __call__(
self,
text: str | list[str] | None = None,
@@ -901,26 +428,487 @@ class BaseNanoNemotronVLProcessor(ABC):
return_tensors: str | TensorType | None = None,
max_num_tiles: int | None = None,
) -> BatchFeature:
- # Use default if not provided
- if max_num_tiles is None:
- max_num_tiles = self.max_num_tiles
+ raise NotImplementedError
- text, images = [self._make_batch_input(x) for x in (text, images)]
- text, image_inputs = self._preprocess_image(
- text=text,
- images=images,
- max_num_tiles=max_num_tiles,
+@dataclass
+class DynamicResolutionParams:
+ media: Image.Image
+ num_tiles: int
+ num_embeddings: int
+ patch_size: tuple[int, int]
+
+
+class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
+ CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
+ CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ tokenizer: TokenizerLike,
+ *args,
+ max_num_tiles: int | None = None,
+ min_num_patches: int = 4,
+ factor_max: float = 1.0,
+ pixel_shuffle: bool = True,
+ min_side: int | None = None,
+ conv_merging: bool = False,
+ use_thumbnail: bool = False,
+ thumbnail_size: int = 448,
+ thumbnail_area_threshold: float = 0.8,
+ apply_data_augment: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ config=config, tokenizer=tokenizer, max_num_tiles=max_num_tiles, **kwargs
+ )
+ self._min_num_patches = min_num_patches
+ self._factor_max = factor_max
+ self._pixel_shuffle = pixel_shuffle
+ self._min_side = min_side
+ self._conv_merging = conv_merging
+ self._use_thumbnail = use_thumbnail
+ self._thumbnail_size = thumbnail_size
+ self._thumbnail_area_threshold = thumbnail_area_threshold
+ self._transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)),
+ ]
+ )
+ self._apply_data_augment = apply_data_augment
+
+ self.norm_mean = torch.tensor(self.CLIP_PIXEL_MEAN).reshape(1, 3, 1, 1)
+ self.norm_std = torch.tensor(self.CLIP_PIXEL_STD).reshape(1, 3, 1, 1)
+ self.downsample_ratio = 2 if pixel_shuffle else 1
+
+ def apply_params(self, params: DynamicResolutionParams) -> torch.Tensor:
+ resized_img = params.media.resize(
+ (
+ params.patch_size[0] * self.patch_size,
+ params.patch_size[1] * self.patch_size,
+ )
+ )
+ # processed_images = [resized_img]
+
+ # # Add thumbnail if enabled and image area is below threshold
+ # if self._use_thumbnail:
+ # # Calculate areas
+ # resized_area = resized_img.size[0] * resized_img.size[1]
+ # thumbnail_area = self._thumbnail_size * self._thumbnail_size
+ # area_ratio = resized_area / thumbnail_area
+
+ # # Only add thumbnail if resized image area is less than threshold % of
+ # # thumbnail area
+ # if area_ratio < self._thumbnail_area_threshold:
+ # thumbnail_img = params.media.resize(
+ # (self._thumbnail_size, self._thumbnail_size)
+ # )
+ # processed_images.append(thumbnail_img)
+
+ return self._transform(resized_img)
+
+ def process_media(
+ self,
+ media: Image.Image,
+ num_tokens_available: int,
+ data_augment: bool = False,
+ tiling_augment_prob: float = 0.4,
+ ) -> DynamicResolutionParams:
+ """Process a single media item and return its parameters.
+ Args:
+ media: The media item to process
+ num_tokens_available: Number of tokens available for this media
+ data_augment: Whether to apply data augmentation to the image. Defaults to
+ False.
+ Returns:
+ DynamicResolutionParams for the media
+ """
+ current_num_tokens_available = num_tokens_available
+ assert isinstance(media, Image.Image), (
+ "Dynamic resolution is only supported for image media"
+ )
+ orig_width, orig_height = media.width, media.height
+
+ closest_patch_height = round(orig_height / self.patch_size + 0.5)
+ closest_patch_width = round(orig_width / self.patch_size + 0.5)
+ patches = closest_patch_height * closest_patch_width
+
+ factor = min(
+ math.sqrt(current_num_tokens_available / patches), self._factor_max
+ )
+ target_patch_height = math.floor(factor * closest_patch_height)
+ target_patch_width = math.floor(factor * closest_patch_width)
+
+ # We only consider self._min_num_patches if it is greater than
+ # current_num_tokens_available.
+ if (
+ current_num_tokens_available > self._min_num_patches
+ and target_patch_height * target_patch_width < self._min_num_patches
+ ):
+ up_factor = math.sqrt(
+ self._min_num_patches / (target_patch_height * target_patch_width)
+ )
+ target_patch_height = math.ceil(up_factor * target_patch_height)
+ target_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if (
+ self._min_side is not None
+ and min(target_patch_width, target_patch_height) * self.patch_size
+ < self._min_side
+ ):
+ if target_patch_width <= target_patch_height:
+ up_factor = self._min_side / (target_patch_width * self.patch_size)
+ new_patch_height = math.ceil(up_factor * target_patch_height)
+ new_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if new_patch_height * new_patch_width > current_num_tokens_available:
+ # If only one side can be min_side, make as big as possible at
+ # native aspect ratio while staying below max_patches
+ if (
+ max(current_num_tokens_available // new_patch_width, 1)
+ * self.patch_size
+ < self._min_side
+ ):
+ up_factor = math.sqrt(
+ current_num_tokens_available
+ / (target_patch_height * target_patch_width)
+ )
+ target_patch_height = math.floor(
+ up_factor * target_patch_height
+ )
+ target_patch_width = math.floor(up_factor * target_patch_width)
+ target_patch_width = new_patch_width
+ target_patch_height = max(
+ current_num_tokens_available // new_patch_width, 1
+ )
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = new_patch_width
+ else:
+ up_factor = self._min_side / (target_patch_height * self.patch_size)
+ new_patch_height = math.ceil(up_factor * target_patch_height)
+ new_patch_width = math.ceil(up_factor * target_patch_width)
+
+ if new_patch_height * new_patch_width > current_num_tokens_available:
+ # If only one side can be min_side, make as big as possible at
+ # native aspect ratio while staying below max_patches
+ if (
+ max(current_num_tokens_available // new_patch_height, 1)
+ * self.patch_size
+ < self._min_side
+ ):
+ up_factor = math.sqrt(
+ current_num_tokens_available
+ / (target_patch_height * target_patch_width)
+ )
+ target_patch_height = math.floor(
+ up_factor * target_patch_height
+ )
+ target_patch_width = math.floor(up_factor * target_patch_width)
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = max(
+ current_num_tokens_available // new_patch_height, 1
+ )
+ else:
+ target_patch_height = new_patch_height
+ target_patch_width = new_patch_width
+
+ # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
+ # or by 4 when BOTH are enabled (two successive 2x reductions)
+ if self._pixel_shuffle or self._conv_merging:
+ required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2
+
+ rem_h = target_patch_height % required_divisor
+ if rem_h != 0:
+ inc_h = required_divisor - rem_h
+ if (
+ target_patch_height + inc_h
+ ) * target_patch_width <= current_num_tokens_available:
+ target_patch_height += inc_h
+ else:
+ target_patch_height = max(
+ required_divisor, target_patch_height - rem_h
+ )
+
+ rem_w = target_patch_width % required_divisor
+ if rem_w != 0:
+ inc_w = required_divisor - rem_w
+ if (
+ target_patch_height * (target_patch_width + inc_w)
+ <= current_num_tokens_available
+ ):
+ target_patch_width += inc_w
+ else:
+ target_patch_width = max(
+ required_divisor, target_patch_width - rem_w
+ )
+
+ if (
+ data_augment
+ and self._apply_data_augment
+ and random.random() < tiling_augment_prob
+ ):
+ target_patch_width, target_patch_height = self.augment_resolution(
+ target_patch_width, target_patch_height, current_num_tokens_available
+ )
+
+ assert isinstance(media, Image.Image), (
+ "Dynamic resolution is only supported for image media"
)
- text_inputs = self.tokenizer(text, add_special_tokens=False)
+ # Calculate embeddings for the main dynamic resolution image
+ num_embeddings = self.num_image_token(
+ image_width=target_patch_width, image_height=target_patch_height
+ )
- combined_outputs = {**text_inputs, **image_inputs}
+ token_count = target_patch_width * target_patch_height
- return BatchFeature(combined_outputs, tensor_type=return_tensors)
+ # Add thumbnail embeddings if enabled and image area is below threshold
+ num_tiles = 1 # Base dynamic resolution image
+ if self._use_thumbnail:
+ # Calculate areas
+ resized_area = (target_patch_width * self.patch_size) * (
+ target_patch_height * self.patch_size
+ )
+ thumbnail_area = self._thumbnail_size * self._thumbnail_size
+ area_ratio = resized_area / thumbnail_area
+
+ # Only add thumbnail if resized image area is less than threshold % of
+ # thumbnail area
+ if area_ratio < self._thumbnail_area_threshold:
+ num_tiles += 1 # Add 1 for thumbnail
+ # Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
+ num_embeddings += self.num_image_token(
+ image_width=self._thumbnail_size, image_height=self._thumbnail_size
+ )
+ token_count += (
+ self._thumbnail_size
+ // self.patch_size
+ * self._thumbnail_size
+ // self.patch_size
+ )
+
+ return DynamicResolutionParams(
+ media=media,
+ num_tiles=num_tiles,
+ num_embeddings=num_embeddings,
+ patch_size=(target_patch_width, target_patch_height),
+ ), token_count
+
+ def augment_resolution(
+ self,
+ target_patch_width: int,
+ target_patch_height: int,
+ current_num_tokens_available: int,
+ ) -> tuple[int, int]:
+ min_num_patch_one_side = 32
+
+ if random.random() < 0.5:
+ # Minus one
+ if (
+ target_patch_width <= min_num_patch_one_side
+ and target_patch_height <= min_num_patch_one_side
+ ):
+ return target_patch_width, target_patch_height
+ elif target_patch_width <= min_num_patch_one_side:
+ return target_patch_width, target_patch_height - min_num_patch_one_side
+ elif target_patch_height <= min_num_patch_one_side:
+ return target_patch_width - min_num_patch_one_side, target_patch_height
+ else:
+ if random.random() < 0.5:
+ return (
+ target_patch_width - min_num_patch_one_side,
+ target_patch_height,
+ )
+ else:
+ return (
+ target_patch_width,
+ target_patch_height - min_num_patch_one_side,
+ )
+ else:
+ # Plus one
+ if target_patch_width * target_patch_height < current_num_tokens_available:
+ if random.random() < 0.5:
+ return (
+ target_patch_width + min_num_patch_one_side,
+ target_patch_height,
+ )
+ else:
+ return (
+ target_patch_width,
+ target_patch_height + min_num_patch_one_side,
+ )
+ return target_patch_width, target_patch_height
+
+ def compute_params(
+ self,
+ media_list: list[Image.Image],
+ num_tokens_available: int | None = None,
+ data_augment: bool = False,
+ ) -> list[DynamicResolutionParams]:
+ """Compute parameters for all media with iterative token budgeting.
+
+ Args:
+ media_list: List of media items to process
+ num_tokens_available: Total number of tokens available across all media
+ data_augment: Whether to apply data augmentation to the image. Defaults to
+ False.
+ Returns:
+ List of ImageTilingParams for each media item
+ """
+ num_tokens_available = (
+ num_tokens_available
+ * (4 if self._pixel_shuffle else 1)
+ * (4 if self._conv_merging else 1)
+ )
+ # When the number of available token is too small, allow self._min_num_patches
+ # per media and let the sample be truncated.
+ num_tokens_available = max(
+ num_tokens_available, self._min_num_patches * len(media_list)
+ )
+
+ # Clip the number of tokens available per media to be between min and max
+ # patches.
+ num_tokens_available_per_media = [
+ max(num_tokens_available, self._min_num_patches)
+ for _ in range(len(media_list))
+ ]
+
+ # In theory this could be a while True loop, but in case the process_media
+ # method slightly
+ # changes, I want to make sure we don't get stuck in an infinite loop.
+ for _ in range(10):
+ # Step 1: Process each media with current token budget
+ params = []
+ token_counts = []
+
+ for media, tokens_for_media in zip(
+ media_list, num_tokens_available_per_media
+ ):
+ param, token_count = self.process_media(
+ media, tokens_for_media, data_augment=data_augment
+ )
+ params.append(param)
+ token_counts.append(token_count)
+
+ # Step 2: Check if total tokens is within budget
+ total_tokens = sum(token_counts)
+
+ if total_tokens <= num_tokens_available:
+ # We're within budget, return the params
+ return params
+
+ # Step 3: We're over budget, need to scale down
+ # Calculate scaling factor to get under budget
+ scaling_factor = num_tokens_available / total_tokens
+
+ # Recalculate token budgets for each media based on scaling
+ # Each media gets a proportional share of the total budget
+ scaled_down_num_tokens_available_per_media = [
+ max(self._min_num_patches, int(token_count * scaling_factor))
+ for token_count in token_counts
+ ]
+ scaled_down = any(
+ [
+ scaled_down_num_tokens_available_per_media[i]
+ < num_tokens_available_per_media[i]
+ for i in range(len(num_tokens_available_per_media))
+ ]
+ )
+ # If there was not scaling down, we're stuck just use min_num_patches per
+ # media, else try with the scaled down num_tokens_available_per_media.
+ if not scaled_down:
+ num_tokens_available_per_media = [self._min_num_patches] * len(
+ media_list
+ )
+ else:
+ num_tokens_available_per_media = (
+ scaled_down_num_tokens_available_per_media
+ )
+ return params
+
+ def stack(
+ self, images: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
+ imgs_sizes = torch.tensor(
+ [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
+ )
+
+ def rearrange_img(x):
+ py = x.shape[-2] // self.patch_size
+ px = x.shape[-1] // self.patch_size
+ x = einops.rearrange(
+ x,
+ "c (py yy) (px xx) -> (py px) (c yy xx)",
+ py=py,
+ yy=self.patch_size,
+ px=px,
+ xx=self.patch_size,
+ )
+ return x
+
+ if len(images) > 0:
+ imgs = [rearrange_img(img) for img in images]
+
+ current_length = 0
+ max_length = 0
+ vision_cu_lengths = [0]
+ for img in imgs:
+ if max_length < img.shape[0]:
+ max_length = img.shape[0]
+ current_length += img.shape[0]
+ vision_cu_lengths.append(current_length)
+
+ vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
+ vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
+
+ return (
+ torch.cat(imgs, dim=0).unsqueeze(0),
+ imgs_sizes,
+ vision_cu_lengths,
+ vision_max_lengths,
+ )
+ else:
+ return (
+ torch.tensor([[0]], dtype=torch.float32),
+ torch.tensor([[0, 0]], dtype=torch.int32),
+ None,
+ None,
+ )
+
+ def _images_to_pixel_values_lst(
+ self,
+ text: list[str],
+ images: list[Image.Image],
+ max_num_tiles: int,
+ ) -> list[torch.Tensor]:
+ num_tokens_available = 2048 - len(text) - 4
+ params_per_image = self.compute_params(
+ images, num_tokens_available=num_tokens_available
+ )
+ images = []
+ for param in params_per_image:
+ t = self.apply_params(param)
+ if t.ndim == 3:
+ t = t.unsqueeze(0)
+ images.append(t)
+ return images
+
+ def __str__(self):
+ return f"DynamicResolutionImageTransform(\
+ min_num_patches={self._min_num_patches}, \
+ patch_size={self.patch_size}, \
+ pixel_shuffle={self._pixel_shuffle}, \
+ conv_merging={self._conv_merging}, \
+ use_thumbnail={self._use_thumbnail}, \
+ thumbnail_size={self._thumbnail_size}, \
+ thumbnail_area_threshold={self._thumbnail_area_threshold})"
-class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
+class NanoNemotronVLProcessor(DynamicResolutionImageTiler):
"""
HF Processor with extended video processing logic.
Code for video processing is adapted from video example:
@@ -1312,7 +1300,9 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
processor = self.get_hf_processor() # we get the CustomProcessor here
max_image_tokens = self.get_max_image_tokens() * max_images
- max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
+ max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token(
+ image_width=256, image_height=256
+ ) # TODO(nhaber): get 256 dynamically
max_frames_per_video = max_total_frames // max(max_videos, 1)
return max(max_frames_per_video, 1)
@@ -1457,7 +1447,9 @@ class NanoNemotronVLMultiModalProcessor(
video_num_patches = []
def get_video_replacement_internvl(item_idx: int):
- feature_size = hf_processor.num_image_token
+ feature_size = hf_processor.num_image_token(
+ image_width=256, image_height=256
+ ) # TODO(nhaber): get 256 dynamically
video, metadata = mm_items["video"][item_idx]
num_patches = video_num_patches[item_idx]
if num_patches is not None:
@@ -1633,9 +1625,6 @@ class NemotronH_Nano_VL_V2(
patch_size = config.patch_size
self.patch_size = patch_size
self.template = config.template
- self.num_image_token = int(
- (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
- )
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
self.image_tag_type = config.image_tag_type
@@ -2153,33 +2142,6 @@ class NemotronH_Nano_VL_V2(
if save_to_file and sys.stdout != original_stdout:
sys.stdout = original_stdout
- def get_model_info(self):
- """
- Get basic model information as a dictionary.
- """
- total_params = sum(p.numel() for p in self.parameters())
-
- component_info = {}
- for name, param in self.named_parameters():
- component = name.split(".")[0]
- if component not in component_info:
- component_info[component] = {"params": 0, "size": 0}
- component_info[component]["params"] += 1
- component_info[component]["size"] += param.numel()
-
- return {
- "model_name": "NemotronH_Nano_VL_V2",
- "total_parameters": total_params,
- "memory_estimate_mb": total_params * 2 / (1024**2), # bfloat16
- "components": component_info,
- "config": {
- "image_size": getattr(self.config, "force_image_size", None),
- "patch_size": getattr(self.config, "patch_size", None),
- "num_image_token": self.num_image_token,
- "downsample_ratio": self.downsample_ratio,
- },
- }
-
def get_vit_model_from_radio_config(self, hf_config):
hf_config_vision = hf_config.vision_config
model_name = hf_config_vision.args.get("model")