This commit is contained in:
Netanel Haber 2025-12-22 06:27:32 -08:00
parent 52e5e55a19
commit 2a7ea9ba37

View File

@ -112,6 +112,13 @@ CONV_MERGING = False # This is assumed to be False for now
PIXEL_SHUFFLE = True # This is assumed to be True for now
REDUCTION_FACTOR = 2 ** (PIXEL_SHUFFLE + CONV_MERGING)
def num_image_token_per_tile(*, tile_dims: Dims, patch_size: int, downsample_ratio: int) -> int:
tile_size = math.sqrt(tile_dims.width * tile_dims.height)
num_tokens = int(
(tile_size // patch_size) ** 2 * (downsample_ratio**2)
)
return num_tokens
def width_and_height_for_max_num_tokens_available(
*,
target_num_tokens_post_shuffle: int,
@ -129,6 +136,7 @@ def width_and_height_for_max_num_tokens_available(
>>> dims = width_and_height_for_max_num_tokens_available(B=8192, patch_size=16)
>>> assert dims.width, dims.height == (2880, 2880)
>>> assert ((dims.width // 16) * (dims.height // 16) // 4) == 8100 # tokens after shuffle
>>> assert num_image_token_per_tile(tile_dims=dims, patch_size=16, downsample_ratio=2) == 8100
"""
side_pixels = math.isqrt(target_num_tokens_post_shuffle) * REDUCTION_FACTOR * patch_size
assert isinstance(side_pixels, int) and side_pixels % patch_size == 0
@ -353,13 +361,6 @@ class BaseNanoNemotronVLProcessor(ABC):
self.norm_mean = torch.tensor(config.norm_mean).reshape(1, 3, 1, 1)
self.norm_std = torch.tensor(config.norm_std).reshape(1, 3, 1, 1)
def num_image_token_per_tile(self, *, tile_width: int, tile_height: int) -> int:
tile_size = math.sqrt(tile_width * tile_height)
num_tokens = int(
(tile_size // self.patch_size) ** 2 * (self.downsample_ratio**2)
)
return num_tokens
@property
@abstractmethod
def image_token_id(self) -> int:
@ -390,8 +391,10 @@ class BaseNanoNemotronVLProcessor(ABC):
use_thumbnail=self.use_thumbnail,
)
return num_tiles * self.num_image_token_per_tile(
tile_width=image_width, tile_height=image_height
return num_tiles * num_image_token_per_tile(
tile_dims=Dims(width=image_width, height=image_height),
patch_size=self.patch_size,
downsample_ratio=self.downsample_ratio
)
def _images_to_pixel_values_lst(
@ -710,8 +713,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
)
# Calculate embeddings for the main dynamic resolution image
num_embeddings_per_tile = self.num_image_token_per_tile(
tile_width=target_patch_width, tile_height=target_patch_height
num_embeddings_per_tile = num_image_token_per_tile(
tile_dims=Dims(width=target_patch_width, height=target_patch_height),
patch_size=self.patch_size,
downsample_ratio=self.downsample_ratio
)
token_count = target_patch_width * target_patch_height
@ -731,8 +736,10 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
if area_ratio < self._thumbnail_area_threshold:
num_tiles += 1 # Add 1 for thumbnail
# Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
num_embeddings += self.num_image_token_per_tile(
tile_width=self._thumbnail_size, tile_height=self._thumbnail_size
num_embeddings_per_tile += num_image_token_per_tile(
tile_dims=Dims(width=self._thumbnail_size, height=self._thumbnail_size),
patch_size=self.patch_size,
downsample_ratio=self.downsample_ratio
)
token_count += (
self._thumbnail_size
@ -947,6 +954,8 @@ class DynamicResolutionImageTiler(BaseNanoNemotronVLProcessor):
params_per_image, feature_sizes = self.compute_params(
images, num_tokens_available
)
print(f"{feature_sizes=}")
print(f"{params_per_image=}")
images = []
for param in params_per_image:
t = self.apply_params(param)
@ -1332,8 +1341,10 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = (
seq_len - max_image_tokens
) // processor.num_image_token_per_tile(
tile_width=256, tile_height=256
) // num_image_token_per_tile(
tile_dims=Dims(width=256, height=256),
patch_size=processor.patch_size,
downsample_ratio=processor.downsample_ratio
) # TODO(nhaber): get 256 dynamically
max_frames_per_video = max_total_frames // max(max_videos, 1)
return max(max_frames_per_video, 1)
@ -1471,8 +1482,10 @@ class NanoNemotronVLMultiModalProcessor(
video_num_patches = []
def get_video_replacement_internvl(item_idx: int):
feature_size = hf_processor.num_image_token_per_tile(
tile_width=256, tile_height=256
feature_size = num_image_token_per_tile(
tile_dims=Dims(width=256, height=256),
patch_size=hf_processor.patch_size,
downsample_ratio=hf_processor.downsample_ratio
) # TODO(nhaber): get 256 dynamically
video, metadata = mm_items["video"][item_idx]
num_patches = video_num_patches[item_idx]