mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 19:25:52 +08:00
509 lines
25 KiB
Python
509 lines
25 KiB
Python
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE.
|
|
import math
|
|
|
|
import torch
|
|
from einops import rearrange
|
|
from torchvision import transforms as T
|
|
from torchvision.transforms import Compose
|
|
from torchvision.transforms.functional import InterpolationMode
|
|
|
|
|
|
IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
|
|
IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
|
|
SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
|
|
SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
|
|
CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
|
CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
|
|
RADIO_G_PIXEL_MEAN = [0.4850, 0.4560, 0.4060]
|
|
RADIO_G_PIXEL_STD = [0.2230, 0.2240, 0.2250]
|
|
|
|
|
|
pixel_statistics = {
|
|
"clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
|
"siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
|
|
"internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
|
|
"radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
|
"radio-g": (RADIO_G_PIXEL_MEAN, RADIO_G_PIXEL_STD),
|
|
"cradio-g": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
|
|
"internvit300M": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
|
|
"huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
|
|
}
|
|
|
|
|
|
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685
|
|
# Copyright (c) 2023 OpenGVLab.
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
|
best_ratio_diff = float('inf')
|
|
best_ratio = (1, 1)
|
|
area = width * height
|
|
for ratio in target_ratios:
|
|
target_aspect_ratio = ratio[0] / ratio[1]
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
|
if ratio_diff < best_ratio_diff:
|
|
best_ratio_diff = ratio_diff
|
|
best_ratio = ratio
|
|
elif ratio_diff == best_ratio_diff:
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
|
best_ratio = ratio
|
|
return best_ratio
|
|
|
|
|
|
def find_closest_area_weighted_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
|
"""
|
|
Find the best number of tiles based on the aspect ratio and the area covered by the tiles.
|
|
"""
|
|
best_factor = float('-inf')
|
|
best_ratio = (1, 1)
|
|
area = width * height
|
|
for ratio in target_ratios:
|
|
target_aspect_ratio = ratio[0] / ratio[1]
|
|
factor_based_on_area_n_ratio = (
|
|
min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6) *
|
|
min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio))
|
|
if factor_based_on_area_n_ratio > best_factor:
|
|
best_factor = factor_based_on_area_n_ratio
|
|
best_ratio = ratio
|
|
return best_ratio
|
|
|
|
|
|
def process_images(sample_imgs, patch_dim, dynamic_resolution, batch_mode=False):
|
|
"""Process a batch of images for multimodal training or evaluation.
|
|
|
|
This function handles image preprocessing with support for both static and dynamic
|
|
resolution processing. For dynamic resolution, it rearranges images into patches
|
|
and computes cumulative sequence lengths for efficient batching.
|
|
|
|
Args:
|
|
sample_imgs (List[torch.Tensor]): List of image tensors with shape (C, H, W).
|
|
patch_dim (int): Dimension of each patch (e.g., 14 for 14x14 patches).
|
|
dynamic_resolution (bool): Whether to use dynamic resolution processing.
|
|
If True, images are rearranged into patches with variable sequence lengths.
|
|
If False, images are simply stacked into a batch tensor.
|
|
batch_mode (bool, optional): Whether this is being called from training batch processing.
|
|
If True, wraps tensors in additional list dimension for consistency with batch format.
|
|
If False, returns tensors directly as used in evaluation. Defaults to False.
|
|
|
|
Returns:
|
|
tuple: A 4-tuple containing:
|
|
- images (torch.Tensor): Processed image tensor.
|
|
For dynamic resolution: shape (1, total_patches, patch_features) if batch_mode=False,
|
|
or shape (1, total_patches, patch_features) if batch_mode=True
|
|
For static resolution: shape (batch_size, C, H, W)
|
|
- imgs_sizes (torch.Tensor): Image sizes tensor with shape (N, 2)
|
|
containing [width, height] for each image, or [[0,0]] if no images.
|
|
- vision_cu_lengths (torch.Tensor or None): Cumulative sequence lengths
|
|
for dynamic resolution. Shape (batch_size + 1,) for evaluation mode,
|
|
or shape (1, batch_size + 1) for batch mode. None for static resolution.
|
|
- vision_max_lengths (torch.Tensor or None): Maximum sequence length
|
|
among all images for dynamic resolution. Scalar tensor for evaluation mode,
|
|
or shape (1,) for batch mode. None for static resolution.
|
|
|
|
Note:
|
|
This function is designed for processing one microbatch at a time for dynamic resolution.
|
|
"""
|
|
vision_cu_lengths = None
|
|
vision_max_lengths = None
|
|
|
|
if len(sample_imgs) > 0:
|
|
imgs_sizes = torch.tensor([[img.shape[1], img.shape[2]] for img in sample_imgs], dtype=torch.int32)
|
|
if dynamic_resolution:
|
|
def rearrange_img(x):
|
|
py = x.shape[-2] // patch_dim
|
|
px = x.shape[-1] // patch_dim
|
|
x = rearrange(x, 'c (py yy) (px xx) -> (py px) (c yy xx)',
|
|
py=py, yy=patch_dim,
|
|
px=px, xx=patch_dim,
|
|
)
|
|
return x
|
|
imgs = [rearrange_img(img) for img in sample_imgs]
|
|
|
|
current_length = 0
|
|
max_length = 0
|
|
vision_cu_lengths = [0]
|
|
for img in imgs:
|
|
if max_length < img.shape[0]:
|
|
max_length = img.shape[0]
|
|
current_length += img.shape[0]
|
|
vision_cu_lengths.append(current_length)
|
|
|
|
vision_cu_lengths = torch.tensor(vision_cu_lengths, dtype=torch.int32)
|
|
vision_max_lengths = torch.tensor(max_length, dtype=torch.int32)
|
|
|
|
# For batch mode, wrap in additional dimension for consistency
|
|
if batch_mode:
|
|
vision_cu_lengths = vision_cu_lengths.unsqueeze(0) # Shape: (1, batch_size + 1)
|
|
vision_max_lengths = vision_max_lengths.unsqueeze(0) # Shape: (1,)
|
|
|
|
imgs = torch.cat(imgs, dim=0)
|
|
images = imgs.unsqueeze(0)
|
|
else:
|
|
images = torch.stack(sample_imgs)
|
|
else:
|
|
imgs_sizes = torch.tensor([[0,0]], dtype=torch.int32)
|
|
if len(sample_imgs) == 0 and batch_mode:
|
|
# For batch mode when no images, use appropriate dummy tensor
|
|
images = torch.tensor([[0]], dtype=torch.float32)
|
|
else:
|
|
images = torch.stack(sample_imgs)
|
|
|
|
return images, imgs_sizes, vision_cu_lengths, vision_max_lengths
|
|
|
|
|
|
class ImageTransform:
|
|
"""Image transformation."""
|
|
|
|
def __init__(self, input_size, vision_model_type, *, dynamic_resolution=False, res_step=16, min_num_patches=1, max_num_patches=128, pixel_shuffle=False, min_side=None, conv_merging=False, match_tiling_dynamic_resolution=False, masked_tiling_dynamic_resolution=False, thumbnail_area_threshold=0.8):
|
|
self._transform = _build_transform(input_size, vision_model_type)
|
|
self._vision_model_type = vision_model_type
|
|
self._dynamic_resolution = dynamic_resolution
|
|
self._res_step = res_step
|
|
self._min_num_patches = min_num_patches
|
|
self._max_num_patches = max_num_patches
|
|
self._pixel_shuffle = pixel_shuffle
|
|
self._min_side = min_side
|
|
self._conv_merging = conv_merging
|
|
self._match_tiling_dynamic_resolution = match_tiling_dynamic_resolution
|
|
self._masked_tiling_dynamic_resolution = masked_tiling_dynamic_resolution
|
|
self._thumbnail_area_threshold = thumbnail_area_threshold
|
|
|
|
def __call__(self, img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, find_closest_aspect_ratio_fn=find_closest_aspect_ratio, is_video=False):
|
|
assert not augment, "Image augmentation not implemented."
|
|
if use_tiling:
|
|
assert img_h == img_w, "dynamic tiling expects equal tile height and width"
|
|
imgs = dynamic_preprocess(
|
|
img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail,
|
|
find_closest_aspect_ratio_fn=find_closest_aspect_ratio_fn)
|
|
imgs = [self._transform(img) for img in imgs]
|
|
elif self._masked_tiling_dynamic_resolution:
|
|
assert img_h == img_w, "masked tiling dynamic resolution expects equal tile height and width"
|
|
assert "radio" in self._vision_model_type, "Masked tiling dynamic resolution is only supported for radio models"
|
|
|
|
# Use tiling logic to determine tile grid (nx, ny)
|
|
orig_width, orig_height = img.size
|
|
aspect_ratio = orig_width / orig_height
|
|
|
|
target_ratios = set(
|
|
(i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1)
|
|
if i * j <= max_num_tiles and i * j >= 1
|
|
)
|
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
|
|
|
tiling = find_closest_aspect_ratio_fn(
|
|
aspect_ratio, target_ratios, orig_width, orig_height, img_h
|
|
)
|
|
|
|
# Resize and split into tiles of size (img_h x img_h)
|
|
target_width = img_h * tiling[0]
|
|
target_height = img_w * tiling[1]
|
|
blocks = tiling[0] * tiling[1]
|
|
|
|
resized_img = img.resize((target_width, target_height))
|
|
processed_images = []
|
|
for i in range(blocks):
|
|
box = (
|
|
(i % (target_width // img_h)) * img_h,
|
|
(i // (target_width // img_h)) * img_h,
|
|
((i % (target_width // img_h)) + 1) * img_h,
|
|
((i // (target_width // img_h)) + 1) * img_h,
|
|
)
|
|
tile_img = resized_img.crop(box)
|
|
processed_images.append(tile_img)
|
|
assert len(processed_images) == blocks
|
|
|
|
# Optional thumbnail
|
|
if use_thumbnail and blocks != 1:
|
|
thumbnail_img = img.resize((img_h, img_h))
|
|
processed_images.append(thumbnail_img)
|
|
|
|
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
|
|
transform = T.Compose([
|
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=pixel_mean, std=pixel_std),
|
|
])
|
|
imgs = [transform(im) for im in processed_images]
|
|
elif self._match_tiling_dynamic_resolution:
|
|
assert img_h == img_w, "match tiling dynamic resolution expects equal tile height and width"
|
|
assert "radio" in self._vision_model_type, "Match tiling dynamic resolution is only supported for radio models"
|
|
|
|
# Use tiling logic to determine optimal dimensions
|
|
orig_width, orig_height = img.size
|
|
aspect_ratio = orig_width / orig_height
|
|
|
|
# Calculate target ratios (same logic as tiling)
|
|
target_ratios = set(
|
|
(i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
|
i * j <= max_num_tiles and i * j >= 1)
|
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
|
|
|
# Find the closest aspect ratio to the target
|
|
target_aspect_ratio = find_closest_aspect_ratio_fn(
|
|
aspect_ratio, target_ratios, orig_width, orig_height, img_h)
|
|
|
|
# Calculate the target width and height using tiling logic
|
|
target_width = img_h * target_aspect_ratio[0]
|
|
target_height = img_w * target_aspect_ratio[1]
|
|
|
|
# Resize image to target dimensions (same as tiling, but don't split)
|
|
resized_img = img.resize((target_width, target_height))
|
|
|
|
# Process as single dynamic resolution image
|
|
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
|
|
transform = T.Compose([
|
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=pixel_mean, std=pixel_std),
|
|
])
|
|
processed_images = [resized_img]
|
|
|
|
# Add thumbnail if use_thumbnail=True and there's more than 1 tile
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
|
if use_thumbnail and blocks != 1:
|
|
thumbnail_img = img.resize((img_h, img_h))
|
|
processed_images.append(thumbnail_img)
|
|
|
|
imgs = [transform(img) for img in processed_images]
|
|
elif self._dynamic_resolution:
|
|
pixel_mean, pixel_std = pixel_statistics[self._vision_model_type]
|
|
transform = T.Compose([
|
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=pixel_mean, std=pixel_std),
|
|
])
|
|
processed_img = dynamic_res_preprocess(img, min_patches=self._min_num_patches, max_patches=self._max_num_patches, res_step=self._res_step, pixel_shuffle=self._pixel_shuffle, min_side=self._min_side, conv_merging=self._conv_merging, is_video=is_video)
|
|
processed_images = [processed_img]
|
|
|
|
# Add thumbnail if enabled and image area is below threshold
|
|
if use_thumbnail:
|
|
# Calculate areas
|
|
processed_width, processed_height = processed_img.size
|
|
resized_area = processed_width * processed_height
|
|
thumbnail_area = img_h * img_h # img_h should be square thumbnail size
|
|
area_ratio = resized_area / thumbnail_area
|
|
|
|
# Only add thumbnail if resized image area is less than threshold % of thumbnail area
|
|
if area_ratio < self._thumbnail_area_threshold:
|
|
thumbnail_img = img.resize((img_h, img_h)) # Use square thumbnail with img_h size
|
|
processed_images.append(thumbnail_img)
|
|
|
|
imgs = [transform(img) for img in processed_images]
|
|
else:
|
|
imgs = [self._transform(img)]
|
|
|
|
return imgs
|
|
|
|
|
|
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702
|
|
# Copyright (c) 2023 OpenGVLab.
|
|
def dynamic_preprocess(
|
|
image, min_num=1, max_num=6, image_size=448, use_thumbnail=False,
|
|
find_closest_aspect_ratio_fn=find_closest_aspect_ratio):
|
|
orig_width, orig_height = image.size
|
|
aspect_ratio = orig_width / orig_height
|
|
|
|
# calculate the existing image aspect ratio
|
|
target_ratios = set(
|
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
|
i * j <= max_num and i * j >= min_num)
|
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
|
|
|
# find the closest aspect ratio to the target
|
|
target_aspect_ratio = find_closest_aspect_ratio_fn(
|
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
|
|
|
# calculate the target width and height
|
|
target_width = image_size * target_aspect_ratio[0]
|
|
target_height = image_size * target_aspect_ratio[1]
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
|
|
|
# resize the image
|
|
resized_img = image.resize((target_width, target_height))
|
|
processed_images = []
|
|
for i in range(blocks):
|
|
box = (
|
|
(i % (target_width // image_size)) * image_size,
|
|
(i // (target_width // image_size)) * image_size,
|
|
((i % (target_width // image_size)) + 1) * image_size,
|
|
((i // (target_width // image_size)) + 1) * image_size
|
|
)
|
|
# split the image
|
|
split_img = resized_img.crop(box)
|
|
processed_images.append(split_img)
|
|
assert len(processed_images) == blocks
|
|
if use_thumbnail and len(processed_images) != 1:
|
|
thumbnail_img = image.resize((image_size, image_size))
|
|
processed_images.append(thumbnail_img)
|
|
return processed_images
|
|
|
|
|
|
def dynamic_res_preprocess(image, min_patches=1, max_patches=128, res_step=16, factor_max=1., pixel_shuffle=False, min_side=None, conv_merging=False, is_video=False):
|
|
"""Preprocess an image with dynamic resolution for vision transformers.
|
|
|
|
This function resizes an image to optimize the number of patches while respecting
|
|
constraints on minimum/maximum patches, minimum side length, and compatibility
|
|
with pixel shuffle or convolution merging operations.
|
|
|
|
The algorithm works by:
|
|
1. Computing the initial patch grid size based on the image dimensions and res_step
|
|
2. Scaling the patch grid to fit within the max_patches constraint
|
|
3. Ensuring the result has at least min_patches
|
|
4. Optionally enforcing a minimum side length constraint
|
|
5. Rounding patch dimensions to even numbers for pixel_shuffle/conv_merging compatibility
|
|
6. Resizing the image to the computed target dimensions
|
|
|
|
Args:
|
|
image (PIL.Image): Input image to preprocess.
|
|
min_patches (int, optional): Minimum number of patches required. Defaults to 1.
|
|
max_patches (int, optional): Maximum number of patches allowed. Defaults to 128.
|
|
res_step (int, optional): Resolution step size (patch dimension). Defaults to 16.
|
|
factor_max (float, optional): Maximum scaling factor to apply. Defaults to 1.0.
|
|
pixel_shuffle (bool, optional): Whether to ensure compatibility with pixel shuffle
|
|
operations by rounding to even patch dimensions. Defaults to False.
|
|
min_side (int, optional): Minimum side length in pixels. If specified, ensures
|
|
at least one side meets this constraint. Defaults to None.
|
|
conv_merging (bool, optional): Whether to ensure compatibility with convolution
|
|
merging by rounding to even patch dimensions. Defaults to False.
|
|
|
|
Returns:
|
|
PIL.Image: Resized image with dimensions optimized for patch-based processing.
|
|
The output dimensions will be (target_patch_width * res_step, target_patch_height * res_step).
|
|
|
|
Note:
|
|
The function preserves aspect ratio as much as possible while satisfying all constraints.
|
|
When constraints conflict (e.g., min_side vs max_patches), the function prioritizes
|
|
staying within max_patches while maximizing the image size.
|
|
|
|
Example:
|
|
>>> from PIL import Image
|
|
>>> img = Image.open("example.jpg") # 800x600 image
|
|
>>> resized_img = dynamic_res_preprocess(img, min_patches=4, max_patches=64, res_step=14)
|
|
>>> # Returns image resized to maintain aspect ratio with 4-64 patches of size 14x14
|
|
"""
|
|
orig_width, orig_height = image.size
|
|
|
|
closest_patch_height = round(orig_height / res_step + 0.5)
|
|
closest_patch_width = round(orig_width / res_step + 0.5)
|
|
patches = closest_patch_height * closest_patch_width
|
|
|
|
factor = min(math.sqrt(max_patches / patches), factor_max)
|
|
target_patch_height = math.floor(factor * closest_patch_height)
|
|
target_patch_width = math.floor(factor * closest_patch_width)
|
|
|
|
if target_patch_height * target_patch_width < min_patches:
|
|
up_factor = math.sqrt(min_patches / (target_patch_height * target_patch_width))
|
|
target_patch_height = math.ceil(up_factor * target_patch_height)
|
|
target_patch_width = math.ceil(up_factor * target_patch_width)
|
|
|
|
if min_side is not None and min(target_patch_width, target_patch_height) * res_step < min_side:
|
|
if target_patch_width <= target_patch_height:
|
|
up_factor = min_side / (target_patch_width * res_step)
|
|
new_patch_height = math.ceil(up_factor * target_patch_height)
|
|
new_patch_width = math.ceil(up_factor * target_patch_width)
|
|
|
|
if new_patch_height * new_patch_width > max_patches:
|
|
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
|
|
if max(max_patches // new_patch_width, 1) * res_step < min_side:
|
|
up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width))
|
|
target_patch_height = math.floor(up_factor * target_patch_height)
|
|
target_patch_width = math.floor(up_factor * target_patch_width)
|
|
target_patch_width = new_patch_width
|
|
target_patch_height = max(max_patches // new_patch_width, 1)
|
|
else:
|
|
target_patch_height = new_patch_height
|
|
target_patch_width = new_patch_width
|
|
else:
|
|
up_factor = min_side / (target_patch_height * res_step)
|
|
new_patch_height = math.ceil(up_factor * target_patch_height)
|
|
new_patch_width = math.ceil(up_factor * target_patch_width)
|
|
|
|
if new_patch_height * new_patch_width > max_patches:
|
|
# If only one side can be min_side, make as big as possible at native aspect ratio while staying below max_patches
|
|
if max(max_patches // new_patch_height, 1) * res_step < min_side:
|
|
up_factor = math.sqrt(max_patches / (target_patch_height * target_patch_width))
|
|
target_patch_height = math.floor(up_factor * target_patch_height)
|
|
target_patch_width = math.floor(up_factor * target_patch_width)
|
|
else:
|
|
target_patch_height = new_patch_height
|
|
target_patch_width = max(max_patches // new_patch_height, 1)
|
|
else:
|
|
target_patch_height = new_patch_height
|
|
target_patch_width = new_patch_width
|
|
|
|
# Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
|
|
# or by 4 when BOTH are enabled (two successive 2x reductions)
|
|
if pixel_shuffle or conv_merging:
|
|
required_divisor = 4 if (pixel_shuffle and conv_merging) else 2
|
|
|
|
rem_h = target_patch_height % required_divisor
|
|
if rem_h != 0:
|
|
inc_h = required_divisor - rem_h
|
|
if (target_patch_height + inc_h) * target_patch_width <= max_patches:
|
|
target_patch_height += inc_h
|
|
else:
|
|
target_patch_height = max(1, target_patch_height - rem_h)
|
|
|
|
rem_w = target_patch_width % required_divisor
|
|
if rem_w != 0:
|
|
inc_w = required_divisor - rem_w
|
|
if target_patch_height * (target_patch_width + inc_w) <= max_patches:
|
|
target_patch_width += inc_w
|
|
else:
|
|
target_patch_width = max(1, target_patch_width - rem_w)
|
|
assert target_patch_height * target_patch_width <= max_patches
|
|
|
|
#TEMP: hacky way to process video same as in training
|
|
if is_video:
|
|
# max_patches = 1024
|
|
# min_patches = 512
|
|
target_patch_width = 32
|
|
target_patch_height = 32
|
|
|
|
|
|
# resize the image
|
|
resized_img = image.resize((target_patch_width * res_step, target_patch_height * res_step))
|
|
|
|
return resized_img
|
|
|
|
|
|
|
|
# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
|
|
# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
|
|
def _build_transform(input_size, vision_model_type):
|
|
if vision_model_type in ("siglip", "internvit", "internvit300M", "radio", "radio-g", "cradio-g"):
|
|
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
|
|
|
|
transform = T.Compose([
|
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=pixel_mean, std=pixel_std)
|
|
])
|
|
elif vision_model_type == "clip":
|
|
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
|
|
|
|
transform = Compose([
|
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=pixel_mean, std=pixel_std),
|
|
])
|
|
elif vision_model_type.startswith("hf://"):
|
|
from megatron.core.models.huggingface.module import get_hf_model_type
|
|
|
|
model_type = get_hf_model_type(vision_model_type)
|
|
if "siglip" in model_type:
|
|
from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
|
|
|
|
processor = SiglipImageProcessor(size={"height": input_size, "width": input_size})
|
|
|
|
def transform(x):
|
|
x = x.convert("RGB") if x.mode != "RGB" else x
|
|
x = processor(x, return_tensors="pt")
|
|
return x["pixel_values"][0]
|
|
else:
|
|
raise NotImplementedError(f"image processing not defined for huggingface model {vision_model_type}")
|
|
else:
|
|
raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}")
|
|
|
|
return transform
|