mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 20:35:40 +08:00
[Model][VLM] Support multi-images inputs for InternVL2 models (#8201)
This commit is contained in:
parent
9f68e00d27
commit
e807125936
@ -214,7 +214,7 @@ Multimodal Language Models
|
|||||||
-
|
-
|
||||||
* - :code:`InternVLChatModel`
|
* - :code:`InternVLChatModel`
|
||||||
- InternVL2
|
- InternVL2
|
||||||
- Image\ :sup:`E`
|
- Image\ :sup:`E+`
|
||||||
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
|
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
|
||||||
-
|
-
|
||||||
* - :code:`LlavaForConditionalGeneration`
|
* - :code:`LlavaForConditionalGeneration`
|
||||||
|
|||||||
@ -6,7 +6,9 @@ by the model.
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from vllm import LLM
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.multimodal.utils import fetch_image
|
from vllm.multimodal.utils import fetch_image
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@ -17,36 +19,84 @@ IMAGE_URLS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _load_phi3v(image_urls: List[str]):
|
def load_phi3v(question, image_urls: List[str]):
|
||||||
return LLM(
|
llm = LLM(
|
||||||
model="microsoft/Phi-3.5-vision-instruct",
|
model="microsoft/Phi-3.5-vision-instruct",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
limit_mm_per_prompt={"image": len(image_urls)},
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_phi3v_generate(question: str, image_urls: List[str]):
|
|
||||||
llm = _load_phi3v(image_urls)
|
|
||||||
|
|
||||||
placeholders = "\n".join(f"<|image_{i}|>"
|
placeholders = "\n".join(f"<|image_{i}|>"
|
||||||
for i, _ in enumerate(image_urls, start=1))
|
for i, _ in enumerate(image_urls, start=1))
|
||||||
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
||||||
|
stop_token_ids = None
|
||||||
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
outputs = llm.generate({
|
|
||||||
|
def load_internvl(question, image_urls: List[str]):
|
||||||
|
model_name = "OpenGVLab/InternVL2-2B"
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_num_seqs=5,
|
||||||
|
max_model_len=4096,
|
||||||
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
|
)
|
||||||
|
|
||||||
|
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||||
|
for i, _ in enumerate(image_urls, start=1))
|
||||||
|
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
|
trust_remote_code=True)
|
||||||
|
prompt = tokenizer.apply_chat_template(messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
|
||||||
|
# Stop tokens for InternVL
|
||||||
|
# models variants may have different stop tokens
|
||||||
|
# please refer to the model card for the correct "stop words":
|
||||||
|
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
|
||||||
|
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
||||||
|
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||||
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
model_example_map = {
|
||||||
|
"phi3_v": load_phi3v,
|
||||||
|
"internvl_chat": load_internvl,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_generate(model, question: str, image_urls: List[str]):
|
||||||
|
llm, prompt, stop_token_ids = model_example_map[model](question,
|
||||||
|
image_urls)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.0,
|
||||||
|
max_tokens=128,
|
||||||
|
stop_token_ids=stop_token_ids)
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
{
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"multi_modal_data": {
|
"multi_modal_data": {
|
||||||
"image": [fetch_image(url) for url in image_urls]
|
"image": [fetch_image(url) for url in image_urls]
|
||||||
},
|
},
|
||||||
})
|
},
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
generated_text = o.outputs[0].text
|
generated_text = o.outputs[0].text
|
||||||
print(generated_text)
|
print(generated_text)
|
||||||
|
|
||||||
|
|
||||||
def run_phi3v_chat(question: str, image_urls: List[str]):
|
def run_chat(model: str, question: str, image_urls: List[str]):
|
||||||
llm = _load_phi3v(image_urls)
|
llm, _, stop_token_ids = model_example_map[model](question, image_urls)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.0,
|
||||||
|
max_tokens=128,
|
||||||
|
stop_token_ids=stop_token_ids)
|
||||||
|
|
||||||
outputs = llm.chat([{
|
outputs = llm.chat([{
|
||||||
"role":
|
"role":
|
||||||
@ -63,7 +113,8 @@ def run_phi3v_chat(question: str, image_urls: List[str]):
|
|||||||
},
|
},
|
||||||
} for image_url in image_urls),
|
} for image_url in image_urls),
|
||||||
],
|
],
|
||||||
}])
|
}],
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
generated_text = o.outputs[0].text
|
generated_text = o.outputs[0].text
|
||||||
@ -71,12 +122,13 @@ def run_phi3v_chat(question: str, image_urls: List[str]):
|
|||||||
|
|
||||||
|
|
||||||
def main(args: Namespace):
|
def main(args: Namespace):
|
||||||
|
model = args.model_type
|
||||||
method = args.method
|
method = args.method
|
||||||
|
|
||||||
if method == "generate":
|
if method == "generate":
|
||||||
run_phi3v_generate(QUESTION, IMAGE_URLS)
|
run_generate(model, QUESTION, IMAGE_URLS)
|
||||||
elif method == "chat":
|
elif method == "chat":
|
||||||
run_phi3v_chat(QUESTION, IMAGE_URLS)
|
run_chat(model, QUESTION, IMAGE_URLS)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid method: {method}")
|
raise ValueError(f"Invalid method: {method}")
|
||||||
|
|
||||||
@ -85,6 +137,12 @@ if __name__ == "__main__":
|
|||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description='Demo on using vLLM for offline inference with '
|
description='Demo on using vLLM for offline inference with '
|
||||||
'vision language models that support multi-image input')
|
'vision language models that support multi-image input')
|
||||||
|
parser.add_argument('--model-type',
|
||||||
|
'-m',
|
||||||
|
type=str,
|
||||||
|
default="phi3_v",
|
||||||
|
choices=model_example_map.keys(),
|
||||||
|
help='Huggingface "model_type".')
|
||||||
parser.add_argument("--method",
|
parser.add_argument("--method",
|
||||||
type=str,
|
type=str,
|
||||||
default="generate",
|
default="generate",
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import types
|
import types
|
||||||
from typing import List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -9,7 +9,8 @@ from transformers import AutoConfig
|
|||||||
from vllm.multimodal.utils import rescale_image_size
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
from vllm.utils import is_cpu
|
from vllm.utils import is_cpu
|
||||||
|
|
||||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||||
|
_ImageAssets)
|
||||||
from .utils import check_logprobs_close
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
pytestmark = pytest.mark.vlm
|
pytestmark = pytest.mark.vlm
|
||||||
@ -20,6 +21,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
|||||||
"cherry_blossom":
|
"cherry_blossom":
|
||||||
"<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
|
"<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
|
||||||
})
|
})
|
||||||
|
HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: <image>\nImage-2: <image>\nDescribe the two images in detail.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501
|
||||||
|
|
||||||
models = [
|
models = [
|
||||||
"OpenGVLab/InternVL2-1B",
|
"OpenGVLab/InternVL2-1B",
|
||||||
@ -64,13 +66,13 @@ def generate(
|
|||||||
def run_test(
|
def run_test(
|
||||||
hf_runner: Type[HfRunner],
|
hf_runner: Type[HfRunner],
|
||||||
vllm_runner: Type[VllmRunner],
|
vllm_runner: Type[VllmRunner],
|
||||||
image_assets: _ImageAssets,
|
inputs: List[Tuple[List[str], PromptImageInput]],
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
*,
|
||||||
size_factors: List[float],
|
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
|
mm_limit: int,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
distributed_executor_backend: Optional[str] = None,
|
distributed_executor_backend: Optional[str] = None,
|
||||||
):
|
):
|
||||||
@ -83,12 +85,6 @@ def run_test(
|
|||||||
Note, the text input is also adjusted to abide by vllm contract.
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
The text output is sanitized to be able to compare with hf.
|
The text output is sanitized to be able to compare with hf.
|
||||||
"""
|
"""
|
||||||
images = [asset.pil_image for asset in image_assets]
|
|
||||||
|
|
||||||
inputs_per_image = [(
|
|
||||||
[prompt for _ in size_factors],
|
|
||||||
[rescale_image_size(image, factor) for factor in size_factors],
|
|
||||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
|
||||||
|
|
||||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
# vLLM needs a fresh new process without cuda initialization.
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
@ -110,13 +106,21 @@ def run_test(
|
|||||||
self.max_num = self.config.max_dynamic_patch
|
self.max_num = self.config.max_dynamic_patch
|
||||||
self.image_size = self.vision_config.image_size
|
self.image_size = self.vision_config.image_size
|
||||||
|
|
||||||
def __call__(self, text: str, images: Image, **kwargs):
|
def __call__(self, text: str, images: Union[Image, List[Image]],
|
||||||
|
**kwargs):
|
||||||
from vllm.model_executor.models.internvl import (
|
from vllm.model_executor.models.internvl import (
|
||||||
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
|
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
|
||||||
pixel_values = image_to_pixel_values(
|
images = [images] if isinstance(images, Image) else images
|
||||||
images, self.image_size, self.min_num, self.max_num,
|
pixel_values = [
|
||||||
|
image_to_pixel_values(image, self.image_size, self.min_num,
|
||||||
|
self.max_num,
|
||||||
self.use_thumbnail).to(self.dtype)
|
self.use_thumbnail).to(self.dtype)
|
||||||
num_patches_list = [pixel_values.shape[0]]
|
for image in images
|
||||||
|
]
|
||||||
|
num_patches_list = [
|
||||||
|
pixel_value.shape[0] for pixel_value in pixel_values
|
||||||
|
]
|
||||||
|
pixel_values = torch.cat(pixel_values, dim=0)
|
||||||
for num_patches in num_patches_list:
|
for num_patches in num_patches_list:
|
||||||
context_tokens = IMG_CONTEXT * self.num_image_token \
|
context_tokens = IMG_CONTEXT * self.num_image_token \
|
||||||
* num_patches
|
* num_patches
|
||||||
@ -130,6 +134,7 @@ def run_test(
|
|||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
limit_mm_per_prompt={"image": mm_limit},
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
enforce_eager=True) as vllm_model:
|
enforce_eager=True) as vllm_model:
|
||||||
@ -138,7 +143,7 @@ def run_test(
|
|||||||
max_tokens,
|
max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
images=images)
|
images=images)
|
||||||
for prompts, images in inputs_per_image
|
for prompts, images in inputs
|
||||||
]
|
]
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
@ -156,7 +161,7 @@ def run_test(
|
|||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
images=hf_images,
|
images=hf_images,
|
||||||
eos_token_id=eos_token_id)
|
eos_token_id=eos_token_id)
|
||||||
for prompts, hf_images in inputs_per_image
|
for prompts, hf_images in inputs
|
||||||
]
|
]
|
||||||
|
|
||||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||||
@ -264,15 +269,64 @@ if is_cpu():
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||||
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
|
inputs_per_image = [(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[rescale_image_size(image, factor) for factor in size_factors],
|
||||||
|
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||||
|
|
||||||
run_test(
|
run_test(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
image_assets,
|
inputs_per_image,
|
||||||
model,
|
model,
|
||||||
size_factors=size_factors,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
|
mm_limit=1,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"size_factors",
|
||||||
|
[
|
||||||
|
# No image
|
||||||
|
[],
|
||||||
|
# Single-scale
|
||||||
|
[1.0],
|
||||||
|
# Single-scale, batched
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
# Multi-scale
|
||||||
|
[0.5, 0.75, 1.0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||||
|
size_factors, dtype: str, max_tokens: int,
|
||||||
|
num_logprobs: int) -> None:
|
||||||
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
|
inputs_per_case = [
|
||||||
|
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||||
|
[[rescale_image_size(image, factor) for image in images]
|
||||||
|
for factor in size_factors])
|
||||||
|
]
|
||||||
|
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
inputs_per_case,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
mm_limit=2,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,16 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Tuple, Type, Union
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm.multimodal.utils import rescale_image_size
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
from vllm.utils import is_cpu, is_hip
|
from vllm.utils import is_cpu, is_hip
|
||||||
|
|
||||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner
|
from ..conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||||
from .utils import check_logprobs_close
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
pytestmark = pytest.mark.vlm
|
pytestmark = pytest.mark.vlm
|
||||||
@ -60,8 +59,7 @@ if is_hip():
|
|||||||
def run_test(
|
def run_test(
|
||||||
hf_runner: Type[HfRunner],
|
hf_runner: Type[HfRunner],
|
||||||
vllm_runner: Type[VllmRunner],
|
vllm_runner: Type[VllmRunner],
|
||||||
inputs: List[Tuple[List[str], Union[List[Image.Image],
|
inputs: List[Tuple[List[str], PromptImageInput]],
|
||||||
List[List[Image.Image]]]]],
|
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
*,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
import itertools
|
import itertools
|
||||||
|
import re
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||||
TypedDict, Union)
|
TypedDict, Union)
|
||||||
|
|
||||||
@ -26,6 +27,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
|||||||
from vllm.multimodal.base import MultiModalInputs
|
from vllm.multimodal.base import MultiModalInputs
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||||
get_clip_num_patches)
|
get_clip_num_patches)
|
||||||
@ -95,8 +97,8 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
|||||||
|
|
||||||
|
|
||||||
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
||||||
max_num: int,
|
max_num: int, image_size: int,
|
||||||
image_size: int) -> Tuple[int, int, int]:
|
use_thumbnail: bool) -> Tuple[int, int, int]:
|
||||||
aspect_ratio = orig_width / orig_height
|
aspect_ratio = orig_width / orig_height
|
||||||
|
|
||||||
# calculate the existing image aspect ratio
|
# calculate the existing image aspect ratio
|
||||||
@ -114,17 +116,26 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
|||||||
target_width = image_size * target_aspect_ratio[0]
|
target_width = image_size * target_aspect_ratio[0]
|
||||||
target_height = image_size * target_aspect_ratio[1]
|
target_height = image_size * target_aspect_ratio[1]
|
||||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||||
|
# add thumbnail image if num_blocks > 1
|
||||||
|
if use_thumbnail and blocks > 1:
|
||||||
|
blocks += 1
|
||||||
return blocks, target_width, target_height
|
return blocks, target_width, target_height
|
||||||
|
|
||||||
|
|
||||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||||
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
|
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
|
||||||
image_size: int,
|
image_size: int,
|
||||||
use_thumbnail: int) -> List[Image.Image]:
|
use_thumbnail: bool) -> List[Image.Image]:
|
||||||
orig_width, orig_height = image.size
|
orig_width, orig_height = image.size
|
||||||
|
|
||||||
|
# calculate the number of blocks without thumbnail
|
||||||
blocks, target_width, target_height = calculate_num_blocks(
|
blocks, target_width, target_height = calculate_num_blocks(
|
||||||
orig_width, orig_height, min_num, max_num, image_size)
|
orig_width,
|
||||||
|
orig_height,
|
||||||
|
min_num,
|
||||||
|
max_num,
|
||||||
|
image_size,
|
||||||
|
use_thumbnail=False)
|
||||||
# resize the image
|
# resize the image
|
||||||
resized_img = image.resize((target_width, target_height))
|
resized_img = image.resize((target_width, target_height))
|
||||||
processed_images = []
|
processed_images = []
|
||||||
@ -197,17 +208,23 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
downsample_ratio)
|
downsample_ratio)
|
||||||
|
|
||||||
image_data = multi_modal_data["image"]
|
image_data = multi_modal_data["image"]
|
||||||
if isinstance(image_data, Image.Image):
|
|
||||||
width, height = image_data.size
|
|
||||||
min_num = hf_config.min_dynamic_patch
|
min_num = hf_config.min_dynamic_patch
|
||||||
max_num = hf_config.max_dynamic_patch
|
max_num = hf_config.max_dynamic_patch
|
||||||
|
use_thumbnail = hf_config.use_thumbnail
|
||||||
|
if isinstance(image_data, Image.Image):
|
||||||
|
width, height = image_data.size
|
||||||
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
|
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
|
||||||
max_num, image_size)
|
max_num, image_size,
|
||||||
# add thumbnail image if num_blocks > 1
|
use_thumbnail)
|
||||||
if hf_config.use_thumbnail and num_blocks > 1:
|
image_feature_size = [num_blocks * num_patches]
|
||||||
num_blocks += 1
|
elif is_list_of(image_data, Image.Image):
|
||||||
image_feature_size = num_blocks * num_patches
|
image_feature_size = []
|
||||||
|
for image in image_data:
|
||||||
|
width, height = image.size
|
||||||
|
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
|
||||||
|
max_num, image_size,
|
||||||
|
use_thumbnail)
|
||||||
|
image_feature_size.append(num_blocks * num_patches)
|
||||||
elif isinstance(image_data, torch.Tensor):
|
elif isinstance(image_data, torch.Tensor):
|
||||||
num_images, image_feature_size, hidden_size = image_data.shape
|
num_images, image_feature_size, hidden_size = image_data.shape
|
||||||
else:
|
else:
|
||||||
@ -220,8 +237,14 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = tokenizer.decode(prompt_token_ids)
|
prompt = tokenizer.decode(prompt_token_ids)
|
||||||
image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
|
|
||||||
new_prompt = prompt.replace('<image>', image_prompt, 1)
|
new_prompt = prompt
|
||||||
|
image_idx = sorted(map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
|
||||||
|
for idx, feature_size in enumerate(image_feature_size, start=1):
|
||||||
|
image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END
|
||||||
|
if not image_idx:
|
||||||
|
image_prompt = f"Image-{idx}: {image_prompt}"
|
||||||
|
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
|
||||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||||
|
|
||||||
return LLMInputs(prompt=prompt,
|
return LLMInputs(prompt=prompt,
|
||||||
@ -245,6 +268,15 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
|
|||||||
use_thumbnail=use_thumbnail)
|
use_thumbnail=use_thumbnail)
|
||||||
# Add an N dimension for number of images per prompt (currently 1).
|
# Add an N dimension for number of images per prompt (currently 1).
|
||||||
data = data.unsqueeze(0)
|
data = data.unsqueeze(0)
|
||||||
|
elif is_list_of(data, Image.Image):
|
||||||
|
data = [
|
||||||
|
image_to_pixel_values(img,
|
||||||
|
image_size,
|
||||||
|
min_num,
|
||||||
|
max_num,
|
||||||
|
use_thumbnail=use_thumbnail) for img in data
|
||||||
|
]
|
||||||
|
data = torch.stack(data)
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user