mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 17:04:43 +08:00
import image processing into model_executor/models/nano_nemotron_vl.py
This commit is contained in:
parent
50ffea9826
commit
1bceb28678
@ -7,10 +7,14 @@
|
||||
# LICENSE is in root directory.
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from collections.abc import Iterable, Mapping, Sequence, Callable
|
||||
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
|
||||
|
||||
import numpy.typing as npt
|
||||
@ -20,6 +24,7 @@ import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
import einops
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||
@ -84,6 +89,488 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
|
||||
# Alternative: Set a specific higher limit
|
||||
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
|
||||
|
||||
IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
|
||||
IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
|
||||
SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
|
||||
SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
|
||||
CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
||||
CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
|
||||
RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060]
|
||||
RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250]
|
||||
|
||||
|
||||
pixel_statistics = {
|
||||
"clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
||||
"siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
|
||||
"internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
|
||||
"radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
||||
"radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
|
||||
"huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
|
||||
"radio_siglip_move": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
||||
"cradio-v1": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
||||
"cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DynamicResolutionParams:
|
||||
image: Image.Image
|
||||
num_tiles: int
|
||||
num_embeddings: int
|
||||
patch_size: tuple[int, int]
|
||||
|
||||
|
||||
class DynamicResolutionImageTilingStrategy:
|
||||
def __init__(
|
||||
self,
|
||||
vision_model_type: str,
|
||||
min_num_patches: int,
|
||||
patch_size: int,
|
||||
get_num_embeddings: Callable[[int, int], int],
|
||||
factor_max: float = 1.0,
|
||||
pixel_shuffle: bool = False,
|
||||
min_side: int | None = None,
|
||||
conv_merging: bool = False,
|
||||
use_thumbnail: bool = False,
|
||||
thumbnail_size: int = 448,
|
||||
thumbnail_area_threshold: float = 0.8,
|
||||
max_num_patches: int = 0,
|
||||
apply_data_augment: bool = False,
|
||||
):
|
||||
assert "radio" in vision_model_type, (
|
||||
"Dynamic resolution is only supported for radio models"
|
||||
)
|
||||
self._vision_model_type = vision_model_type
|
||||
self._min_num_patches = min_num_patches
|
||||
self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf")
|
||||
self._patch_size = patch_size
|
||||
self._get_num_embeddings = get_num_embeddings
|
||||
self._factor_max = factor_max
|
||||
self._pixel_shuffle = pixel_shuffle
|
||||
self._min_side = min_side
|
||||
self._conv_merging = conv_merging
|
||||
self._use_thumbnail = use_thumbnail
|
||||
self._thumbnail_size = thumbnail_size
|
||||
self._thumbnail_area_threshold = thumbnail_area_threshold
|
||||
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
|
||||
self._transform = T.Compose(
|
||||
[
|
||||
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
||||
T.ToTensor(), # T.Lambda(lambda img: _fast_to_tensor(img)),
|
||||
T.Normalize(mean=pixel_mean, std=pixel_std),
|
||||
]
|
||||
)
|
||||
self._apply_data_augment = apply_data_augment
|
||||
|
||||
def apply_params(
|
||||
self, params: DynamicResolutionParams, **kwargs
|
||||
) -> list[torch.Tensor]:
|
||||
# resize the image
|
||||
resized_img = params.image.resize(
|
||||
(
|
||||
params.patch_size[0] * self._patch_size,
|
||||
params.patch_size[1] * self._patch_size,
|
||||
)
|
||||
)
|
||||
processed_images = [resized_img]
|
||||
|
||||
# Add thumbnail if enabled and image area is below threshold
|
||||
if self._use_thumbnail:
|
||||
# Calculate areas
|
||||
resized_area = resized_img.size[0] * resized_img.size[1]
|
||||
thumbnail_area = self._thumbnail_size * self._thumbnail_size
|
||||
area_ratio = resized_area / thumbnail_area
|
||||
|
||||
# Only add thumbnail if resized image area is less than threshold % of thumbnail area
|
||||
if area_ratio < self._thumbnail_area_threshold:
|
||||
thumbnail_img = params.image.resize(
|
||||
(self._thumbnail_size, self._thumbnail_size)
|
||||
)
|
||||
processed_images.append(thumbnail_img)
|
||||
|
||||
return [self._transform(img) for img in processed_images]
|
||||
|
||||
def process_media(
|
||||
self,
|
||||
image: Image.Image,
|
||||
num_tokens_available: int,
|
||||
data_augment: bool = False,
|
||||
tiling_augment_prob: float = 0.4,
|
||||
) -> DynamicResolutionParams:
|
||||
"""Process a single media item and return its parameters.
|
||||
|
||||
Args:
|
||||
media: The media item to process
|
||||
num_tokens_available: Number of tokens available for this media
|
||||
data_augment: Whether to apply data augmentation to the image. Defaults to False.
|
||||
Returns:
|
||||
DynamicResolutionParams for the media
|
||||
"""
|
||||
current_num_tokens_available = num_tokens_available
|
||||
assert isinstance(image, Image.Image), (
|
||||
"Dynamic resolution is only supported for image media"
|
||||
)
|
||||
orig_width, orig_height = image.width, image.height
|
||||
|
||||
closest_patch_height = round(orig_height / self._patch_size + 0.5)
|
||||
closest_patch_width = round(orig_width / self._patch_size + 0.5)
|
||||
patches = closest_patch_height * closest_patch_width
|
||||
|
||||
factor = min(
|
||||
math.sqrt(current_num_tokens_available / patches), self._factor_max
|
||||
)
|
||||
target_patch_height = math.floor(factor * closest_patch_height)
|
||||
target_patch_width = math.floor(factor * closest_patch_width)
|
||||
|
||||
# We only consider self._min_num_patches if it is greater than current_num_tokens_available.
|
||||
if (
|
||||
current_num_tokens_available > self._min_num_patches
|
||||
and target_patch_height * target_patch_width < self._min_num_patches
|
||||
):
|
||||
up_factor = math.sqrt(
|
||||
self._min_num_patches / (target_patch_height * target_patch_width)
|
||||
)
|
||||
target_patch_height = math.ceil(up_factor * target_patch_height)
|
||||
target_patch_width = math.ceil(up_factor * target_patch_width)
|
||||
|
||||
if (
|
||||
self._min_side is not None
|
||||
and min(target_patch_width, target_patch_height) * self._patch_size
|
||||
< self._min_side
|
||||
):
|
||||
if target_patch_width <= target_patch_height:
|
||||
up_factor = self._min_side / (target_patch_width * self._patch_size)
|
||||
new_patch_height = math.ceil(up_factor * target_patch_height)
|
||||
new_patch_width = math.ceil(up_factor * target_patch_width)
|
||||
|
||||
if new_patch_height * new_patch_width > current_num_tokens_available:
|
||||
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
|
||||
if (
|
||||
max(current_num_tokens_available // new_patch_width, 1)
|
||||
* self._patch_size
|
||||
< self._min_side
|
||||
):
|
||||
up_factor = math.sqrt(
|
||||
current_num_tokens_available
|
||||
/ (target_patch_height * target_patch_width)
|
||||
)
|
||||
target_patch_height = math.floor(
|
||||
up_factor * target_patch_height
|
||||
)
|
||||
target_patch_width = math.floor(up_factor * target_patch_width)
|
||||
target_patch_width = new_patch_width
|
||||
target_patch_height = max(
|
||||
current_num_tokens_available // new_patch_width, 1
|
||||
)
|
||||
else:
|
||||
target_patch_height = new_patch_height
|
||||
target_patch_width = new_patch_width
|
||||
else:
|
||||
up_factor = self._min_side / (target_patch_height * self._patch_size)
|
||||
new_patch_height = math.ceil(up_factor * target_patch_height)
|
||||
new_patch_width = math.ceil(up_factor * target_patch_width)
|
||||
|
||||
if new_patch_height * new_patch_width > current_num_tokens_available:
|
||||
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
|
||||
if (
|
||||
max(current_num_tokens_available // new_patch_height, 1)
|
||||
* self._patch_size
|
||||
< self._min_side
|
||||
):
|
||||
up_factor = math.sqrt(
|
||||
current_num_tokens_available
|
||||
/ (target_patch_height * target_patch_width)
|
||||
)
|
||||
target_patch_height = math.floor(
|
||||
up_factor * target_patch_height
|
||||
)
|
||||
target_patch_width = math.floor(up_factor * target_patch_width)
|
||||
else:
|
||||
target_patch_height = new_patch_height
|
||||
target_patch_width = max(
|
||||
current_num_tokens_available // new_patch_height, 1
|
||||
)
|
||||
else:
|
||||
target_patch_height = new_patch_height
|
||||
target_patch_width = new_patch_width
|
||||
|
||||
# Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
|
||||
# or by 4 when BOTH are enabled (two successive 2x reductions)
|
||||
if self._pixel_shuffle or self._conv_merging:
|
||||
required_divisor = 4 if (self._pixel_shuffle and self._conv_merging) else 2
|
||||
|
||||
rem_h = target_patch_height % required_divisor
|
||||
if rem_h != 0:
|
||||
inc_h = required_divisor - rem_h
|
||||
if (
|
||||
target_patch_height + inc_h
|
||||
) * target_patch_width <= current_num_tokens_available:
|
||||
target_patch_height += inc_h
|
||||
else:
|
||||
target_patch_height = max(
|
||||
required_divisor, target_patch_height - rem_h
|
||||
)
|
||||
|
||||
rem_w = target_patch_width % required_divisor
|
||||
if rem_w != 0:
|
||||
inc_w = required_divisor - rem_w
|
||||
if (
|
||||
target_patch_height * (target_patch_width + inc_w)
|
||||
<= current_num_tokens_available
|
||||
):
|
||||
target_patch_width += inc_w
|
||||
else:
|
||||
target_patch_width = max(
|
||||
required_divisor, target_patch_width - rem_w
|
||||
)
|
||||
|
||||
if (
|
||||
data_augment
|
||||
and self._apply_data_augment
|
||||
and random.random() < tiling_augment_prob
|
||||
):
|
||||
target_patch_width, target_patch_height = self.augment_resolution(
|
||||
target_patch_width, target_patch_height, current_num_tokens_available
|
||||
)
|
||||
|
||||
assert isinstance(image, Image.Image), (
|
||||
"Dynamic resolution is only supported for image media"
|
||||
)
|
||||
|
||||
# Calculate embeddings for the main dynamic resolution image
|
||||
num_embeddings = self._get_num_embeddings(
|
||||
target_patch_width * self._patch_size,
|
||||
target_patch_height * self._patch_size,
|
||||
)
|
||||
|
||||
token_count = target_patch_width * target_patch_height
|
||||
|
||||
# Add thumbnail embeddings if enabled and image area is below threshold
|
||||
num_tiles = 1 # Base dynamic resolution image
|
||||
if self._use_thumbnail:
|
||||
# Calculate areas
|
||||
resized_area = (target_patch_width * self._patch_size) * (
|
||||
target_patch_height * self._patch_size
|
||||
)
|
||||
thumbnail_area = self._thumbnail_size * self._thumbnail_size
|
||||
area_ratio = resized_area / thumbnail_area
|
||||
|
||||
# Only add thumbnail if resized image area is less than threshold % of thumbnail area
|
||||
if area_ratio < self._thumbnail_area_threshold:
|
||||
num_tiles += 1 # Add 1 for thumbnail
|
||||
# Add embeddings for thumbnail (thumbnail_size x thumbnail_size)
|
||||
num_embeddings += self._get_num_embeddings(
|
||||
self._thumbnail_size, self._thumbnail_size
|
||||
)
|
||||
token_count += (
|
||||
self._thumbnail_size
|
||||
// self._patch_size
|
||||
* self._thumbnail_size
|
||||
// self._patch_size
|
||||
)
|
||||
|
||||
return DynamicResolutionParams(
|
||||
image=image,
|
||||
num_tiles=num_tiles,
|
||||
num_embeddings=num_embeddings,
|
||||
patch_size=(target_patch_width, target_patch_height),
|
||||
), token_count
|
||||
|
||||
def augment_resolution(
|
||||
self,
|
||||
target_patch_width: int,
|
||||
target_patch_height: int,
|
||||
current_num_tokens_available: int,
|
||||
) -> tuple[int, int]:
|
||||
min_num_patch_one_side = 32
|
||||
|
||||
if random.random() < 0.5:
|
||||
# Minus one
|
||||
if (
|
||||
target_patch_width <= min_num_patch_one_side
|
||||
and target_patch_height <= min_num_patch_one_side
|
||||
):
|
||||
return target_patch_width, target_patch_height
|
||||
elif target_patch_width <= min_num_patch_one_side:
|
||||
return target_patch_width, target_patch_height - min_num_patch_one_side
|
||||
elif target_patch_height <= min_num_patch_one_side:
|
||||
return target_patch_width - min_num_patch_one_side, target_patch_height
|
||||
else:
|
||||
if random.random() < 0.5:
|
||||
return (
|
||||
target_patch_width - min_num_patch_one_side,
|
||||
target_patch_height,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
target_patch_width,
|
||||
target_patch_height - min_num_patch_one_side,
|
||||
)
|
||||
else:
|
||||
# Plus one
|
||||
if target_patch_width * target_patch_height < current_num_tokens_available:
|
||||
if random.random() < 0.5:
|
||||
return (
|
||||
target_patch_width + min_num_patch_one_side,
|
||||
target_patch_height,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
target_patch_width,
|
||||
target_patch_height + min_num_patch_one_side,
|
||||
)
|
||||
return target_patch_width, target_patch_height
|
||||
|
||||
def compute_params(
|
||||
self,
|
||||
media_list: list[Image.Image],
|
||||
num_tokens_available: int | None = None,
|
||||
max_num_tiles: int | None = None,
|
||||
data_augment: bool = False,
|
||||
**kwargs,
|
||||
) -> list[DynamicResolutionParams]:
|
||||
"""Compute parameters for all media with iterative token budgeting.
|
||||
|
||||
Args:
|
||||
media_list: List of media items to process
|
||||
num_tokens_available: Total number of tokens available across all media
|
||||
max_num_tiles: Maximum number of tiles (unused in this implementation)
|
||||
data_augment: Whether to apply data augmentation to the image. Defaults to False.
|
||||
Returns:
|
||||
List of ImageTilingParams for each media item
|
||||
"""
|
||||
num_tokens_available = (
|
||||
num_tokens_available
|
||||
* (4 if self._pixel_shuffle else 1)
|
||||
* (4 if self._conv_merging else 1)
|
||||
)
|
||||
# When the number of available token is too small, allow self._min_num_patches per media and
|
||||
# let the sample be truncated.
|
||||
num_tokens_available = max(
|
||||
num_tokens_available, self._min_num_patches * len(media_list)
|
||||
)
|
||||
|
||||
# Clip the number of tokens available per media to be between min and max patches.
|
||||
num_tokens_available_per_media = [
|
||||
max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
|
||||
for _ in range(len(media_list))
|
||||
]
|
||||
|
||||
# In theory this could be a while True loop, but in case the process_media method slightly
|
||||
# changes, I want to make sure we don't get stuck in an infinite loop.
|
||||
for _ in range(10):
|
||||
# Step 1: Process each media with current token budget
|
||||
params = []
|
||||
token_counts = []
|
||||
|
||||
for media, tokens_for_media in zip(
|
||||
media_list, num_tokens_available_per_media
|
||||
):
|
||||
param, token_count = self.process_media(
|
||||
media, tokens_for_media, data_augment=data_augment
|
||||
)
|
||||
params.append(param)
|
||||
token_counts.append(token_count)
|
||||
|
||||
# Step 2: Check if total tokens is within budget
|
||||
total_tokens = sum(token_counts)
|
||||
|
||||
if total_tokens <= num_tokens_available:
|
||||
# We're within budget, return the params
|
||||
return params
|
||||
|
||||
# Step 3: We're over budget, need to scale down
|
||||
# Calculate scaling factor to get under budget
|
||||
scaling_factor = num_tokens_available / total_tokens
|
||||
|
||||
# Recalculate token budgets for each media based on scaling
|
||||
# Each media gets a proportional share of the total budget
|
||||
scaled_down_num_tokens_available_per_media = [
|
||||
max(self._min_num_patches, int(token_count * scaling_factor))
|
||||
for token_count in token_counts
|
||||
]
|
||||
scaled_down = any(
|
||||
[
|
||||
scaled_down_num_tokens_available_per_media[i]
|
||||
< num_tokens_available_per_media[i]
|
||||
for i in range(len(num_tokens_available_per_media))
|
||||
]
|
||||
)
|
||||
# If there was not scaling down, we're stuck just use min_num_patches per media, else
|
||||
# try with the scaled down num_tokens_available_per_media.
|
||||
if not scaled_down:
|
||||
num_tokens_available_per_media = [self._min_num_patches] * len(
|
||||
media_list
|
||||
)
|
||||
else:
|
||||
num_tokens_available_per_media = (
|
||||
scaled_down_num_tokens_available_per_media
|
||||
)
|
||||
return params
|
||||
|
||||
def stack(
|
||||
self, images: list[torch.Tensor]
|
||||
) -> tuple[torch.Tensor, list[tuple[int, int]], list[int] | None, list[int] | None]:
|
||||
imgs_sizes = torch.tensor(
|
||||
[[img.shape[1], img.shape[2]] for img in images], dtype=torch.int32
|
||||
)
|
||||
|
||||
def rearrange_img(x):
|
||||
py = x.shape[-2] // self._patch_size
|
||||
px = x.shape[-1] // self._patch_size
|
||||
x = einops.rearrange(
|
||||
x,
|
||||
"c (py yy) (px xx) -> (py px) (c yy xx)",
|
||||
py=py,
|
||||
yy=self._patch_size,
|
||||
px=px,
|
||||
xx=self._patch_size,
|
||||
)
|
||||
return x
|
||||
|
||||
if len(images) > 0:
|
||||
imgs = [rearrange_img(img) for img in images]
|
||||
|
||||
current_length = 0
|
||||
max_length = 0
|
||||
vision_cu_lengths = [0]
|
||||
for img in imgs:
|
||||
if max_length < img.shape[0]:
|
||||
max_length = img.shape[0]
|
||||
current_length += img.shape[0]
|
||||
vision_cu_lengths.append(current_length)
|
||||
|
||||
vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
|
||||
vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
|
||||
|
||||
return (
|
||||
torch.cat(imgs, dim=0).unsqueeze(0),
|
||||
imgs_sizes,
|
||||
vision_cu_lengths,
|
||||
vision_max_lengths,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
torch.tensor([[0]], dtype=torch.float32),
|
||||
torch.tensor([[0, 0]], dtype=torch.int32),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return f"DynamicResolutionImageTransform(vision_model_type={self._vision_model_type}, min_num_patches={self._min_num_patches}, patch_size={self._patch_size}, pixel_shuffle={self._pixel_shuffle}, conv_merging={self._conv_merging}, use_thumbnail={self._use_thumbnail}, thumbnail_size={self._thumbnail_size}, thumbnail_area_threshold={self._thumbnail_area_threshold})"
|
||||
|
||||
|
||||
|
||||
image_tiling_strategy = DynamicResolutionImageTilingStrategy(
|
||||
vision_model_type="radio",
|
||||
min_num_patches=4,
|
||||
patch_size=16,
|
||||
get_num_embeddings=lambda x, y: x * y * 2,
|
||||
max_num_patches=64,
|
||||
)
|
||||
|
||||
|
||||
IMG_START = "<img>"
|
||||
IMG_END = "</img>"
|
||||
IMG_CONTEXT = "<image>"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user