diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 3b8a3841cf938..3bac36744ef5e 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -12,7 +12,7 @@ import math import random import warnings from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from typing import Annotated, Any, Literal, TypeAlias, TypeVar @@ -88,501 +88,9 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely # Alternative: Set a specific higher limit # Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels -IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] -IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] -SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5] -SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5] -CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] -CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] -RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060] -RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250] - -pixel_statistics = { - "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), - "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), - "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD), - "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), - "radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), -} - - -@dataclass -class DynamicResolutionParams: - media: Image.Image - num_tiles: int - num_embeddings: int - patch_size: tuple[int, int] - - -class DynamicResolutionImageTilingStrategy: - def __init__( - self, - vision_model_type: str, - min_num_patches: int, - patch_size: int, - get_num_embeddings: Callable[[int, int], int], - factor_max: float = 1.0, - pixel_shuffle: bool = False, - min_side: int | None = None, - conv_merging: bool = False, - use_thumbnail: bool = False, - thumbnail_size: int = 448, - thumbnail_area_threshold: float = 0.8, - max_num_patches: int = 0, - apply_data_augment: bool = False, - ): - assert "radio" in vision_model_type, ( - "Dynamic resolution is only supported for radio models" - ) - self._vision_model_type = vision_model_type - self._min_num_patches = min_num_patches - self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf") - self._patch_size = patch_size - self._get_num_embeddings = get_num_embeddings - self._factor_max = factor_max - self._pixel_shuffle = pixel_shuffle - self._min_side = min_side - self._conv_merging = conv_merging - self._use_thumbnail = use_thumbnail - self._thumbnail_size = thumbnail_size - self._thumbnail_area_threshold = thumbnail_area_threshold - pixel_mean, pixel_std = pixel_statistics[self._vision_model_type] - self._transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)), - T.Normalize(mean=pixel_mean, std=pixel_std), - ] - ) - self._apply_data_augment = apply_data_augment - - def apply_params( - self, params: DynamicResolutionParams, **kwargs - ) -> list[torch.Tensor]: - # resize the image - resized_img = params.media.resize( - ( - params.patch_size[0] * self._patch_size, - params.patch_size[1] * self._patch_size, - ) - ) - processed_images = [resized_img] - - # Add thumbnail if enabled and image area is below threshold - if self._use_thumbnail: - # Calculate areas - resized_area = resized_img.size[0] * resized_img.size[1] - thumbnail_area = self._thumbnail_size * self._thumbnail_size - area_ratio = resized_area / thumbnail_area - - # Only add thumbnail if resized image area is less than threshold % of - # thumbnail area - if area_ratio < self._thumbnail_area_threshold: - thumbnail_img = params.media.resize( - (self._thumbnail_size, self._thumbnail_size) - ) - processed_images.append(thumbnail_img) - - return [self._transform(img) for img in processed_images] - - def process_media( - self, - media: Image.Image, - num_tokens_available: int, - data_augment: bool = False, - tiling_augment_prob: float = 0.4, - ) -> DynamicResolutionParams: - """Process a single media item and return its parameters. - Args: - media: The media item to process - num_tokens_available: Number of tokens available for this media - data_augment: Whether to apply data augmentation to the image. Defaults to - False. - Returns: - DynamicResolutionParams for the media - """ - current_num_tokens_available = num_tokens_available - assert isinstance(media, Image.Image), ( - "Dynamic resolution is only supported for image media" - ) - orig_width, orig_height = media.width, media.height - - closest_patch_height = round(orig_height / self._patch_size + 0.5) - closest_patch_width = round(orig_width / self._patch_size + 0.5) - patches = closest_patch_height * closest_patch_width - - factor = min( - math.sqrt(current_num_tokens_available / patches), self._factor_max - ) - target_patch_height = math.floor(factor * closest_patch_height) - target_patch_width = math.floor(factor * closest_patch_width) - - # We only consider self._min_num_patches if it is greater than - # current_num_tokens_available. - if ( - current_num_tokens_available > self._min_num_patches - and target_patch_height * target_patch_width < self._min_num_patches - ): - up_factor = math.sqrt( - self._min_num_patches / (target_patch_height * target_patch_width) - ) - target_patch_height = math.ceil(up_factor * target_patch_height) - target_patch_width = math.ceil(up_factor * target_patch_width) - - if ( - self._min_side is not None - and min(target_patch_width, target_patch_height) * self._patch_size - < self._min_side - ): - if target_patch_width <= target_patch_height: - up_factor = self._min_side / (target_patch_width * self._patch_size) - new_patch_height = math.ceil(up_factor * target_patch_height) - new_patch_width = math.ceil(up_factor * target_patch_width) - - if new_patch_height * new_patch_width > current_num_tokens_available: - # If only one side can be min_side, make as big as possible at - # native aspect ratio while staying below max_patches - if ( - max(current_num_tokens_available // new_patch_width, 1) - * self._patch_size - < self._min_side - ): - up_factor = math.sqrt( - current_num_tokens_available - / (target_patch_height * target_patch_width) - ) - target_patch_height = math.floor( - up_factor * target_patch_height - ) - target_patch_width = math.floor(up_factor * target_patch_width) - target_patch_width = new_patch_width - target_patch_height = max( - current_num_tokens_available // new_patch_width, 1 - ) - else: - target_patch_height = new_patch_height - target_patch_width = new_patch_width - else: - up_factor = self._min_side / (target_patch_height * self._patch_size) - new_patch_height = math.ceil(up_factor * target_patch_height) - new_patch_width = math.ceil(up_factor * target_patch_width) - - if new_patch_height * new_patch_width > current_num_tokens_available: - # If only one side can be min_side, make as big as possible at - # native aspect ratio while staying below max_patches - if ( - max(current_num_tokens_available // new_patch_height, 1) - * self._patch_size - < self._min_side - ): - up_factor = math.sqrt( - current_num_tokens_available - / (target_patch_height * target_patch_width) - ) - target_patch_height = math.floor( - up_factor * target_patch_height - ) - target_patch_width = math.floor(up_factor * target_patch_width) - else: - target_patch_height = new_patch_height - target_patch_width = max( - current_num_tokens_available // new_patch_height, 1 - ) - else: - target_patch_height = new_patch_height - target_patch_width = new_patch_width - - # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging) - # or by 4 when BOTH are enabled (two successive 2x reductions) - if self._pixel_shuffle or self._conv_merging: - required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2 - - rem_h = target_patch_height % required_divisor - if rem_h != 0: - inc_h = required_divisor - rem_h - if ( - target_patch_height + inc_h - ) * target_patch_width <= current_num_tokens_available: - target_patch_height += inc_h - else: - target_patch_height = max( - required_divisor, target_patch_height - rem_h - ) - - rem_w = target_patch_width % required_divisor - if rem_w != 0: - inc_w = required_divisor - rem_w - if ( - target_patch_height * (target_patch_width + inc_w) - <= current_num_tokens_available - ): - target_patch_width += inc_w - else: - target_patch_width = max( - required_divisor, target_patch_width - rem_w - ) - - if ( - data_augment - and self._apply_data_augment - and random.random() < tiling_augment_prob - ): - target_patch_width, target_patch_height = self.augment_resolution( - target_patch_width, target_patch_height, current_num_tokens_available - ) - - assert isinstance(media, Image.Image), ( - "Dynamic resolution is only supported for image media" - ) - - # Calculate embeddings for the main dynamic resolution image - num_embeddings = self._get_num_embeddings( - target_patch_width * self._patch_size, - target_patch_height * self._patch_size, - ) - - token_count = target_patch_width * target_patch_height - - # Add thumbnail embeddings if enabled and image area is below threshold - num_tiles = 1 # Base dynamic resolution image - if self._use_thumbnail: - # Calculate areas - resized_area = (target_patch_width * self._patch_size) * ( - target_patch_height * self._patch_size - ) - thumbnail_area = self._thumbnail_size * self._thumbnail_size - area_ratio = resized_area / thumbnail_area - - # Only add thumbnail if resized image area is less than threshold % of - # thumbnail area - if area_ratio < self._thumbnail_area_threshold: - num_tiles += 1 # Add 1 for thumbnail - # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) - num_embeddings += self._get_num_embeddings( - self._thumbnail_size, self._thumbnail_size - ) - token_count += ( - self._thumbnail_size - // self._patch_size - * self._thumbnail_size - // self._patch_size - ) - - return DynamicResolutionParams( - media=media, - num_tiles=num_tiles, - num_embeddings=num_embeddings, - patch_size=(target_patch_width, target_patch_height), - ), token_count - - def augment_resolution( - self, - target_patch_width: int, - target_patch_height: int, - current_num_tokens_available: int, - ) -> tuple[int, int]: - min_num_patch_one_side = 32 - - if random.random() < 0.5: - # Minus one - if ( - target_patch_width <= min_num_patch_one_side - and target_patch_height <= min_num_patch_one_side - ): - return target_patch_width, target_patch_height - elif target_patch_width <= min_num_patch_one_side: - return target_patch_width, target_patch_height - min_num_patch_one_side - elif target_patch_height <= min_num_patch_one_side: - return target_patch_width - min_num_patch_one_side, target_patch_height - else: - if random.random() < 0.5: - return ( - target_patch_width - min_num_patch_one_side, - target_patch_height, - ) - else: - return ( - target_patch_width, - target_patch_height - min_num_patch_one_side, - ) - else: - # Plus one - if target_patch_width * target_patch_height < current_num_tokens_available: - if random.random() < 0.5: - return ( - target_patch_width + min_num_patch_one_side, - target_patch_height, - ) - else: - return ( - target_patch_width, - target_patch_height + min_num_patch_one_side, - ) - return target_patch_width, target_patch_height - - def compute_params( - self, - media_list: list[Image.Image], - num_tokens_available: int | None = None, - max_num_tiles: int | None = None, - data_augment: bool = False, - **kwargs, - ) -> list[DynamicResolutionParams]: - """Compute parameters for all media with iterative token budgeting. - - Args: - media_list: List of media items to process - num_tokens_available: Total number of tokens available across all media - max_num_tiles: Maximum number of tiles (unused in this implementation) - data_augment: Whether to apply data augmentation to the image. Defaults to - False. - Returns: - List of ImageTilingParams for each media item - """ - num_tokens_available = ( - num_tokens_available - * (4 if self._pixel_shuffle else 1) - * (4 if self._conv_merging else 1) - ) - # When the number of available token is too small, allow self._min_num_patches - # per media and let the sample be truncated. - num_tokens_available = max( - num_tokens_available, self._min_num_patches * len(media_list) - ) - - # Clip the number of tokens available per media to be between min and max - # patches. - num_tokens_available_per_media = [ - max(min(num_tokens_available, self._max_num_patches), self._min_num_patches) - for _ in range(len(media_list)) - ] - - # In theory this could be a while True loop, but in case the process_media - # method slightly - # changes, I want to make sure we don't get stuck in an infinite loop. - for _ in range(10): - # Step 1: Process each media with current token budget - params = [] - token_counts = [] - - for media, tokens_for_media in zip( - media_list, num_tokens_available_per_media - ): - param, token_count = self.process_media( - media, tokens_for_media, data_augment=data_augment - ) - params.append(param) - token_counts.append(token_count) - - # Step 2: Check if total tokens is within budget - total_tokens = sum(token_counts) - - if total_tokens <= num_tokens_available: - # We're within budget, return the params - return params - - # Step 3: We're over budget, need to scale down - # Calculate scaling factor to get under budget - scaling_factor = num_tokens_available / total_tokens - - # Recalculate token budgets for each media based on scaling - # Each media gets a proportional share of the total budget - scaled_down_num_tokens_available_per_media = [ - max(self._min_num_patches, int(token_count * scaling_factor)) - for token_count in token_counts - ] - scaled_down = any( - [ - scaled_down_num_tokens_available_per_media[i] - < num_tokens_available_per_media[i] - for i in range(len(num_tokens_available_per_media)) - ] - ) - # If there was not scaling down, we're stuck just use min_num_patches per - # media, else try with the scaled down num_tokens_available_per_media. - if not scaled_down: - num_tokens_available_per_media = [self._min_num_patches] * len( - media_list - ) - else: - num_tokens_available_per_media = ( - scaled_down_num_tokens_available_per_media - ) - return params - - def stack( - self, images: list[torch.Tensor] - ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: - imgs_sizes = torch.tensor( - [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 - ) - - def rearrange_img(x): - py = x.shape[-2] // self._patch_size - px = x.shape[-1] // self._patch_size - x = einops.rearrange( - x, - "c (py yy) (px xx) -> (py px) (c yy xx)", - py=py, - yy=self._patch_size, - px=px, - xx=self._patch_size, - ) - return x - - if len(images) > 0: - imgs = [rearrange_img(img) for img in images] - - current_length = 0 - max_length = 0 - vision_cu_lengths = [0] - for img in imgs: - if max_length < img.shape[0]: - max_length = img.shape[0] - current_length += img.shape[0] - vision_cu_lengths.append(current_length) - - vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) - vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) - - return ( - torch.cat(imgs, dim=0).unsqueeze(0), - imgs_sizes, - vision_cu_lengths, - vision_max_lengths, - ) - else: - return ( - torch.tensor([[0]], dtype=torch.float32), - torch.tensor([[0, 0]], dtype=torch.int32), - None, - None, - ) - - def __str__(self): - return f"DynamicResolutionImageTransform(\ - vision_model_type={self._vision_model_type}, \ - min_num_patches={self._min_num_patches}, \ - patch_size={self._patch_size}, \ - pixel_shuffle={self._pixel_shuffle}, \ - conv_merging={self._conv_merging}, \ - use_thumbnail={self._use_thumbnail}, \ - thumbnail_size={self._thumbnail_size}, \ - thumbnail_area_threshold={self._thumbnail_area_threshold})" - - -image_tiling_strategy = DynamicResolutionImageTilingStrategy( - vision_model_type="radio", - min_num_patches=4, - patch_size=16, - get_num_embeddings=lambda x, y: x * y * 2, - max_num_patches=64, -) +# TODO(nhaber): get 2048 from config +# TODO(nhaber): does use_thumbnail=True work? IMG_START = "" @@ -753,7 +261,12 @@ def video_to_pixel_values( return torch.stack(frames_tensors) -def input_conditioner(x, norm_mean, norm_std): +def input_conditioner( + x: torch.Tensor, norm_mean: torch.Tensor, norm_std: torch.Tensor +) -> torch.Tensor: + assert isinstance(x, torch.Tensor), "x must be a tensor" + assert isinstance(norm_mean, torch.Tensor), "norm_mean must be a tensor" + assert isinstance(norm_std, torch.Tensor), "norm_std must be a tensor" return (x - norm_mean) / norm_std @@ -792,15 +305,20 @@ class BaseNanoNemotronVLProcessor(ABC): self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES image_size: int = config.force_image_size - patch_size: int = config.patch_size + self.patch_size: int = getattr(config, "patch_size", 16) + self.downsample_ratio: float = self.config.downsample_ratio - self.num_image_token = int( - (image_size // patch_size) ** 2 * (config.downsample_ratio**2) - ) self.image_size = image_size self.use_thumbnail: bool = config.use_thumbnail - self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1) - self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1) + self.norm_mean = torch.tensor(config.norm_mean).reshape(1, 3, 1, 1) + self.norm_std = torch.tensor(config.norm_std).reshape(1, 3, 1, 1) + + def num_image_token(self, *, image_width: int, image_height: int) -> int: + image_size = math.sqrt(image_width * image_height) + num_tokens = int( + (image_size // self.patch_size) ** 2 * (self.downsample_ratio**2) + ) + return num_tokens @property @abstractmethod @@ -832,10 +350,13 @@ class BaseNanoNemotronVLProcessor(ABC): use_thumbnail=self.use_thumbnail, ) - return num_patches * self.num_image_token + return num_patches * self.num_image_token( + image_width=image_width, image_height=image_height + ) def _images_to_pixel_values_lst( self, + text: list[str], images: list[Image.Image], max_num_tiles: int, ) -> list[torch.Tensor]: @@ -859,7 +380,9 @@ class BaseNanoNemotronVLProcessor(ABC): if len(images) == 0: image_inputs = {} else: - pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles) + pixel_values_lst = self._images_to_pixel_values_lst( + text=text, images=images, max_num_tiles=max_num_tiles + ) image_inputs = { "pixel_values_flat": input_conditioner( torch.cat(pixel_values_lst), self.norm_mean, self.norm_std @@ -881,7 +404,10 @@ class BaseNanoNemotronVLProcessor(ABC): for i, pixel_values in enumerate(pixel_values_lst): num_patches = pixel_values.shape[0] - feature_size = num_patches * self.num_image_token + feature_size = num_patches * self.num_image_token( + image_width=pixel_values.shape[1], + image_height=pixel_values.shape[2], + ) image_repl = self.get_image_repl(feature_size, num_patches) parts[i] = parts[i].replace("", image_repl.full) text = ["".join(parts)] @@ -894,6 +420,7 @@ class BaseNanoNemotronVLProcessor(ABC): input_item = [input_item] return input_item + @abstractmethod def __call__( self, text: str | list[str] | None = None, @@ -901,26 +428,487 @@ class BaseNanoNemotronVLProcessor(ABC): return_tensors: str | TensorType | None = None, max_num_tiles: int | None = None, ) -> BatchFeature: - # Use default if not provided - if max_num_tiles is None: - max_num_tiles = self.max_num_tiles + raise NotImplementedError - text, images = [self._make_batch_input(x) for x in (text, images)] - text, image_inputs = self._preprocess_image( - text=text, - images=images, - max_num_tiles=max_num_tiles, +@dataclass +class DynamicResolutionParams: + media: Image.Image + num_tiles: int + num_embeddings: int + patch_size: tuple[int, int] + + +class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): + CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] + CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] + + def __init__( + self, + config: PretrainedConfig, + tokenizer: TokenizerLike, + *args, + max_num_tiles: int | None = None, + min_num_patches: int = 4, + factor_max: float = 1.0, + pixel_shuffle: bool = True, + min_side: int | None = None, + conv_merging: bool = False, + use_thumbnail: bool = False, + thumbnail_size: int = 448, + thumbnail_area_threshold: float = 0.8, + apply_data_augment: bool = False, + **kwargs, + ) -> None: + super().__init__( + config=config, tokenizer=tokenizer, max_num_tiles=max_num_tiles, **kwargs + ) + self._min_num_patches = min_num_patches + self._factor_max = factor_max + self._pixel_shuffle = pixel_shuffle + self._min_side = min_side + self._conv_merging = conv_merging + self._use_thumbnail = use_thumbnail + self._thumbnail_size = thumbnail_size + self._thumbnail_area_threshold = thumbnail_area_threshold + self._transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)), + ] + ) + self._apply_data_augment = apply_data_augment + + self.norm_mean = torch.tensor(self.CLIP_PIXEL_MEAN).reshape(1, 3, 1, 1) + self.norm_std = torch.tensor(self.CLIP_PIXEL_STD).reshape(1, 3, 1, 1) + self.downsample_ratio = 2 if pixel_shuffle else 1 + + def apply_params(self, params: DynamicResolutionParams) -> torch.Tensor: + resized_img = params.media.resize( + ( + params.patch_size[0] * self.patch_size, + params.patch_size[1] * self.patch_size, + ) + ) + # processed_images = [resized_img] + + # # Add thumbnail if enabled and image area is below threshold + # if self._use_thumbnail: + # # Calculate areas + # resized_area = resized_img.size[0] * resized_img.size[1] + # thumbnail_area = self._thumbnail_size * self._thumbnail_size + # area_ratio = resized_area / thumbnail_area + + # # Only add thumbnail if resized image area is less than threshold % of + # # thumbnail area + # if area_ratio < self._thumbnail_area_threshold: + # thumbnail_img = params.media.resize( + # (self._thumbnail_size, self._thumbnail_size) + # ) + # processed_images.append(thumbnail_img) + + return self._transform(resized_img) + + def process_media( + self, + media: Image.Image, + num_tokens_available: int, + data_augment: bool = False, + tiling_augment_prob: float = 0.4, + ) -> DynamicResolutionParams: + """Process a single media item and return its parameters. + Args: + media: The media item to process + num_tokens_available: Number of tokens available for this media + data_augment: Whether to apply data augmentation to the image. Defaults to + False. + Returns: + DynamicResolutionParams for the media + """ + current_num_tokens_available = num_tokens_available + assert isinstance(media, Image.Image), ( + "Dynamic resolution is only supported for image media" + ) + orig_width, orig_height = media.width, media.height + + closest_patch_height = round(orig_height / self.patch_size + 0.5) + closest_patch_width = round(orig_width / self.patch_size + 0.5) + patches = closest_patch_height * closest_patch_width + + factor = min( + math.sqrt(current_num_tokens_available / patches), self._factor_max + ) + target_patch_height = math.floor(factor * closest_patch_height) + target_patch_width = math.floor(factor * closest_patch_width) + + # We only consider self._min_num_patches if it is greater than + # current_num_tokens_available. + if ( + current_num_tokens_available > self._min_num_patches + and target_patch_height * target_patch_width < self._min_num_patches + ): + up_factor = math.sqrt( + self._min_num_patches / (target_patch_height * target_patch_width) + ) + target_patch_height = math.ceil(up_factor * target_patch_height) + target_patch_width = math.ceil(up_factor * target_patch_width) + + if ( + self._min_side is not None + and min(target_patch_width, target_patch_height) * self.patch_size + < self._min_side + ): + if target_patch_width <= target_patch_height: + up_factor = self._min_side / (target_patch_width * self.patch_size) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > current_num_tokens_available: + # If only one side can be min_side, make as big as possible at + # native aspect ratio while staying below max_patches + if ( + max(current_num_tokens_available // new_patch_width, 1) + * self.patch_size + < self._min_side + ): + up_factor = math.sqrt( + current_num_tokens_available + / (target_patch_height * target_patch_width) + ) + target_patch_height = math.floor( + up_factor * target_patch_height + ) + target_patch_width = math.floor(up_factor * target_patch_width) + target_patch_width = new_patch_width + target_patch_height = max( + current_num_tokens_available // new_patch_width, 1 + ) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width + else: + up_factor = self._min_side / (target_patch_height * self.patch_size) + new_patch_height = math.ceil(up_factor * target_patch_height) + new_patch_width = math.ceil(up_factor * target_patch_width) + + if new_patch_height * new_patch_width > current_num_tokens_available: + # If only one side can be min_side, make as big as possible at + # native aspect ratio while staying below max_patches + if ( + max(current_num_tokens_available // new_patch_height, 1) + * self.patch_size + < self._min_side + ): + up_factor = math.sqrt( + current_num_tokens_available + / (target_patch_height * target_patch_width) + ) + target_patch_height = math.floor( + up_factor * target_patch_height + ) + target_patch_width = math.floor(up_factor * target_patch_width) + else: + target_patch_height = new_patch_height + target_patch_width = max( + current_num_tokens_available // new_patch_height, 1 + ) + else: + target_patch_height = new_patch_height + target_patch_width = new_patch_width + + # Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging) + # or by 4 when BOTH are enabled (two successive 2x reductions) + if self._pixel_shuffle or self._conv_merging: + required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2 + + rem_h = target_patch_height % required_divisor + if rem_h != 0: + inc_h = required_divisor - rem_h + if ( + target_patch_height + inc_h + ) * target_patch_width <= current_num_tokens_available: + target_patch_height += inc_h + else: + target_patch_height = max( + required_divisor, target_patch_height - rem_h + ) + + rem_w = target_patch_width % required_divisor + if rem_w != 0: + inc_w = required_divisor - rem_w + if ( + target_patch_height * (target_patch_width + inc_w) + <= current_num_tokens_available + ): + target_patch_width += inc_w + else: + target_patch_width = max( + required_divisor, target_patch_width - rem_w + ) + + if ( + data_augment + and self._apply_data_augment + and random.random() < tiling_augment_prob + ): + target_patch_width, target_patch_height = self.augment_resolution( + target_patch_width, target_patch_height, current_num_tokens_available + ) + + assert isinstance(media, Image.Image), ( + "Dynamic resolution is only supported for image media" ) - text_inputs = self.tokenizer(text, add_special_tokens=False) + # Calculate embeddings for the main dynamic resolution image + num_embeddings = self.num_image_token( + image_width=target_patch_width, image_height=target_patch_height + ) - combined_outputs = {**text_inputs, **image_inputs} + token_count = target_patch_width * target_patch_height - return BatchFeature(combined_outputs, tensor_type=return_tensors) + # Add thumbnail embeddings if enabled and image area is below threshold + num_tiles = 1 # Base dynamic resolution image + if self._use_thumbnail: + # Calculate areas + resized_area = (target_patch_width * self.patch_size) * ( + target_patch_height * self.patch_size + ) + thumbnail_area = self._thumbnail_size * self._thumbnail_size + area_ratio = resized_area / thumbnail_area + + # Only add thumbnail if resized image area is less than threshold % of + # thumbnail area + if area_ratio < self._thumbnail_area_threshold: + num_tiles += 1 # Add 1 for thumbnail + # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) + num_embeddings += self.num_image_token( + image_width=self._thumbnail_size, image_height=self._thumbnail_size + ) + token_count += ( + self._thumbnail_size + // self.patch_size + * self._thumbnail_size + // self.patch_size + ) + + return DynamicResolutionParams( + media=media, + num_tiles=num_tiles, + num_embeddings=num_embeddings, + patch_size=(target_patch_width, target_patch_height), + ), token_count + + def augment_resolution( + self, + target_patch_width: int, + target_patch_height: int, + current_num_tokens_available: int, + ) -> tuple[int, int]: + min_num_patch_one_side = 32 + + if random.random() < 0.5: + # Minus one + if ( + target_patch_width <= min_num_patch_one_side + and target_patch_height <= min_num_patch_one_side + ): + return target_patch_width, target_patch_height + elif target_patch_width <= min_num_patch_one_side: + return target_patch_width, target_patch_height - min_num_patch_one_side + elif target_patch_height <= min_num_patch_one_side: + return target_patch_width - min_num_patch_one_side, target_patch_height + else: + if random.random() < 0.5: + return ( + target_patch_width - min_num_patch_one_side, + target_patch_height, + ) + else: + return ( + target_patch_width, + target_patch_height - min_num_patch_one_side, + ) + else: + # Plus one + if target_patch_width * target_patch_height < current_num_tokens_available: + if random.random() < 0.5: + return ( + target_patch_width + min_num_patch_one_side, + target_patch_height, + ) + else: + return ( + target_patch_width, + target_patch_height + min_num_patch_one_side, + ) + return target_patch_width, target_patch_height + + def compute_params( + self, + media_list: list[Image.Image], + num_tokens_available: int | None = None, + data_augment: bool = False, + ) -> list[DynamicResolutionParams]: + """Compute parameters for all media with iterative token budgeting. + + Args: + media_list: List of media items to process + num_tokens_available: Total number of tokens available across all media + data_augment: Whether to apply data augmentation to the image. Defaults to + False. + Returns: + List of ImageTilingParams for each media item + """ + num_tokens_available = ( + num_tokens_available + * (4 if self._pixel_shuffle else 1) + * (4 if self._conv_merging else 1) + ) + # When the number of available token is too small, allow self._min_num_patches + # per media and let the sample be truncated. + num_tokens_available = max( + num_tokens_available, self._min_num_patches * len(media_list) + ) + + # Clip the number of tokens available per media to be between min and max + # patches. + num_tokens_available_per_media = [ + max(num_tokens_available, self._min_num_patches) + for _ in range(len(media_list)) + ] + + # In theory this could be a while True loop, but in case the process_media + # method slightly + # changes, I want to make sure we don't get stuck in an infinite loop. + for _ in range(10): + # Step 1: Process each media with current token budget + params = [] + token_counts = [] + + for media, tokens_for_media in zip( + media_list, num_tokens_available_per_media + ): + param, token_count = self.process_media( + media, tokens_for_media, data_augment=data_augment + ) + params.append(param) + token_counts.append(token_count) + + # Step 2: Check if total tokens is within budget + total_tokens = sum(token_counts) + + if total_tokens <= num_tokens_available: + # We're within budget, return the params + return params + + # Step 3: We're over budget, need to scale down + # Calculate scaling factor to get under budget + scaling_factor = num_tokens_available / total_tokens + + # Recalculate token budgets for each media based on scaling + # Each media gets a proportional share of the total budget + scaled_down_num_tokens_available_per_media = [ + max(self._min_num_patches, int(token_count * scaling_factor)) + for token_count in token_counts + ] + scaled_down = any( + [ + scaled_down_num_tokens_available_per_media[i] + < num_tokens_available_per_media[i] + for i in range(len(num_tokens_available_per_media)) + ] + ) + # If there was not scaling down, we're stuck just use min_num_patches per + # media, else try with the scaled down num_tokens_available_per_media. + if not scaled_down: + num_tokens_available_per_media = [self._min_num_patches] * len( + media_list + ) + else: + num_tokens_available_per_media = ( + scaled_down_num_tokens_available_per_media + ) + return params + + def stack( + self, images: list[torch.Tensor] + ) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]: + imgs_sizes = torch.tensor( + [[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32 + ) + + def rearrange_img(x): + py = x.shape[-2] // self.patch_size + px = x.shape[-1] // self.patch_size + x = einops.rearrange( + x, + "c (py yy) (px xx) -> (py px) (c yy xx)", + py=py, + yy=self.patch_size, + px=px, + xx=self.patch_size, + ) + return x + + if len(images) > 0: + imgs = [rearrange_img(img) for img in images] + + current_length = 0 + max_length = 0 + vision_cu_lengths = [0] + for img in imgs: + if max_length < img.shape[0]: + max_length = img.shape[0] + current_length += img.shape[0] + vision_cu_lengths.append(current_length) + + vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32) + vision_max_lengths = torch.tensor(max_length, dtype=torch.int32) + + return ( + torch.cat(imgs, dim=0).unsqueeze(0), + imgs_sizes, + vision_cu_lengths, + vision_max_lengths, + ) + else: + return ( + torch.tensor([[0]], dtype=torch.float32), + torch.tensor([[0, 0]], dtype=torch.int32), + None, + None, + ) + + def _images_to_pixel_values_lst( + self, + text: list[str], + images: list[Image.Image], + max_num_tiles: int, + ) -> list[torch.Tensor]: + num_tokens_available = 2048 - len(text) - 4 + params_per_image = self.compute_params( + images, num_tokens_available=num_tokens_available + ) + images = [] + for param in params_per_image: + t = self.apply_params(param) + if t.ndim == 3: + t = t.unsqueeze(0) + images.append(t) + return images + + def __str__(self): + return f"DynamicResolutionImageTransform(\ + min_num_patches={self._min_num_patches}, \ + patch_size={self.patch_size}, \ + pixel_shuffle={self._pixel_shuffle}, \ + conv_merging={self._conv_merging}, \ + use_thumbnail={self._use_thumbnail}, \ + thumbnail_size={self._thumbnail_size}, \ + thumbnail_area_threshold={self._thumbnail_area_threshold})" -class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): +class NanoNemotronVLProcessor(DynamicResolutionImageTiler): """ HF Processor with extended video processing logic. Code for video processing is adapted from video example: @@ -1312,7 +1300,9 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): processor = self.get_hf_processor() # we get the CustomProcessor here max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token + max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token( + image_width=256, image_height=256 + ) # TODO(nhaber): get 256 dynamically max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) @@ -1457,7 +1447,9 @@ class NanoNemotronVLMultiModalProcessor( video_num_patches = [] def get_video_replacement_internvl(item_idx: int): - feature_size = hf_processor.num_image_token + feature_size = hf_processor.num_image_token( + image_width=256, image_height=256 + ) # TODO(nhaber): get 256 dynamically video, metadata = mm_items["video"][item_idx] num_patches = video_num_patches[item_idx] if num_patches is not None: @@ -1633,9 +1625,6 @@ class NemotronH_Nano_VL_V2( patch_size = config.patch_size self.patch_size = patch_size self.template = config.template - self.num_image_token = int( - (image_size // patch_size) ** 2 * (config.downsample_ratio**2) - ) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version self.image_tag_type = config.image_tag_type @@ -2153,33 +2142,6 @@ class NemotronH_Nano_VL_V2( if save_to_file and sys.stdout != original_stdout: sys.stdout = original_stdout - def get_model_info(self): - """ - Get basic model information as a dictionary. - """ - total_params = sum(p.numel() for p in self.parameters()) - - component_info = {} - for name, param in self.named_parameters(): - component = name.split(".")[0] - if component not in component_info: - component_info[component] = {"params": 0, "size": 0} - component_info[component]["params"] += 1 - component_info[component]["size"] += param.numel() - - return { - "model_name": "NemotronH_Nano_VL_V2", - "total_parameters": total_params, - "memory_estimate_mb": total_params * 2 / (1024**2), # bfloat16 - "components": component_info, - "config": { - "image_size": getattr(self.config, "force_image_size", None), - "patch_size": getattr(self.config, "patch_size", None), - "num_image_token": self.num_image_token, - "downsample_ratio": self.downsample_ratio, - }, - } - def get_vit_model_from_radio_config(self, hf_config): hf_config_vision = hf_config.vision_config model_name = hf_config_vision.args.get("model")