mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[Model] Expose InternVL2 max_dynamic_patch as a mm_processor_kwarg (#8946)
This commit is contained in:
parent
8e60afa15e
commit
2ae25f79cf
@ -115,6 +115,7 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
limit_mm_per_prompt={"image": len(image_urls)},
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
|
mm_processor_kwargs={"max_dynamic_patch": 4},
|
||||||
)
|
)
|
||||||
|
|
||||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||||
|
|||||||
@ -5,8 +5,9 @@
|
|||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
import re
|
import re
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
from functools import partial
|
||||||
TypedDict, Union)
|
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||||
|
Tuple, TypedDict, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -122,6 +123,20 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
|||||||
return blocks, target_width, target_height
|
return blocks, target_width, target_height
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_num_blocks_wrapper(hf_config: Dict[str, Any],
|
||||||
|
max_dynamic_patch: Optional[int] = None):
|
||||||
|
if max_dynamic_patch is None:
|
||||||
|
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||||
|
min_num = hf_config.min_dynamic_patch
|
||||||
|
image_size = hf_config.vision_config.image_size
|
||||||
|
use_thumbnail = hf_config.use_thumbnail
|
||||||
|
return partial(calculate_num_blocks,
|
||||||
|
min_num=min_num,
|
||||||
|
max_num=max_dynamic_patch,
|
||||||
|
image_size=image_size,
|
||||||
|
use_thumbnail=use_thumbnail)
|
||||||
|
|
||||||
|
|
||||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||||
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
|
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
|
||||||
image_size: int,
|
image_size: int,
|
||||||
@ -168,62 +183,85 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
|
|||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
|
|
||||||
def get_internvl_num_patches(image_size: int, patch_size: int,
|
def image_to_pixel_values_wrapper(hf_config: Dict[str, Any],
|
||||||
downsample_ratio: float):
|
max_dynamic_patch: Optional[int] = None):
|
||||||
|
image_size = hf_config.vision_config.image_size
|
||||||
|
min_num = hf_config.min_dynamic_patch
|
||||||
|
if max_dynamic_patch is None:
|
||||||
|
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||||
|
use_thumbnail = hf_config.use_thumbnail
|
||||||
|
return partial(image_to_pixel_values,
|
||||||
|
input_size=image_size,
|
||||||
|
min_num=min_num,
|
||||||
|
max_num=max_dynamic_patch,
|
||||||
|
use_thumbnail=use_thumbnail)
|
||||||
|
|
||||||
|
|
||||||
|
def get_internvl_num_patches(hf_config: Dict[str, Any]):
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
downsample_ratio = hf_config.downsample_ratio
|
||||||
|
image_size = vision_config.image_size
|
||||||
|
patch_size = vision_config.patch_size
|
||||||
return int(
|
return int(
|
||||||
get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
|
get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
|
||||||
(downsample_ratio**2))
|
(downsample_ratio**2))
|
||||||
|
|
||||||
|
|
||||||
def get_max_internvl_image_tokens(ctx: InputContext):
|
def get_max_internvl_image_tokens(ctx: InputContext,
|
||||||
|
*,
|
||||||
|
max_dynamic_patch: Optional[int] = None):
|
||||||
hf_config = ctx.get_hf_config()
|
hf_config = ctx.get_hf_config()
|
||||||
vision_config = hf_config.vision_config
|
|
||||||
|
|
||||||
|
if max_dynamic_patch is None:
|
||||||
|
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||||
use_thumbnail = hf_config.use_thumbnail
|
use_thumbnail = hf_config.use_thumbnail
|
||||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
if use_thumbnail and max_dynamic_patch > 1:
|
||||||
if use_thumbnail:
|
|
||||||
max_dynamic_patch += 1
|
max_dynamic_patch += 1
|
||||||
downsample_ratio = hf_config.downsample_ratio
|
|
||||||
|
|
||||||
image_size = vision_config.image_size
|
num_patches = get_internvl_num_patches(hf_config)
|
||||||
patch_size = vision_config.patch_size
|
|
||||||
num_patches = get_internvl_num_patches(image_size, patch_size,
|
|
||||||
downsample_ratio)
|
|
||||||
return num_patches * max_dynamic_patch
|
return num_patches * max_dynamic_patch
|
||||||
|
|
||||||
|
|
||||||
def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
def get_max_internvl_image_size(ctx: InputContext,
|
||||||
|
*,
|
||||||
|
max_dynamic_patch: Optional[int] = None):
|
||||||
|
hf_config = ctx.get_hf_config()
|
||||||
|
image_size = hf_config.vision_config.image_size
|
||||||
|
|
||||||
|
if max_dynamic_patch is None:
|
||||||
|
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||||
|
use_thumbnail = hf_config.use_thumbnail
|
||||||
|
if use_thumbnail and max_dynamic_patch > 1:
|
||||||
|
max_dynamic_patch += 1
|
||||||
|
width = image_size * max_dynamic_patch
|
||||||
|
height = image_size
|
||||||
|
return width, height
|
||||||
|
|
||||||
|
|
||||||
|
def input_processor_for_internvl(ctx: InputContext,
|
||||||
|
llm_inputs: LLMInputs,
|
||||||
|
*,
|
||||||
|
max_dynamic_patch: Optional[int] = None):
|
||||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||||
return llm_inputs
|
return llm_inputs
|
||||||
|
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
hf_config = ctx.get_hf_config()
|
hf_config = ctx.get_hf_config()
|
||||||
vision_config = hf_config.vision_config
|
|
||||||
|
|
||||||
image_size = vision_config.image_size
|
|
||||||
patch_size = vision_config.patch_size
|
|
||||||
downsample_ratio = hf_config.downsample_ratio
|
|
||||||
num_patches = get_internvl_num_patches(image_size, patch_size,
|
|
||||||
downsample_ratio)
|
|
||||||
|
|
||||||
image_data = multi_modal_data["image"]
|
image_data = multi_modal_data["image"]
|
||||||
min_num = hf_config.min_dynamic_patch
|
num_patches = get_internvl_num_patches(hf_config)
|
||||||
max_num = hf_config.max_dynamic_patch
|
num_blocks_calculator = calculate_num_blocks_wrapper(
|
||||||
use_thumbnail = hf_config.use_thumbnail
|
hf_config, max_dynamic_patch)
|
||||||
if isinstance(image_data, Image.Image):
|
if isinstance(image_data, Image.Image):
|
||||||
width, height = image_data.size
|
width, height = image_data.size
|
||||||
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
|
num_blocks, _, _ = num_blocks_calculator(width, height)
|
||||||
max_num, image_size,
|
|
||||||
use_thumbnail)
|
|
||||||
image_feature_size = [num_blocks * num_patches]
|
image_feature_size = [num_blocks * num_patches]
|
||||||
elif is_list_of(image_data, Image.Image):
|
elif is_list_of(image_data, Image.Image):
|
||||||
image_feature_size = []
|
image_feature_size = []
|
||||||
for image in image_data:
|
for image in image_data:
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
|
num_blocks, _, _ = num_blocks_calculator(width, height)
|
||||||
max_num, image_size,
|
|
||||||
use_thumbnail)
|
|
||||||
image_feature_size.append(num_blocks * num_patches)
|
image_feature_size.append(num_blocks * num_patches)
|
||||||
elif isinstance(image_data, torch.Tensor):
|
elif isinstance(image_data, torch.Tensor):
|
||||||
num_images, image_feature_size, hidden_size = image_data.shape
|
num_images, image_feature_size, hidden_size = image_data.shape
|
||||||
@ -253,31 +291,21 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
multi_modal_data=multi_modal_data)
|
multi_modal_data=multi_modal_data)
|
||||||
|
|
||||||
|
|
||||||
def input_mapper_for_internvl(ctx: InputContext, data: object):
|
def input_mapper_for_internvl(ctx: InputContext,
|
||||||
|
data: object,
|
||||||
|
*,
|
||||||
|
max_dynamic_patch: Optional[int] = None):
|
||||||
hf_config = ctx.get_hf_config()
|
hf_config = ctx.get_hf_config()
|
||||||
|
|
||||||
use_thumbnail = hf_config.use_thumbnail
|
image_pixel_values_mapper = image_to_pixel_values_wrapper(
|
||||||
min_num = hf_config.min_dynamic_patch
|
hf_config, max_dynamic_patch)
|
||||||
max_num = hf_config.max_dynamic_patch
|
|
||||||
image_size = hf_config.vision_config.image_size
|
|
||||||
|
|
||||||
if isinstance(data, Image.Image):
|
if isinstance(data, Image.Image):
|
||||||
data = image_to_pixel_values(data,
|
data = image_pixel_values_mapper(data)
|
||||||
image_size,
|
|
||||||
min_num,
|
|
||||||
max_num,
|
|
||||||
use_thumbnail=use_thumbnail)
|
|
||||||
# Add an N dimension for number of images per prompt (currently 1).
|
# Add an N dimension for number of images per prompt (currently 1).
|
||||||
data = data.unsqueeze(0)
|
data = data.unsqueeze(0)
|
||||||
elif is_list_of(data, Image.Image):
|
elif is_list_of(data, Image.Image):
|
||||||
# we can't stack here because the images may have different num_patches
|
# we can't stack here because the images may have different num_patches
|
||||||
data = [
|
data = [image_pixel_values_mapper(img) for img in data]
|
||||||
image_to_pixel_values(img,
|
|
||||||
image_size,
|
|
||||||
min_num,
|
|
||||||
max_num,
|
|
||||||
use_thumbnail=use_thumbnail) for img in data
|
|
||||||
]
|
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
tokenizer = cached_get_tokenizer(
|
tokenizer = cached_get_tokenizer(
|
||||||
model_config.tokenizer,
|
model_config.tokenizer,
|
||||||
@ -292,20 +320,24 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
|
def dummy_data_for_internvl(ctx: InputContext,
|
||||||
mm_counts: Mapping[str, int]):
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
*,
|
||||||
|
max_dynamic_patch: Optional[int] = None):
|
||||||
num_images = mm_counts["image"]
|
num_images = mm_counts["image"]
|
||||||
|
|
||||||
image_feature_size = get_max_internvl_image_tokens(ctx)
|
|
||||||
model_config = ctx.model_config
|
|
||||||
hf_config = ctx.get_hf_config()
|
hf_config = ctx.get_hf_config()
|
||||||
vision_config = hf_config.vision_config
|
|
||||||
|
image_feature_size = get_max_internvl_image_tokens(
|
||||||
|
ctx, max_dynamic_patch=max_dynamic_patch)
|
||||||
|
model_config = ctx.model_config
|
||||||
tokenizer = cached_get_tokenizer(
|
tokenizer = cached_get_tokenizer(
|
||||||
model_config.tokenizer,
|
model_config.tokenizer,
|
||||||
trust_remote_code=model_config.trust_remote_code)
|
trust_remote_code=model_config.trust_remote_code)
|
||||||
|
|
||||||
seq_data = dummy_seq_data_for_clip(
|
seq_data = dummy_seq_data_for_clip(
|
||||||
vision_config,
|
hf_config.vision_config,
|
||||||
seq_len,
|
seq_len,
|
||||||
num_images,
|
num_images,
|
||||||
image_token_id=tokenizer.encode(IMG_CONTEXT,
|
image_token_id=tokenizer.encode(IMG_CONTEXT,
|
||||||
@ -313,14 +345,11 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
|
|||||||
image_feature_size_override=image_feature_size,
|
image_feature_size_override=image_feature_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_size = vision_config.image_size
|
max_image_width, max_image_height = get_max_internvl_image_size(
|
||||||
min_num = hf_config.min_dynamic_patch
|
ctx, max_dynamic_patch=max_dynamic_patch)
|
||||||
max_num = hf_config.max_dynamic_patch
|
|
||||||
max_image_width = max_num * image_size
|
|
||||||
max_image_height = min_num * image_size
|
|
||||||
|
|
||||||
mm_data = dummy_image_for_clip(
|
mm_data = dummy_image_for_clip(
|
||||||
vision_config,
|
hf_config.vision_config,
|
||||||
num_images,
|
num_images,
|
||||||
image_width_override=max_image_width,
|
image_width_override=max_image_width,
|
||||||
image_height_override=max_image_height,
|
image_height_override=max_image_height,
|
||||||
@ -470,7 +499,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
|
|||||||
self,
|
self,
|
||||||
image_input: InternVLImageInputs,
|
image_input: InternVLImageInputs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
return image_input["data"]
|
return image_input["data"]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user