mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 12:19:49 +08:00
[Bugfix] Fix Nemotron VL image processing (#22739)
Co-authored-by: ducviet00-h2 <viet.d.hoang@h2corporation.jp>
This commit is contained in:
parent
9e7e5baaa8
commit
a01e0018b5
@ -23,15 +23,15 @@ def _get_expected_num_patches(
|
|||||||
min_num: int,
|
min_num: int,
|
||||||
max_num: int,
|
max_num: int,
|
||||||
):
|
):
|
||||||
from vllm.model_executor.models.internvl import (
|
from vllm.model_executor.models.nemotron_vl import (
|
||||||
calculate_internvl_targets, get_internvl_target_ratios)
|
calculate_nemotron_vl_targets, get_nemotron_vl_target_ratios)
|
||||||
|
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
|
|
||||||
blocks, _, _ = calculate_internvl_targets(
|
blocks, _, _ = calculate_nemotron_vl_targets(
|
||||||
orig_width=width,
|
orig_width=width,
|
||||||
orig_height=height,
|
orig_height=height,
|
||||||
target_ratios=get_internvl_target_ratios(
|
target_ratios=get_nemotron_vl_target_ratios(
|
||||||
min_num,
|
min_num,
|
||||||
max_num,
|
max_num,
|
||||||
),
|
),
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torchvision.transforms as T
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoModel, PretrainedConfig
|
from transformers import AutoModel, PretrainedConfig
|
||||||
from transformers.image_processing_utils_fast import BaseImageProcessorFast
|
from transformers.image_processing_utils_fast import BaseImageProcessorFast
|
||||||
@ -27,6 +28,7 @@ from vllm.model_executor.models.internvl import (
|
|||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.image import convert_image_mode
|
||||||
from vllm.multimodal.inputs import NestedTensors
|
from vllm.multimodal.inputs import NestedTensors
|
||||||
from vllm.multimodal.processing import PromptUpdateDetails
|
from vllm.multimodal.processing import PromptUpdateDetails
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -44,6 +46,146 @@ IMG_END = '</img>'
|
|||||||
IMG_CONTEXT = '<image>'
|
IMG_CONTEXT = '<image>'
|
||||||
|
|
||||||
|
|
||||||
|
def build_transform(input_size: int):
|
||||||
|
return T.Compose([
|
||||||
|
T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
|
||||||
|
T.Resize((input_size, input_size),
|
||||||
|
interpolation=T.InterpolationMode.BICUBIC),
|
||||||
|
T.ToTensor(),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# adapted from https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1
|
||||||
|
def find_closest_aspect_ratio(
|
||||||
|
aspect_ratio: float,
|
||||||
|
target_ratios: list[tuple[int, int]],
|
||||||
|
*,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
image_size: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
best_factor = float('-inf')
|
||||||
|
best_ratio = (1, 1)
|
||||||
|
area = width * height
|
||||||
|
|
||||||
|
for rw, rh in target_ratios:
|
||||||
|
target_aspect_ratio = rw / rh
|
||||||
|
size_factor = min((rw * rh * image_size * image_size) / area, 0.6)
|
||||||
|
ratio_closeness = min(target_aspect_ratio / aspect_ratio,
|
||||||
|
aspect_ratio / target_aspect_ratio)
|
||||||
|
factor = size_factor * ratio_closeness
|
||||||
|
|
||||||
|
if factor > best_factor:
|
||||||
|
best_factor = factor
|
||||||
|
best_ratio = (rw, rh)
|
||||||
|
|
||||||
|
return best_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_nemotron_vl_targets(
|
||||||
|
*,
|
||||||
|
orig_width: int,
|
||||||
|
orig_height: int,
|
||||||
|
target_ratios: list[tuple[int, int]],
|
||||||
|
image_size: int,
|
||||||
|
use_thumbnail: bool,
|
||||||
|
) -> tuple[int, int, int]:
|
||||||
|
aspect_ratio = orig_width / orig_height
|
||||||
|
|
||||||
|
# find the closest aspect ratio to the target
|
||||||
|
target_aspect_ratio = find_closest_aspect_ratio(
|
||||||
|
aspect_ratio,
|
||||||
|
target_ratios,
|
||||||
|
width=orig_width,
|
||||||
|
height=orig_height,
|
||||||
|
image_size=image_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate the target width and height
|
||||||
|
target_width = image_size * target_aspect_ratio[0]
|
||||||
|
target_height = image_size * target_aspect_ratio[1]
|
||||||
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||||
|
|
||||||
|
# add thumbnail image if num_blocks != 1
|
||||||
|
if use_thumbnail and blocks != 1:
|
||||||
|
blocks += 1
|
||||||
|
|
||||||
|
return blocks, target_width, target_height
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_preprocess_nemotron_vl(
|
||||||
|
image: Image.Image,
|
||||||
|
*,
|
||||||
|
target_ratios: list[tuple[int, int]],
|
||||||
|
image_size: int,
|
||||||
|
use_thumbnail: bool,
|
||||||
|
) -> list[Image.Image]:
|
||||||
|
orig_width, orig_height = image.size
|
||||||
|
|
||||||
|
# calculate the number of blocks without thumbnail
|
||||||
|
blocks, target_width, target_height = calculate_nemotron_vl_targets(
|
||||||
|
orig_width=orig_width,
|
||||||
|
orig_height=orig_height,
|
||||||
|
target_ratios=target_ratios,
|
||||||
|
image_size=image_size,
|
||||||
|
use_thumbnail=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# resize the image
|
||||||
|
resized_img = image.resize((target_width, target_height))
|
||||||
|
processed_images = []
|
||||||
|
for i in range(blocks):
|
||||||
|
box = ((i % (target_width // image_size)) * image_size,
|
||||||
|
(i // (target_width // image_size)) * image_size,
|
||||||
|
((i % (target_width // image_size)) + 1) * image_size,
|
||||||
|
((i // (target_width // image_size)) + 1) * image_size)
|
||||||
|
# split the image
|
||||||
|
split_img = resized_img.crop(box)
|
||||||
|
processed_images.append(split_img)
|
||||||
|
|
||||||
|
assert len(processed_images) == blocks
|
||||||
|
|
||||||
|
if use_thumbnail and len(processed_images) != 1:
|
||||||
|
thumbnail_img = image.resize((image_size, image_size))
|
||||||
|
processed_images.append(thumbnail_img)
|
||||||
|
|
||||||
|
return processed_images
|
||||||
|
|
||||||
|
|
||||||
|
def get_nemotron_vl_target_ratios(
|
||||||
|
min_num: int,
|
||||||
|
max_num: int,
|
||||||
|
) -> list[tuple[int, int]]:
|
||||||
|
target_ratios = {(i, j)
|
||||||
|
for n in range(min_num, max_num + 1)
|
||||||
|
for i in range(1, n + 1)
|
||||||
|
for j in range(1, n + 1) if min_num <= i * j <= max_num}
|
||||||
|
return sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_pixel_values_nemotron_vl(
|
||||||
|
image: Image.Image,
|
||||||
|
*,
|
||||||
|
input_size: int,
|
||||||
|
min_num: int,
|
||||||
|
max_num: int,
|
||||||
|
use_thumbnail: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
target_ratios = get_nemotron_vl_target_ratios(min_num, max_num)
|
||||||
|
|
||||||
|
transform = build_transform(input_size=input_size)
|
||||||
|
|
||||||
|
images = dynamic_preprocess_nemotron_vl(
|
||||||
|
image,
|
||||||
|
target_ratios=target_ratios,
|
||||||
|
image_size=input_size,
|
||||||
|
use_thumbnail=use_thumbnail,
|
||||||
|
)
|
||||||
|
|
||||||
|
pixel_values = torch.stack([transform(image) for image in images])
|
||||||
|
return pixel_values
|
||||||
|
|
||||||
|
|
||||||
class NemotronVLProcessor(InternVLProcessor):
|
class NemotronVLProcessor(InternVLProcessor):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -87,6 +229,50 @@ class NemotronVLProcessor(InternVLProcessor):
|
|||||||
def image_token_id(self) -> int:
|
def image_token_id(self) -> int:
|
||||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||||
|
|
||||||
|
def get_num_image_tokens(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
) -> int:
|
||||||
|
target_ratios = self.resolve_target_ratios(
|
||||||
|
use_thumbnail=False, # Applied in calculate_targets
|
||||||
|
)
|
||||||
|
|
||||||
|
num_patches, _, _ = calculate_nemotron_vl_targets(
|
||||||
|
orig_width=image_width,
|
||||||
|
orig_height=image_height,
|
||||||
|
image_size=self.image_size,
|
||||||
|
target_ratios=target_ratios,
|
||||||
|
use_thumbnail=self.use_thumbnail,
|
||||||
|
)
|
||||||
|
|
||||||
|
return num_patches * self.num_image_token
|
||||||
|
|
||||||
|
def _images_to_pixel_values_lst(
|
||||||
|
self,
|
||||||
|
images: list[Image.Image],
|
||||||
|
min_dynamic_patch: Optional[int] = None,
|
||||||
|
max_dynamic_patch: Optional[int] = None,
|
||||||
|
dynamic_image_size: Optional[bool] = None,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
min_num, max_num = self.resolve_min_max_num(
|
||||||
|
min_dynamic_patch=min_dynamic_patch,
|
||||||
|
max_dynamic_patch=max_dynamic_patch,
|
||||||
|
dynamic_image_size=dynamic_image_size,
|
||||||
|
use_thumbnail=False, # Applied in image_to_pixel_values
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
image_to_pixel_values_nemotron_vl(
|
||||||
|
image,
|
||||||
|
input_size=self.image_size,
|
||||||
|
min_num=min_num,
|
||||||
|
max_num=max_num,
|
||||||
|
use_thumbnail=self.use_thumbnail,
|
||||||
|
) for image in images
|
||||||
|
]
|
||||||
|
|
||||||
def _preprocess_image(
|
def _preprocess_image(
|
||||||
self,
|
self,
|
||||||
text: list[str],
|
text: list[str],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user