vllm/image_processing.py
Netanel Haber 6979edb575 import
2025-12-15 07:48:55 -08:00

509 lines
25 KiB
Python

# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE.
import math
import torch
from einops import rearrange
from torchvision import transforms as T
from torchvision.transforms import Compose
from torchvision.transforms.functional import InterpolationMode
IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060]
RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250]
pixel_statistics = {
"clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
"internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
"radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
"cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"internvit300M": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
"huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
}
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685
# Copyright (c) 2023 OpenGVLab.
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def find_closest_area_weighted_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
"""
Find the best number of tiles based on the aspect ratio and the area covered by the tiles.
"""
best_factor = float('-inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
factor_based_on_area_n_ratio = (
min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6) *
min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio))
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
def process_images(sample_imgs, patch_dim, dynamic_resolution, batch_mode=False):
"""Process a batch of images for multimodal training or evaluation.
This function handles image preprocessing with support for both static and dynamic
resolution processing. For dynamic resolution, it rearranges images into patches
and computes cumulative sequence lengths for efficient batching.
Args:
sample_imgs (List[torch.Tensor]): List of image tensors with shape (C, H, W).
patch_dim (int): Dimension of each patch (e.g., 14 for 14x14 patches).
dynamic_resolution (bool): Whether to use dynamic resolution processing.
If True, images are rearranged into patches with variable sequence lengths.
If False, images are simply stacked into a batch tensor.
batch_mode (bool, optional): Whether this is being called from training batch processing.
If True, wraps tensors in additional list dimension for consistency with batch format.
If False, returns tensors directly as used in evaluation. Defaults to False.
Returns:
tuple: A 4-tuple containing:
- images (torch.Tensor): Processed image tensor.
For dynamic resolution: shape (1, total_patches, patch_features) if batch_mode=False,
or shape (1, total_patches, patch_features) if batch_mode=True
For static resolution: shape (batch_size, C, H, W)
- imgs_sizes (torch.Tensor): Image sizes tensor with shape (N, 2)
containing [width, height] for each image, or [[0,0]] if no images.
- vision_cu_lengths (torch.Tensor or None): Cumulative sequence lengths
for dynamic resolution. Shape (batch_size + 1,) for evaluation mode,
or shape (1, batch_size + 1) for batch mode. None for static resolution.
- vision_max_lengths (torch.Tensor or None): Maximum sequence length
among all images for dynamic resolution. Scalar tensor for evaluation mode,
or shape (1,) for batch mode. None for static resolution.
Note:
This function is designed for processing one microbatch at a time for dynamic resolution.
"""
vision_cu_lengths = None
vision_max_lengths = None
if len(sample_imgs) > 0:
imgs_sizes = torch.tensor([[img.shape[1], img.shape[2]] for img in sample_imgs], dtype=torch.int32)
if dynamic_resolution:
def rearrange_img(x):
py = x.shape[-2] // patch_dim
px = x.shape[-1] // patch_dim
x = rearrange(x, 'c (py yy) (px xx) -> (py px) (c yy xx)',
py=py, yy=patch_dim,
px=px, xx=patch_dim,
)
return x
imgs = [rearrange_img(img) for img in sample_imgs]
current_length = 0
max_length = 0
vision_cu_lengths = [0]
for img in imgs:
if max_length < img.shape[0]:
max_length = img.shape[0]
current_length += img.shape[0]
vision_cu_lengths.append(current_length)
vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
# For batch mode, wrap in additional dimension for consistency
if batch_mode:
vision_cu_lengths = vision_cu_lengths.unsqueeze(0) # Shape: (1, batch_size + 1)
vision_max_lengths = vision_max_lengths.unsqueeze(0) # Shape: (1,)
imgs = torch.cat(imgs, dim=0)
images = imgs.unsqueeze(0)
else:
images = torch.stack(sample_imgs)
else:
imgs_sizes = torch.tensor([[0,0]], dtype=torch.int32)
if len(sample_imgs) == 0 and batch_mode:
# For batch mode when no images, use appropriate dummy tensor
images = torch.tensor([[0]], dtype=torch.float32)
else:
images = torch.stack(sample_imgs)
return images, imgs_sizes, vision_cu_lengths, vision_max_lengths
class ImageTransform:
"""Image transformation."""
def __init__(self, input_size, vision_model_type, *, dynamic_resolution=False, res_step=16, min_num_patches=1, max_num_patches=128, pixel_shuffle=False, min_side=None, conv_merging=False, match_tiling_dynamic_resolution=False, masked_tiling_dynamic_resolution=False, thumbnail_area_threshold=0.8):
self._transform = _build_transform(input_size, vision_model_type)
self._vision_model_type = vision_model_type
self._dynamic_resolution = dynamic_resolution
self._res_step = res_step
self._min_num_patches = min_num_patches
self._max_num_patches = max_num_patches
self._pixel_shuffle = pixel_shuffle
self._min_side = min_side
self._conv_merging = conv_merging
self._match_tiling_dynamic_resolution = match_tiling_dynamic_resolution
self._masked_tiling_dynamic_resolution = masked_tiling_dynamic_resolution
self._thumbnail_area_threshold = thumbnail_area_threshold
def __call__(self, img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, find_closest_aspect_ratio_fn=find_closest_aspect_ratio, is_video=False):
assert not augment, "Image augmentation not implemented."
if use_tiling:
assert img_h == img_w, "dynamic tiling expects equal tile height and width"
imgs = dynamic_preprocess(
img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail,
find_closest_aspect_ratio_fn=find_closest_aspect_ratio_fn)
imgs = [self._transform(img) for img in imgs]
elif self._masked_tiling_dynamic_resolution:
assert img_h == img_w, "masked tiling dynamic resolution expects equal tile height and width"
assert "radio" in self._vision_model_type, "Masked tiling dynamic resolution is only supported for radio models"
# Use tiling logic to determine tile grid (nx, ny)
orig_width, orig_height = img.size
aspect_ratio = orig_width / orig_height
target_ratios = set(
(i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1)
if i * j <= max_num_tiles and i * j >= 1
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
tiling = find_closest_aspect_ratio_fn(
aspect_ratio, target_ratios, orig_width, orig_height, img_h
)
# Resize and split into tiles of size (img_h x img_h)
target_width = img_h * tiling[0]
target_height = img_w * tiling[1]
blocks = tiling[0] * tiling[1]
resized_img = img.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // img_h)) * img_h,
(i // (target_width // img_h)) * img_h,
((i % (target_width // img_h)) + 1) * img_h,
((i // (target_width // img_h)) + 1) * img_h,
)
tile_img = resized_img.crop(box)
processed_images.append(tile_img)
assert len(processed_images) == blocks
# Optional thumbnail
if use_thumbnail and blocks != 1:
thumbnail_img = img.resize((img_h, img_h))
processed_images.append(thumbnail_img)
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.ToTensor(),
T.Normalize(mean=pixel_mean, std=pixel_std),
])
imgs = [transform(im) for im in processed_images]
elif self._match_tiling_dynamic_resolution:
assert img_h == img_w, "match tiling dynamic resolution expects equal tile height and width"
assert "radio" in self._vision_model_type, "Match tiling dynamic resolution is only supported for radio models"
# Use tiling logic to determine optimal dimensions
orig_width, orig_height = img.size
aspect_ratio = orig_width / orig_height
# Calculate target ratios (same logic as tiling)
target_ratios = set(
(i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num_tiles and i * j >= 1)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# Find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio_fn(
aspect_ratio, target_ratios, orig_width, orig_height, img_h)
# Calculate the target width and height using tiling logic
target_width = img_h * target_aspect_ratio[0]
target_height = img_w * target_aspect_ratio[1]
# Resize image to target dimensions (same as tiling, but don't split)
resized_img = img.resize((target_width, target_height))
# Process as single dynamic resolution image
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.ToTensor(),
T.Normalize(mean=pixel_mean, std=pixel_std),
])
processed_images = [resized_img]
# Add thumbnail if use_thumbnail=True and there's more than 1 tile
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
if use_thumbnail and blocks != 1:
thumbnail_img = img.resize((img_h, img_h))
processed_images.append(thumbnail_img)
imgs = [transform(img) for img in processed_images]
elif self._dynamic_resolution:
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.ToTensor(),
T.Normalize(mean=pixel_mean, std=pixel_std),
])
processed_img = dynamic_res_preprocess(img, min_patches=self._min_num_patches, max_patches=self._max_num_patches, res_step=self._res_step, pixel_shuffle=self._pixel_shuffle, min_side=self._min_side, conv_merging=self._conv_merging, is_video=is_video)
processed_images = [processed_img]
# Add thumbnail if enabled and image area is below threshold
if use_thumbnail:
# Calculate areas
processed_width, processed_height = processed_img.size
resized_area = processed_width * processed_height
thumbnail_area = img_h * img_h # img_h should be square thumbnail size
area_ratio = resized_area / thumbnail_area
# Only add thumbnail if resized image area is less than threshold % of thumbnail area
if area_ratio < self._thumbnail_area_threshold:
thumbnail_img = img.resize((img_h, img_h)) # Use square thumbnail with img_h size
processed_images.append(thumbnail_img)
imgs = [transform(img) for img in processed_images]
else:
imgs = [self._transform(img)]
return imgs
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702
# Copyright (c) 2023 OpenGVLab.
def dynamic_preprocess(
image, min_num=1, max_num=6, image_size=448, use_thumbnail=False,
find_closest_aspect_ratio_fn=find_closest_aspect_ratio):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio_fn(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def dynamic_res_preprocess(image, min_patches=1, max_patches=128, res_step=16, factor_max=1., pixel_shuffle=False, min_side=None, conv_merging=False, is_video=False):
"""Preprocess an image with dynamic resolution for vision transformers.
This function resizes an image to optimize the number of patches while respecting
constraints on minimum/maximum patches, minimum side length, and compatibility
with pixel shuffle or convolution merging operations.
The algorithm works by:
1. Computing the initial patch grid size based on the image dimensions and res_step
2. Scaling the patch grid to fit within the max_patches constraint
3. Ensuring the result has at least min_patches
4. Optionally enforcing a minimum side length constraint
5. Rounding patch dimensions to even numbers for pixel_shuffle/conv_merging compatibility
6. Resizing the image to the computed target dimensions
Args:
image (PIL.Image): Input image to preprocess.
min_patches (int, optional): Minimum number of patches required. Defaults to 1.
max_patches (int, optional): Maximum number of patches allowed. Defaults to 128.
res_step (int, optional): Resolution step size (patch dimension). Defaults to 16.
factor_max (float, optional): Maximum scaling factor to apply. Defaults to 1.0.
pixel_shuffle (bool, optional): Whether to ensure compatibility with pixel shuffle
operations by rounding to even patch dimensions. Defaults to False.
min_side (int, optional): Minimum side length in pixels. If specified, ensures
at least one side meets this constraint. Defaults to None.
conv_merging (bool, optional): Whether to ensure compatibility with convolution
merging by rounding to even patch dimensions. Defaults to False.
Returns:
PIL.Image: Resized image with dimensions optimized for patch-based processing.
The output dimensions will be (target_patch_width * res_step, target_patch_height * res_step).
Note:
The function preserves aspect ratio as much as possible while satisfying all constraints.
When constraints conflict (e.g., min_side vs max_patches), the function prioritizes
staying within max_patches while maximizing the image size.
Example:
>>> from PIL import Image
>>> img = Image.open("example.jpg") # 800x600 image
>>> resized_img = dynamic_res_preprocess(img, min_patches=4, max_patches=64, res_step=14)
>>> # Returns image resized to maintain aspect ratio with 4-64 patches of size 14x14
"""
orig_width, orig_height = image.size
closest_patch_height = round(orig_height / res_step + 0.5)
closest_patch_width = round(orig_width / res_step + 0.5)
patches = closest_patch_height * closest_patch_width
factor = min(math.sqrt(max_patches / patches), factor_max)
target_patch_height = math.floor(factor * closest_patch_height)
target_patch_width = math.floor(factor * closest_patch_width)
if target_patch_height * target_patch_width < min_patches:
up_factor = math.sqrt(min_patches / (target_patch_height * target_patch_width))
target_patch_height = math.ceil(up_factor * target_patch_height)
target_patch_width = math.ceil(up_factor * target_patch_width)
if min_side is not None and min(target_patch_width, target_patch_height) * res_step < min_side:
if target_patch_width <= target_patch_height:
up_factor = min_side / (target_patch_width * res_step)
new_patch_height = math.ceil(up_factor * target_patch_height)
new_patch_width = math.ceil(up_factor * target_patch_width)
if new_patch_height * new_patch_width > max_patches:
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
if max(max_patches // new_patch_width, 1) * res_step < min_side:
up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width))
target_patch_height = math.floor(up_factor * target_patch_height)
target_patch_width = math.floor(up_factor * target_patch_width)
target_patch_width = new_patch_width
target_patch_height = max(max_patches // new_patch_width, 1)
else:
target_patch_height = new_patch_height
target_patch_width = new_patch_width
else:
up_factor = min_side / (target_patch_height * res_step)
new_patch_height = math.ceil(up_factor * target_patch_height)
new_patch_width = math.ceil(up_factor * target_patch_width)
if new_patch_height * new_patch_width > max_patches:
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
if max(max_patches // new_patch_height, 1) * res_step < min_side:
up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width))
target_patch_height = math.floor(up_factor * target_patch_height)
target_patch_width = math.floor(up_factor * target_patch_width)
else:
target_patch_height = new_patch_height
target_patch_width = max(max_patches // new_patch_height, 1)
else:
target_patch_height = new_patch_height
target_patch_width = new_patch_width
# Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
# or by 4 when BOTH are enabled (two successive 2x reductions)
if pixel_shuffle or conv_merging:
required_divisor = 4 if (pixel_shuffle and conv_merging) else 2
rem_h = target_patch_height % required_divisor
if rem_h != 0:
inc_h = required_divisor - rem_h
if (target_patch_height + inc_h) * target_patch_width <= max_patches:
target_patch_height += inc_h
else:
target_patch_height = max(1, target_patch_height - rem_h)
rem_w = target_patch_width % required_divisor
if rem_w != 0:
inc_w = required_divisor - rem_w
if target_patch_height * (target_patch_width + inc_w) <= max_patches:
target_patch_width += inc_w
else:
target_patch_width = max(1, target_patch_width - rem_w)
assert target_patch_height * target_patch_width <= max_patches
#TEMP: hacky way to process video same as in training
if is_video:
# max_patches = 1024
# min_patches = 512
target_patch_width = 32
target_patch_height = 32
# resize the image
resized_img = image.resize((target_patch_width * res_step, target_patch_height * res_step))
return resized_img
# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
def _build_transform(input_size, vision_model_type):
if vision_model_type in ("siglip", "internvit", "internvit300M", "radio", "radio-g", "cradio-g"):
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=pixel_mean, std=pixel_std)
])
elif vision_model_type == "clip":
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
transform = Compose([
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.ToTensor(),
T.Normalize(mean=pixel_mean, std=pixel_std),
])
elif vision_model_type.startswith("hf://"):
from megatron.core.models.huggingface.module import get_hf_model_type
model_type = get_hf_model_type(vision_model_type)
if "siglip" in model_type:
from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
processor = SiglipImageProcessor(size={"height": input_size, "width": input_size})
def transform(x):
x = x.convert("RGB") if x.mode != "RGB" else x
x = processor(x, return_tensors="pt")
return x["pixel_values"][0]
else:
raise NotImplementedError(f"image processing not defined for huggingface model {vision_model_type}")
else:
raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}")
return transform