Merge 6f1249fdb28615972f3b89234d11f3b19a0bd72a into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
Netanel Haber 2025-12-25 00:06:33 +00:00 committed by GitHub
commit 8c6ab2daef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,11 +8,14 @@
# --------------------------------------------------------
import copy
import warnings
import math
import random
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import einops
import numpy.typing as npt
import regex as re
import torch
@ -20,6 +23,7 @@ import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType
from typing_extensions import assert_never
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
@ -58,7 +62,6 @@ from vllm.multimodal.inputs import (
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalDataParser,
)
@ -84,6 +87,11 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Alternative: Set a specific higher limit
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
# TODO(nhaber): does use_thumbnail=True work?
# TODO(nhaber): mixing images and videos will mess up the "text_prompt_length" calculation.
IMG_START = "<img>"
IMG_END = "</img>"
IMG_CONTEXT = "<image>"
@ -93,6 +101,58 @@ IMG_CONTEXT = "<image>"
DEFAULT_NUM_TILES = 12
def num_image_token_per_tile(
*, width: int, height: int, patch_size: int, downsample_ratio: int
) -> int:
tile_size = math.sqrt((width // patch_size) * (height // patch_size))
num_tokens = int(tile_size**2 // (downsample_ratio**2))
return num_tokens
def width_and_height_for_max_num_tokens_available(
*,
target_num_tokens_post_shuffle: int,
patch_size: int,
downsample_ratio: int,
) -> tuple[int, int]:
"""
TODO(nhaber): optimize this so it squeezes closer to target number of tokens.
Calculate image dimensions that produce approximately `target` tokens after
pixel_shuffle.
With pixel_shuffle enabled, each 2x2 patch grid becomes 1 token, so we
need 4*B patches to get B tokens.
Examples:
>>> width, height = width_and_height_for_max_num_tokens_available(
... target_num_tokens_post_shuffle=8192,
... patch_size=16,
... downsample_ratio=2,
... )
>>> assert width, height == (2880, 2880)
>>> assert (width // 16) * (height // 16) // 2**2 == 8100 # tokens post-shuffle
>>> assert (
... num_image_token_per_tile(
... width=width, height=height, patch_size=16, downsample_ratio=2
... )
... == 8100
... )
"""
side_pixels = (
math.isqrt(target_num_tokens_post_shuffle) * downsample_ratio * patch_size
)
assert isinstance(side_pixels, int) and side_pixels % patch_size == 0
return side_pixels, side_pixels
@dataclass
class DynamicResolutionParams:
media: Image.Image
num_tiles: int
num_embeddings: int
patch_size: tuple[int, int]
class NanoNemotronVLImagePixelInputs(TensorSchema):
"""
Dimensions:
@ -252,7 +312,12 @@ def video_to_pixel_values(
return torch.stack(frames_tensors)
def input_conditioner(x, norm_mean, norm_std):
def input_conditioner(
x: torch.Tensor, norm_mean: torch.Tensor, norm_std: torch.Tensor
) -> torch.Tensor:
assert isinstance(x, torch.Tensor), "x must be a tensor"
assert isinstance(norm_mean, torch.Tensor), "norm_mean must be a tensor"
assert isinstance(norm_std, torch.Tensor), "norm_std must be a tensor"
return (x - norm_mean) / norm_std
@ -291,15 +356,13 @@ class BaseNanoNemotronVLProcessor(ABC):
self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES
image_size: int = config.force_image_size
patch_size: int = config.patch_size
self.patch_size: int = getattr(config, "patch_size", 16)
# self.downsample_ratio: float = self.config.downsample_ratio
self.num_image_token = int(
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
)
self.image_size = image_size
self.use_thumbnail: bool = config.use_thumbnail
self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1)
self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1)
self.norm_mean = torch.tensor(config.norm_mean).reshape(1, 3, 1, 1)
self.norm_std = torch.tensor(config.norm_std).reshape(1, 3, 1, 1)
@property
@abstractmethod
@ -331,13 +394,19 @@ class BaseNanoNemotronVLProcessor(ABC):
use_thumbnail=self.use_thumbnail,
)
return num_patches * self.num_image_token
return num_patches * num_image_token_per_tile(
width=image_width,
height=image_height,
patch_size=self.patch_size,
downsample_ratio=self.downsample_ratio,
)
def _images_to_pixel_values_lst(
self,
text_prompt_length: int,
images: list[Image.Image],
max_num_tiles: int,
) -> list[torch.Tensor]:
) -> tuple[list[torch.Tensor], list[int]]:
return [
image_to_pixel_values(
image,
@ -358,7 +427,20 @@ class BaseNanoNemotronVLProcessor(ABC):
if len(images) == 0:
image_inputs = {}
else:
pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
assert len(text) == 1, (
"hf_processor is called on the output of get_dummy_text, "
"which should be a single string"
)
text_prompt_length = len(
self.tokenizer(
text[0].replace("<image>", ""), add_special_tokens=False
)["input_ids"]
)
pixel_values_lst, token_counts = self._images_to_pixel_values_lst(
text_prompt_length=text_prompt_length,
images=images,
max_num_tiles=max_num_tiles,
)
image_inputs = {
"pixel_values_flat": input_conditioner(
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
@ -378,9 +460,10 @@ class BaseNanoNemotronVLProcessor(ABC):
"same as the number of images"
)
for i, pixel_values in enumerate(pixel_values_lst):
for i, (pixel_values, feature_size) in enumerate(
zip(pixel_values_lst, token_counts, strict=True)
):
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
parts[i] = parts[i].replace("<image>", image_repl.full)
text = ["".join(parts)]
@ -393,6 +476,7 @@ class BaseNanoNemotronVLProcessor(ABC):
input_item = [input_item]
return input_item
@abstractmethod
def __call__(
self,
text: str | list[str] | None = None,
@ -400,26 +484,494 @@ class BaseNanoNemotronVLProcessor(ABC):
return_tensors: str | TensorType | None = None,
max_num_tiles: int | None = None,
) -> BatchFeature:
# Use default if not provided
if max_num_tiles is None:
max_num_tiles = self.max_num_tiles
raise NotImplementedError
text, images = [self._make_batch_input(x) for x in (text, images)]
text, image_inputs = self._preprocess_image(
text=text,
images=images,
max_num_tiles=max_num_tiles,
class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
def __init__(
self,
config: PretrainedConfig,
tokenizer: TokenizerLike,
*args,
max_model_len: int,
max_num_tiles: int | None = None,
min_num_patches: int = 4,
factor_max: float = 1.0,
pixel_shuffle: bool = True,
min_side: int | None = None,
conv_merging: bool = False,
use_thumbnail: bool = False,
thumbnail_size: int = 448,
thumbnail_area_threshold: float = 0.8,
apply_data_augment: bool = False,
**kwargs,
) -> None:
super().__init__(
config=config, tokenizer=tokenizer, max_num_tiles=max_num_tiles, **kwargs
)
text_inputs = self.tokenizer(text, add_special_tokens=False)
self._patch_size: int = getattr(config, "patch_size", 16)
self.max_model_len = max_model_len
self._min_num_patches = min_num_patches
self._factor_max = factor_max
self._pixel_shuffle = pixel_shuffle
self._min_side = min_side
self._conv_merging = conv_merging
self._use_thumbnail = use_thumbnail
self._thumbnail_size = thumbnail_size
self._thumbnail_area_threshold = thumbnail_area_threshold
self.norm_mean = torch.tensor(self.CLIP_PIXEL_MEAN).reshape(1, 3, 1, 1)
self.norm_std = torch.tensor(self.CLIP_PIXEL_STD).reshape(1, 3, 1, 1)
self._transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)),
# T.Normalize(mean=pixel_mean, std=pixel_std), - This is done down below with input_conditioner
]
)
self._apply_data_augment = apply_data_augment
reduction_factor = 1 / self.config.downsample_ratio
assert reduction_factor == 2.0, (
"I don't understand what's going on if this isn't 4"
)
self.downsample_ratio = int(reduction_factor) ** (pixel_shuffle + conv_merging)
assert self.downsample_ratio == 2, (
f"I don't understand what's going on if {self.downsample_ratio=} isn't 2"
)
combined_outputs = {**text_inputs, **image_inputs}
def _get_num_embeddings(self, width: int, height: int) -> int:
return num_image_token_per_tile(
width=width,
height=height,
patch_size=self._patch_size,
downsample_ratio=self.downsample_ratio,
)
return BatchFeature(combined_outputs, tensor_type=return_tensors)
def max_num_tokens_available(self, text_prompt_length: int) -> int:
return self.max_model_len - text_prompt_length - 4
def _images_to_pixel_values_lst(
self,
text_prompt_length: int,
images: list[Image.Image],
max_num_tiles: int,
) -> tuple[list[torch.Tensor], list[int]]:
num_tokens_available = self.max_num_tokens_available(text_prompt_length)
params_per_image = self.compute_params(images, num_tokens_available)
feature_sizes = []
images = []
for param in params_per_image:
for t in self.apply_params(param):
if t.ndim == 3:
t = t.unsqueeze(0)
images.append(t)
feature_sizes.append(param.num_embeddings)
print(f"{feature_sizes=}")
print(f"{params_per_image=}")
return images, feature_sizes
feature_size_cache: dict[
Image.Image, int
] = {} # TODO(nhaber): Find a less silly way of doing this... Why can't this be an instance variable?
def get_cached_feature_size(self, image: Image.Image) -> int:
feature_size = self.feature_size_cache[id(image)]
del self.feature_size_cache[id(image)]
return feature_size
def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]:
resized_img = params.media.resize(
(
params.patch_size[0] * self._patch_size,
params.patch_size[1] * self._patch_size,
)
)
processed_images = [resized_img]
# Add thumbnail if enabled and image area is below threshold
if self._use_thumbnail:
# Calculate areas
resized_area = resized_img.size[0] * resized_img.size[1]
thumbnail_area = self._thumbnail_size * self._thumbnail_size
area_ratio = resized_area / thumbnail_area
# Only add thumbnail if resized image area is less than threshold % of thumbnail area
if area_ratio < self._thumbnail_area_threshold:
thumbnail_img = params.media.resize(
(self._thumbnail_size, self._thumbnail_size)
)
processed_images.append(thumbnail_img)
return [self._transform(img) for img in processed_images]
def process_media(
self,
media: Image.Image,
num_tokens_available: int,
data_augment: bool = False,
tiling_augment_prob: float = 0.4,
) -> tuple[DynamicResolutionParams, int]:
"""Process a single media item and return its parameters.
Args:
media: The media item to process
num_tokens_available: Number of tokens available for this media
data_augment: Whether to apply data augmentation to the image. Defaults to False.
Returns:
DynamicResolutionParams for the media
"""
current_num_tokens_available = num_tokens_available
assert isinstance(media, Image.Image), (
"Dynamic resolution is only supported for image media"
)
orig_width, orig_height = media.width, media.height
# TODO(nhaber): Ask Tyler - the round + 0.5 code is dangerous [banker's rounding], no?
closest_patch_height = round(orig_height / self._patch_size + 0.5)
closest_patch_width = round(orig_width / self._patch_size + 0.5)
patches = closest_patch_height * closest_patch_width
factor = min(
math.sqrt(current_num_tokens_available / patches), self._factor_max
)
target_patch_height = math.floor(factor * closest_patch_height)
target_patch_width = math.floor(factor * closest_patch_width)
# We only consider self._min_num_patches if it is greater than current_num_tokens_available.
if (
current_num_tokens_available > self._min_num_patches
and target_patch_height * target_patch_width < self._min_num_patches
):
up_factor = math.sqrt(
self._min_num_patches / (target_patch_height * target_patch_width)
)
target_patch_height = math.ceil(up_factor * target_patch_height)
target_patch_width = math.ceil(up_factor * target_patch_width)
if (
self._min_side is not None
and min(target_patch_width, target_patch_height) * self._patch_size
< self._min_side
):
if target_patch_width <= target_patch_height:
up_factor = self._min_side / (target_patch_width * self._patch_size)
new_patch_height = math.ceil(up_factor * target_patch_height)
new_patch_width = math.ceil(up_factor * target_patch_width)
if new_patch_height * new_patch_width > current_num_tokens_available:
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
if (
max(current_num_tokens_available // new_patch_width, 1)
* self._patch_size
< self._min_side
):
up_factor = math.sqrt(
current_num_tokens_available
/ (target_patch_height * target_patch_width)
)
target_patch_height = math.floor(
up_factor * target_patch_height
)
target_patch_width = math.floor(up_factor * target_patch_width)
target_patch_width = new_patch_width
target_patch_height = max(
current_num_tokens_available // new_patch_width, 1
)
else:
target_patch_height = new_patch_height
target_patch_width = new_patch_width
else:
up_factor = self._min_side / (target_patch_height * self._patch_size)
new_patch_height = math.ceil(up_factor * target_patch_height)
new_patch_width = math.ceil(up_factor * target_patch_width)
if new_patch_height * new_patch_width > current_num_tokens_available:
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
if (
max(current_num_tokens_available // new_patch_height, 1)
* self._patch_size
< self._min_side
):
up_factor = math.sqrt(
current_num_tokens_available
/ (target_patch_height * target_patch_width)
)
target_patch_height = math.floor(
up_factor * target_patch_height
)
target_patch_width = math.floor(up_factor * target_patch_width)
else:
target_patch_height = new_patch_height
target_patch_width = max(
current_num_tokens_available // new_patch_height, 1
)
else:
target_patch_height = new_patch_height
target_patch_width = new_patch_width
# Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
# or by 4 when BOTH are enabled (two successive 2x reductions)
if self._pixel_shuffle or self._conv_merging:
required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2
rem_h = target_patch_height % required_divisor
if rem_h != 0:
inc_h = required_divisor - rem_h
if (
target_patch_height + inc_h
) * target_patch_width <= current_num_tokens_available:
target_patch_height += inc_h
else:
target_patch_height = max(
required_divisor, target_patch_height - rem_h
)
rem_w = target_patch_width % required_divisor
if rem_w != 0:
inc_w = required_divisor - rem_w
if (
target_patch_height * (target_patch_width + inc_w)
<= current_num_tokens_available
):
target_patch_width += inc_w
else:
target_patch_width = max(
required_divisor, target_patch_width - rem_w
)
if (
data_augment
and self._apply_data_augment
and random.random() < tiling_augment_prob
):
target_patch_width, target_patch_height = self.augment_resolution(
target_patch_width, target_patch_height, current_num_tokens_available
)
# Calculate embeddings for the main dynamic resolution image
num_embeddings = self._get_num_embeddings(
target_patch_width * self._patch_size,
target_patch_height * self._patch_size,
)
token_count = target_patch_width * target_patch_height
# Add thumbnail embeddings if enabled and image area is below threshold
num_tiles = 1 # Base dynamic resolution image
if self._use_thumbnail:
# Calculate areas
resized_area = (target_patch_width * self._patch_size) * (
target_patch_height * self._patch_size
)
thumbnail_area = self._thumbnail_size * self._thumbnail_size
area_ratio = resized_area / thumbnail_area
# Only add thumbnail if resized image area is less than threshold % of thumbnail area
if area_ratio < self._thumbnail_area_threshold:
num_tiles += 1 # Add 1 for thumbnail
# Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
num_embeddings += self._get_num_embeddings(
self._thumbnail_size, self._thumbnail_size
)
token_count += (
self._thumbnail_size
// self._patch_size
* self._thumbnail_size
// self._patch_size
)
return DynamicResolutionParams(
media=media,
num_tiles=num_tiles,
num_embeddings=num_embeddings,
patch_size=(target_patch_width, target_patch_height),
), token_count
def augment_resolution(
self,
target_patch_width: int,
target_patch_height: int,
current_num_tokens_available: int,
) -> tuple[int, int]:
min_num_patch_one_side = 32
if random.random() < 0.5:
# Minus one
if (
target_patch_width <= min_num_patch_one_side
and target_patch_height <= min_num_patch_one_side
):
return target_patch_width, target_patch_height
elif target_patch_width <= min_num_patch_one_side:
return target_patch_width, target_patch_height - min_num_patch_one_side
elif target_patch_height <= min_num_patch_one_side:
return target_patch_width - min_num_patch_one_side, target_patch_height
else:
if random.random() < 0.5:
return (
target_patch_width - min_num_patch_one_side,
target_patch_height,
)
else:
return (
target_patch_width,
target_patch_height - min_num_patch_one_side,
)
else:
# Plus one
if target_patch_width * target_patch_height < current_num_tokens_available:
if random.random() < 0.5:
return (
target_patch_width + min_num_patch_one_side,
target_patch_height,
)
else:
return (
target_patch_width,
target_patch_height + min_num_patch_one_side,
)
return target_patch_width, target_patch_height
def compute_params(
self,
media_list: list[Image.Image],
num_tokens_available: int | None = None,
data_augment: bool = False,
) -> list[DynamicResolutionParams]:
"""Compute parameters for all media with iterative token budgeting.
Args:
media_list: List of media items to process
num_tokens_available: Total number of tokens available across all media
data_augment: Whether to apply data augmentation to the image. Defaults to
False.
Returns:
List of ImageTilingParams for each media item
"""
num_tokens_available = (
num_tokens_available
* (4 if self._pixel_shuffle else 1)
* (4 if self._conv_merging else 1)
)
# When the number of available token is too small, allow self._min_num_patches per media and
# let the sample be truncated.
num_tokens_available = max(
num_tokens_available, self._min_num_patches * len(media_list)
)
# Clip the number of tokens available per media to be between min and max patches.
num_tokens_available_per_media = [
max(num_tokens_available, self._min_num_patches)
for _ in range(len(media_list))
]
# In theory this could be a while True loop, but in case the process_media method slightly
# changes, I want to make sure we don't get stuck in an infinite loop.
for _ in range(10):
# Step 1: Process each media with current token budget
params = []
token_counts = []
for media, tokens_for_media in zip(
media_list, num_tokens_available_per_media
):
param, token_count = self.process_media(
media, tokens_for_media, data_augment=data_augment
)
params.append(param)
token_counts.append(token_count)
self.feature_size_cache[id(param.media)] = param.num_embeddings
# Step 2: Check if total tokens is within budget
total_tokens = sum(token_counts)
if total_tokens <= num_tokens_available:
# We're within budget, return the params
return params
# Step 3: We're over budget, need to scale down
# Calculate scaling factor to get under budget
scaling_factor = num_tokens_available / total_tokens
# Recalculate token budgets for each media based on scaling
# Each media gets a proportional share of the total budget
scaled_down_num_tokens_available_per_media = [
max(self._min_num_patches, int(token_count * scaling_factor))
for token_count in token_counts
]
scaled_down = any(
[
scaled_down_num_tokens_available_per_media[i]
< num_tokens_available_per_media[i]
for i in range(len(num_tokens_available_per_media))
]
)
# If there was not scaling down, we're stuck just use min_num_patches per media, else
# try with the scaled down num_tokens_available_per_media.
if not scaled_down:
num_tokens_available_per_media = [self._min_num_patches] * len(
media_list
)
else:
num_tokens_available_per_media = (
scaled_down_num_tokens_available_per_media
)
assert_never(num_tokens_available_per_media)
def stack(
self, images: list[torch.Tensor]
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
imgs_sizes = torch.tensor(
[[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
)
def rearrange_img(x):
py = x.shape[-2] // self._patch_size
px = x.shape[-1] // self._patch_size
x = einops.rearrange(
x,
"c (py yy) (px xx) -> (py px) (c yy xx)",
py=py,
yy=self._patch_size,
px=px,
xx=self._patch_size,
)
return x
if len(images) > 0:
imgs = [rearrange_img(img) for img in images]
current_length = 0
max_length = 0
vision_cu_lengths = [0]
for img in imgs:
if max_length < img.shape[0]:
max_length = img.shape[0]
current_length += img.shape[0]
vision_cu_lengths.append(current_length)
vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
return (
torch.cat(imgs, dim=0).unsqueeze(0),
imgs_sizes,
vision_cu_lengths,
vision_max_lengths,
)
else:
return (
torch.tensor([[0]], dtype=torch.float32),
torch.tensor([[0, 0]], dtype=torch.int32),
None,
None,
)
class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
class NanoNemotronVLProcessor(DynamicResolutionImageTiler):
"""
HF Processor with extended video processing logic.
Code for video processing is adapted from video example:
@ -431,6 +983,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
config: PretrainedConfig,
tokenizer: TokenizerLike,
*,
max_model_len: int,
max_num_tiles: int | None = None,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
@ -441,6 +994,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
super().__init__(
config=config,
tokenizer=tokenizer,
max_model_len=max_model_len,
max_num_tiles=max_num_tiles,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
@ -716,7 +1270,7 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
**kwargs: object,
) -> BaseNanoNemotronVLProcessor:
) -> DynamicResolutionImageTiler:
raise NotImplementedError
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
@ -739,31 +1293,6 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
max_num_tiles=max_num_tiles,
)
def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize:
processor = self.get_hf_processor()
base_size = processor.image_size
target_ratios = get_internvl_target_ratios(1, max_num_tiles)
largest_feature_size, largest_feature_pinpoint = 0, None
for wr, hr in target_ratios:
width, height = base_size * wr, base_size * hr
feat_size = self.get_num_image_tokens(
image_width=width,
image_height=height,
max_num_tiles=max_num_tiles,
processor=processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width, height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
def get_max_image_tokens(self) -> int:
processor = self.get_hf_processor()
# Use default max_num_tiles for max tokens calculation
@ -788,7 +1317,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
@property
def supports_video(self):
return self.get_hf_processor().supports_video
return False # TODO(nhaber): add video support
def get_supported_mm_limits(self):
video_limit = {"video": None} if self.supports_video else {}
@ -811,7 +1340,12 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
processor = self.get_hf_processor() # we get the CustomProcessor here
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
max_total_frames = (seq_len - max_image_tokens) // num_image_token_per_tile(
width=256,
height=256,
patch_size=processor._patch_size,
downsample_ratio=processor.downsample_ratio,
) # TODO(nhaber): get 256 dynamically
max_frames_per_video = max_total_frames // max(max_videos, 1)
return max(max_frames_per_video, 1)
@ -822,6 +1356,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
tokenizer=self.get_tokenizer(),
video_token=self.get_video_token(),
video_pruning_rate=self.get_video_pruning_rate(),
max_model_len=self.ctx.model_config.max_model_len,
**kwargs,
)
@ -871,17 +1406,8 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
# Extract max_num_tiles from kwargs, default to 12
max_num_tiles = hf_processor_mm_kwargs.get(
"max_num_tiles", hf_processor.max_num_tiles
)
feature_size = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
max_num_tiles=max_num_tiles,
processor=hf_processor,
)
image = images.get(item_idx)
feature_size = hf_processor.get_cached_feature_size(image)
num_patches = None
local_image_num_patches = image_num_patches
@ -956,7 +1482,12 @@ class NanoNemotronVLMultiModalProcessor(
video_num_patches = []
def get_video_replacement_internvl(item_idx: int):
feature_size = hf_processor.num_image_token
feature_size = num_image_token_per_tile(
width=256,
height=256,
patch_size=hf_processor._patch_size,
downsample_ratio=hf_processor.downsample_ratio,
) # TODO(nhaber): get 256 dynamically
video, metadata = mm_items["video"][item_idx]
num_patches = video_num_patches[item_idx]
if num_patches is not None:
@ -1017,12 +1548,14 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
# Use default max_num_tiles for dummy data generation
max_num_tiles = 12
target_width, target_height = self.info.get_image_size_with_most_features(
max_num_tiles
)
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
B = processor.max_num_tokens_available(text_prompt_length=num_images)
target_width, target_height = width_and_height_for_max_num_tokens_available(
target_num_tokens_post_shuffle=B,
patch_size=processor._patch_size,
downsample_ratio=processor.downsample_ratio,
)
image_overrides = mm_options.get("image") if mm_options else None
@ -1132,9 +1665,6 @@ class NemotronH_Nano_VL_V2(
patch_size = config.patch_size
self.patch_size = patch_size
self.template = config.template
self.num_image_token = int(
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
)
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
self.image_tag_type = config.image_tag_type
@ -1182,33 +1712,36 @@ class NemotronH_Nano_VL_V2(
IMG_CONTEXT, add_special_tokens=False
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(
n,
w,
int(h * scale_factor),
int(c / scale_factor),
)
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale -->
# N, H * scale, W * scale, C // (scale ** 2)
x = x.view(
n,
int(h * scale_factor),
int(w * scale_factor),
int(c / (scale_factor * scale_factor)),
)
if self.ps_version == "v1":
warnings.warn(
"In ps_version 'v1', the height and width have not "
"been swapped back, which results in a transposed image.",
stacklevel=2,
def pixel_shuffle_dynamic_res(self, x, *, imgs_sizes):
scale_factor = self.downsample_ratio
patch_dim = self.patch_size
seq_lens = torch.prod(imgs_sizes // patch_dim, dim=-1)
splits = torch.split(x, seq_lens.tolist(), dim=-2)
out = []
for i, sv in enumerate(splits):
h = imgs_sizes[i][0] // patch_dim
w = imgs_sizes[i][1] // patch_dim
sv = sv.reshape(sv.shape[0], h, w, -1)
n, h, w, c = sv.size()
sv = sv.view(n, h, int(w * scale_factor), int(c / scale_factor))
sv = sv.permute(0, 2, 1, 3).contiguous()
sv = sv.view(
n,
int(w * scale_factor),
int(h * scale_factor),
int(c / (scale_factor * scale_factor)),
)
else:
x = x.permute(0, 2, 1, 3).contiguous()
if self.ps_version == "v2":
sv = sv.permute(0, 2, 1, 3).contiguous()
sv = sv.reshape(sv.shape[0], -1, sv.shape[-1])
out.append(sv)
x = torch.cat(out, dim=-2)
return x
def extract_feature(self, pixel_values):
@ -1220,16 +1753,22 @@ class NemotronH_Nano_VL_V2(
n = pixel_values.shape[0]
vit_embeds_list = []
for i in range(0, n, micro_batch_size):
vit_embeds = self.vision_model(pixel_values[i : i + micro_batch_size])
current = pixel_values[i : i + micro_batch_size]
vit_embeds = self.vision_model(current)
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(
vit_embeds, scale_factor=self.downsample_ratio
)
vit_embeds = vit_embeds.reshape(
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
)
# pixel_shuffle_dynamic_res expects patches concatenated along dim=-2,
# but vision model outputs (batch, patches, hidden). Process each image
# individually to handle this correctly.
_, _, h, w = current.shape
shuffled_embeds = []
for j in range(vit_embeds.shape[0]):
single_embed = vit_embeds[j : j + 1] # (1, patches, hidden)
single_shuffled = self.pixel_shuffle_dynamic_res(
single_embed, imgs_sizes=torch.tensor([(h, w)])
)
shuffled_embeds.append(single_shuffled)
vit_embeds = torch.cat(shuffled_embeds, dim=0)
vit_embeds = self.mlp1(vit_embeds)
vit_embeds_list.append(vit_embeds)
@ -1652,33 +2191,6 @@ class NemotronH_Nano_VL_V2(
if save_to_file and sys.stdout != original_stdout:
sys.stdout = original_stdout
def get_model_info(self):
"""
Get basic model information as a dictionary.
"""
total_params = sum(p.numel() for p in self.parameters())
component_info = {}
for name, param in self.named_parameters():
component = name.split(".")[0]
if component not in component_info:
component_info[component] = {"params": 0, "size": 0}
component_info[component]["params"] += 1
component_info[component]["size"] += param.numel()
return {
"model_name": "NemotronH_Nano_VL_V2",
"total_parameters": total_params,
"memory_estimate_mb": total_params * 2 / (1024**2), # bfloat16
"components": component_info,
"config": {
"image_size": getattr(self.config, "force_image_size", None),
"patch_size": getattr(self.config, "patch_size", None),
"num_image_token": self.num_image_token,
"downsample_ratio": self.downsample_ratio,
},
}
def get_vit_model_from_radio_config(self, hf_config):
hf_config_vision = hf_config.vision_config
model_name = hf_config_vision.args.get("model")