diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 3bac36744ef5e..ad5a57c511fe8 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -10,7 +10,6 @@ import copy import math import random -import warnings from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass @@ -24,6 +23,7 @@ import torch.nn as nn import torchvision.transforms as T from PIL import Image from transformers import BatchFeature, PretrainedConfig, TensorType +from typing_extensions import assert_never from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions @@ -62,7 +62,6 @@ from vllm.multimodal.inputs import ( from vllm.multimodal.parse import ( ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems, MultiModalDataParser, ) @@ -91,6 +90,7 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely # TODO(nhaber): get 2048 from config # TODO(nhaber): does use_thumbnail=True work? +# TODO(nhaber): mixing images and videos will mess up the "text_prompt_length" calculation. IMG_START = "" @@ -102,6 +102,46 @@ IMG_CONTEXT = "" DEFAULT_NUM_TILES = 12 +@dataclass(kw_only=True, frozen=True) +class Dims: + height: int + width: int + + +CONV_MERGING = False # This is assumed to be False for now +PIXEL_SHUFFLE = True # This is assumed to be True for now +REDUCTION_FACTOR = 2 ** (PIXEL_SHUFFLE + CONV_MERGING) + +def width_and_height_for_max_num_tokens_available( + *, + target_num_tokens_post_shuffle: int, + patch_size: int, +) -> Dims: + """ + TODO(nhaber): optimize this so it squeezes closer to target number of tokens. + Calculate image dimensions that produce approximately `target` tokens after + pixel_shuffle. + + With pixel_shuffle enabled, each 2x2 patch grid becomes 1 token, so we + need 4*B patches to get B tokens. + + Examples: + >>> dims = width_and_height_for_max_num_tokens_available(B=8192, patch_size=16) + >>> assert dims.width, dims.height == (2880, 2880) + >>> assert ((dims.width // 16) * (dims.height // 16) // 4) == 8100 # tokens after shuffle + """ + side_pixels = math.isqrt(target_num_tokens_post_shuffle) * REDUCTION_FACTOR * patch_size + assert isinstance(side_pixels, int) and side_pixels % patch_size == 0 + return Dims(width=side_pixels, height=side_pixels) + +@dataclass +class DynamicResolutionParams: + media: Image.Image + num_tiles: int + num_embeddings: int + patch_size: tuple[int, int] + + class NanoNemotronVLImagePixelInputs(TensorSchema): """ Dimensions: @@ -313,10 +353,10 @@ class BaseNanoNemotronVLProcessor(ABC): self.norm_mean = torch.tensor(config.norm_mean).reshape(1, 3, 1, 1) self.norm_std = torch.tensor(config.norm_std).reshape(1, 3, 1, 1) - def num_image_token(self, *, image_width: int, image_height: int) -> int: - image_size = math.sqrt(image_width * image_height) + def num_image_token_per_tile(self, *, tile_width: int, tile_height: int) -> int: + tile_size = math.sqrt(tile_width * tile_height) num_tokens = int( - (image_size // self.patch_size) ** 2 * (self.downsample_ratio**2) + (tile_size // self.patch_size) ** 2 * (self.downsample_ratio**2) ) return num_tokens @@ -342,7 +382,7 @@ class BaseNanoNemotronVLProcessor(ABC): ) -> int: target_ratios = get_internvl_target_ratios(1, max_num_tiles) - num_patches, _, _ = calculate_internvl_targets( + num_tiles, _, _ = calculate_internvl_targets( orig_width=image_width, orig_height=image_height, target_ratios=target_ratios, @@ -350,16 +390,16 @@ class BaseNanoNemotronVLProcessor(ABC): use_thumbnail=self.use_thumbnail, ) - return num_patches * self.num_image_token( - image_width=image_width, image_height=image_height + return num_tiles * self.num_image_token_per_tile( + tile_width=image_width, tile_height=image_height ) def _images_to_pixel_values_lst( self, - text: list[str], + text_prompt_length: int, images: list[Image.Image], max_num_tiles: int, - ) -> list[torch.Tensor]: + ) -> tuple[list[torch.Tensor], list[int]]: return [ image_to_pixel_values( image, @@ -380,8 +420,19 @@ class BaseNanoNemotronVLProcessor(ABC): if len(images) == 0: image_inputs = {} else: - pixel_values_lst = self._images_to_pixel_values_lst( - text=text, images=images, max_num_tiles=max_num_tiles + assert len(text) == 1, ( + "hf_processor is called on the output of get_dummy_text, " + "which should be a single string" + ) + text_prompt_length = len( + self.tokenizer( + text[0].replace("", ""), add_special_tokens=False + )["input_ids"] + ) + pixel_values_lst, token_counts = self._images_to_pixel_values_lst( + text_prompt_length=text_prompt_length, + images=images, + max_num_tiles=max_num_tiles, ) image_inputs = { "pixel_values_flat": input_conditioner( @@ -402,12 +453,10 @@ class BaseNanoNemotronVLProcessor(ABC): "same as the number of images" ) - for i, pixel_values in enumerate(pixel_values_lst): + for i, (pixel_values, feature_size) in enumerate( + zip(pixel_values_lst, token_counts, strict=True) + ): num_patches = pixel_values.shape[0] - feature_size = num_patches * self.num_image_token( - image_width=pixel_values.shape[1], - image_height=pixel_values.shape[2], - ) image_repl = self.get_image_repl(feature_size, num_patches) parts[i] = parts[i].replace("", image_repl.full) text = ["".join(parts)] @@ -431,14 +480,6 @@ class BaseNanoNemotronVLProcessor(ABC): raise NotImplementedError -@dataclass -class DynamicResolutionParams: - media: Image.Image - num_tiles: int - num_embeddings: int - patch_size: tuple[int, int] - - class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] @@ -448,6 +489,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): config: PretrainedConfig, tokenizer: TokenizerLike, *args, + max_model_len: int, max_num_tiles: int | None = None, min_num_patches: int = 4, factor_max: float = 1.0, @@ -463,6 +505,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): super().__init__( config=config, tokenizer=tokenizer, max_num_tiles=max_num_tiles, **kwargs ) + self.max_model_len = max_model_len + self._min_num_patches = min_num_patches self._factor_max = factor_max self._pixel_shuffle = pixel_shuffle @@ -483,6 +527,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): self.norm_std = torch.tensor(self.CLIP_PIXEL_STD).reshape(1, 3, 1, 1) self.downsample_ratio = 2 if pixel_shuffle else 1 + feature_size_cache: dict[ + Image.Image, int + ] = {} # TODO(nhaber): Find a less silly way of doing this... Why can't this be a class variable? + def apply_params(self, params: DynamicResolutionParams) -> torch.Tensor: resized_img = params.media.resize( ( @@ -515,7 +563,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): num_tokens_available: int, data_augment: bool = False, tiling_augment_prob: float = 0.4, - ) -> DynamicResolutionParams: + ) -> tuple[DynamicResolutionParams, int]: """Process a single media item and return its parameters. Args: media: The media item to process @@ -531,8 +579,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): ) orig_width, orig_height = media.width, media.height - closest_patch_height = round(orig_height / self.patch_size + 0.5) - closest_patch_width = round(orig_width / self.patch_size + 0.5) + closest_patch_height = math.ceil( + orig_height / self.patch_size + ) # TODO(nhaber): Ask Tyler - the previous round + 0.5 code is dangerous [banker's rounding], no? If we flip this back to the round, the max_wh_fill_budget needs to do -1 for each of w;h to be safe + closest_patch_width = math.ceil(orig_width / self.patch_size) patches = closest_patch_height * closest_patch_width factor = min( @@ -660,8 +710,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): ) # Calculate embeddings for the main dynamic resolution image - num_embeddings = self.num_image_token( - image_width=target_patch_width, image_height=target_patch_height + num_embeddings_per_tile = self.num_image_token_per_tile( + tile_width=target_patch_width, tile_height=target_patch_height ) token_count = target_patch_width * target_patch_height @@ -681,8 +731,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): if area_ratio < self._thumbnail_area_threshold: num_tiles += 1 # Add 1 for thumbnail # Add embeddings for thumbnail (thumbnail_size x thumbnail_size) - num_embeddings += self.num_image_token( - image_width=self._thumbnail_size, image_height=self._thumbnail_size + num_embeddings += self.num_image_token_per_tile( + tile_width=self._thumbnail_size, tile_height=self._thumbnail_size ) token_count += ( self._thumbnail_size @@ -694,7 +744,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): return DynamicResolutionParams( media=media, num_tiles=num_tiles, - num_embeddings=num_embeddings, + num_embeddings=num_embeddings_per_tile, patch_size=(target_patch_width, target_patch_height), ), token_count @@ -748,7 +798,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): media_list: list[Image.Image], num_tokens_available: int | None = None, data_augment: bool = False, - ) -> list[DynamicResolutionParams]: + ) -> tuple[list[DynamicResolutionParams], list[int]]: """Compute parameters for all media with iterative token budgeting. Args: @@ -782,8 +832,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): # changes, I want to make sure we don't get stuck in an infinite loop. for _ in range(10): # Step 1: Process each media with current token budget - params = [] - token_counts = [] + params: list[DynamicResolutionParams] = [] + token_counts: list[int] = [] for media, tokens_for_media in zip( media_list, num_tokens_available_per_media @@ -799,7 +849,12 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): if total_tokens <= num_tokens_available: # We're within budget, return the params - return params + # Convert from patch count to actual token count after downsampling + divisor = (4 if self._pixel_shuffle else 1) * (4 if self._conv_merging else 1) + adjusted_token_counts = [tc // divisor for tc in token_counts] + for param, feature_size in zip(params, adjusted_token_counts, strict=True): + self.feature_size_cache[id(param.media)] = feature_size + return params, adjusted_token_counts # Step 3: We're over budget, need to scale down # Calculate scaling factor to get under budget @@ -828,7 +883,7 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): num_tokens_available_per_media = ( scaled_down_num_tokens_available_per_media ) - return params + assert_never(num_tokens_available_per_media) def stack( self, images: list[torch.Tensor] @@ -879,15 +934,18 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): None, ) + def max_num_tokens_available(self, text_prompt_length: int) -> int: + return self.max_model_len - text_prompt_length - 4 + def _images_to_pixel_values_lst( self, - text: list[str], + text_prompt_length: int, images: list[Image.Image], max_num_tiles: int, - ) -> list[torch.Tensor]: - num_tokens_available = 2048 - len(text) - 4 - params_per_image = self.compute_params( - images, num_tokens_available=num_tokens_available + ) -> tuple[list[torch.Tensor], list[int]]: + num_tokens_available = self.max_num_tokens_available(text_prompt_length) + params_per_image, feature_sizes = self.compute_params( + images, num_tokens_available ) images = [] for param in params_per_image: @@ -895,17 +953,12 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor): if t.ndim == 3: t = t.unsqueeze(0) images.append(t) - return images + return images, feature_sizes - def __str__(self): - return f"DynamicResolutionImageTransform(\ - min_num_patches={self._min_num_patches}, \ - patch_size={self.patch_size}, \ - pixel_shuffle={self._pixel_shuffle}, \ - conv_merging={self._conv_merging}, \ - use_thumbnail={self._use_thumbnail}, \ - thumbnail_size={self._thumbnail_size}, \ - thumbnail_area_threshold={self._thumbnail_area_threshold})" + def get_cached_feature_size(self, image: Image.Image) -> int: + feature_size = self.feature_size_cache[id(image)] + del self.feature_size_cache[id(image)] + return feature_size class NanoNemotronVLProcessor(DynamicResolutionImageTiler): @@ -920,6 +973,7 @@ class NanoNemotronVLProcessor(DynamicResolutionImageTiler): config: PretrainedConfig, tokenizer: TokenizerLike, *, + max_model_len: int, max_num_tiles: int | None = None, min_dynamic_patch: int | None = None, max_dynamic_patch: int | None = None, @@ -930,6 +984,7 @@ class NanoNemotronVLProcessor(DynamicResolutionImageTiler): super().__init__( config=config, tokenizer=tokenizer, + max_model_len=max_model_len, max_num_tiles=max_num_tiles, min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, @@ -1205,7 +1260,7 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): def get_hf_processor( self, **kwargs: object, - ) -> BaseNanoNemotronVLProcessor: + ) -> DynamicResolutionImageTiler: raise NotImplementedError def get_supported_mm_limits(self) -> Mapping[str, int | None]: @@ -1228,31 +1283,6 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): max_num_tiles=max_num_tiles, ) - def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize: - processor = self.get_hf_processor() - - base_size = processor.image_size - target_ratios = get_internvl_target_ratios(1, max_num_tiles) - - largest_feature_size, largest_feature_pinpoint = 0, None - for wr, hr in target_ratios: - width, height = base_size * wr, base_size * hr - - feat_size = self.get_num_image_tokens( - image_width=width, - image_height=height, - max_num_tiles=max_num_tiles, - processor=processor, - ) - if feat_size > largest_feature_size: - largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, height=height) - - if largest_feature_size == 0 or largest_feature_pinpoint is None: - raise ValueError("Cannot have a largest feature size of 0!") - - return largest_feature_pinpoint - def get_max_image_tokens(self) -> int: processor = self.get_hf_processor() # Use default max_num_tiles for max tokens calculation @@ -1277,7 +1307,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): @property def supports_video(self): - return self.get_hf_processor().supports_video + return False # TODO(nhaber): add video support def get_supported_mm_limits(self): video_limit = {"video": None} if self.supports_video else {} @@ -1300,8 +1330,10 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): processor = self.get_hf_processor() # we get the CustomProcessor here max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token( - image_width=256, image_height=256 + max_total_frames = ( + seq_len - max_image_tokens + ) // processor.num_image_token_per_tile( + tile_width=256, tile_height=256 ) # TODO(nhaber): get 256 dynamically max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) @@ -1313,6 +1345,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): tokenizer=self.get_tokenizer(), video_token=self.get_video_token(), video_pruning_rate=self.get_video_pruning_rate(), + max_model_len=self.ctx.model_config.max_model_len, **kwargs, ) @@ -1362,17 +1395,8 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) else: - image_size = images.get_image_size(item_idx) - # Extract max_num_tiles from kwargs, default to 12 - max_num_tiles = hf_processor_mm_kwargs.get( - "max_num_tiles", hf_processor.max_num_tiles - ) - feature_size = self.info.get_num_image_tokens( - image_width=image_size.width, - image_height=image_size.height, - max_num_tiles=max_num_tiles, - processor=hf_processor, - ) + image = images.get(item_idx) + feature_size = hf_processor.get_cached_feature_size(image) num_patches = None local_image_num_patches = image_num_patches @@ -1447,8 +1471,8 @@ class NanoNemotronVLMultiModalProcessor( video_num_patches = [] def get_video_replacement_internvl(item_idx: int): - feature_size = hf_processor.num_image_token( - image_width=256, image_height=256 + feature_size = hf_processor.num_image_token_per_tile( + tile_width=256, tile_height=256 ) # TODO(nhaber): get 256 dynamically video, metadata = mm_items["video"][item_idx] num_patches = video_num_patches[item_idx] @@ -1510,19 +1534,20 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: - # Use default max_num_tiles for dummy data generation - max_num_tiles = 12 - target_width, target_height = self.info.get_image_size_with_most_features( - max_num_tiles - ) num_images = mm_counts.get("image", 0) + processor = self.info.get_hf_processor() + B = processor.max_num_tokens_available(text_prompt_length=num_images) + target_dims = width_and_height_for_max_num_tokens_available( + target_num_tokens_post_shuffle=B, + patch_size=processor.patch_size, + ) image_overrides = mm_options.get("image") if mm_options else None return { "image": self._get_dummy_images( - width=target_width, - height=target_height, + width=target_dims.width, + height=target_dims.height, num_images=num_images, overrides=image_overrides, ) @@ -1672,33 +1697,36 @@ class NemotronH_Nano_VL_V2( IMG_CONTEXT, add_special_tokens=False ) - def pixel_shuffle(self, x, scale_factor=0.5): - n, w, h, c = x.size() - # N, W, H, C --> N, W, H * scale, C // scale - x = x.view( - n, - w, - int(h * scale_factor), - int(c / scale_factor), - ) - # N, W, H * scale, C // scale --> N, H * scale, W, C // scale - x = x.permute(0, 2, 1, 3).contiguous() - # N, H * scale, W, C // scale --> - # N, H * scale, W * scale, C // (scale ** 2) - x = x.view( - n, - int(h * scale_factor), - int(w * scale_factor), - int(c / (scale_factor * scale_factor)), - ) - if self.ps_version == "v1": - warnings.warn( - "In ps_version 'v1', the height and width have not " - "been swapped back, which results in a transposed image.", - stacklevel=2, + def pixel_shuffle_dynamic_res(self, x, *, imgs_sizes): + scale_factor = self.downsample_ratio + patch_dim = self.patch_size + seq_lens = torch.prod(imgs_sizes // patch_dim, dim=-1) + splits = torch.split(x, seq_lens.tolist(), dim=-2) + out = [] + for i, sv in enumerate(splits): + h = imgs_sizes[i][0] // patch_dim + w = imgs_sizes[i][1] // patch_dim + sv = sv.reshape(sv.shape[0], h, w, -1) + + n, h, w, c = sv.size() + + sv = sv.view(n, h, int(w * scale_factor), int(c / scale_factor)) + sv = sv.permute(0, 2, 1, 3).contiguous() + sv = sv.view( + n, + int(w * scale_factor), + int(h * scale_factor), + int(c / (scale_factor * scale_factor)), ) - else: - x = x.permute(0, 2, 1, 3).contiguous() + + if self.ps_version == "v2": + sv = sv.permute(0, 2, 1, 3).contiguous() + + sv = sv.reshape(sv.shape[0], -1, sv.shape[-1]) + out.append(sv) + + x = torch.cat(out, dim=-2) + return x def extract_feature(self, pixel_values): @@ -1710,16 +1738,22 @@ class NemotronH_Nano_VL_V2( n = pixel_values.shape[0] vit_embeds_list = [] for i in range(0, n, micro_batch_size): - vit_embeds = self.vision_model(pixel_values[i : i + micro_batch_size]) + current = pixel_values[i : i + micro_batch_size] + vit_embeds = self.vision_model(current) vit_embeds = vit_embeds.to(dtype=torch.bfloat16) - h = w = int(vit_embeds.shape[1] ** 0.5) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = self.pixel_shuffle( - vit_embeds, scale_factor=self.downsample_ratio - ) - vit_embeds = vit_embeds.reshape( - vit_embeds.shape[0], -1, vit_embeds.shape[-1] - ) + + # pixel_shuffle_dynamic_res expects patches concatenated along dim=-2, + # but vision model outputs (batch, patches, hidden). Process each image + # individually to handle this correctly. + _, _, h, w = current.shape + shuffled_embeds = [] + for j in range(vit_embeds.shape[0]): + single_embed = vit_embeds[j : j + 1] # (1, patches, hidden) + single_shuffled = self.pixel_shuffle_dynamic_res( + single_embed, imgs_sizes=torch.tensor([(h, w)]) + ) + shuffled_embeds.append(single_shuffled) + vit_embeds = torch.cat(shuffled_embeds, dim=0) vit_embeds = self.mlp1(vit_embeds) vit_embeds_list.append(vit_embeds)