[Model][VLM] Support multi-images inputs for InternVL2 models (#8201)

This commit is contained in:
Isotr0py 2024-09-07 16:38:23 +08:00 committed by GitHub
parent 9f68e00d27
commit e807125936
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 199 additions and 57 deletions

View File

@ -214,7 +214,7 @@ Multimodal Language Models
-
* - :code:`InternVLChatModel`
- InternVL2
- Image\ :sup:`E`
- Image\ :sup:`E+`
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
-
* - :code:`LlavaForConditionalGeneration`

View File

@ -6,7 +6,9 @@ by the model.
from argparse import Namespace
from typing import List
from vllm import LLM
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser
@ -17,36 +19,84 @@ IMAGE_URLS = [
]
def _load_phi3v(image_urls: List[str]):
return LLM(
def load_phi3v(question, image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={"image": len(image_urls)},
)
def run_phi3v_generate(question: str, image_urls: List[str]):
llm = _load_phi3v(image_urls)
placeholders = "\n".join(f"<|image_{i}|>"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
stop_token_ids = None
return llm, prompt, stop_token_ids
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": [fetch_image(url) for url in image_urls]
def load_internvl(question, image_urls: List[str]):
model_name = "OpenGVLab/InternVL2-2B"
llm = LLM(
model=model_name,
trust_remote_code=True,
max_num_seqs=5,
max_model_len=4096,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
# Stop tokens for InternVL
# models variants may have different stop tokens
# please refer to the model card for the correct "stop words":
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids
model_example_map = {
"phi3_v": load_phi3v,
"internvl_chat": load_internvl,
}
def run_generate(model, question: str, image_urls: List[str]):
llm, prompt, stop_token_ids = model_example_map[model](question,
image_urls)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
stop_token_ids=stop_token_ids)
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {
"image": [fetch_image(url) for url in image_urls]
},
},
})
sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def run_phi3v_chat(question: str, image_urls: List[str]):
llm = _load_phi3v(image_urls)
def run_chat(model: str, question: str, image_urls: List[str]):
llm, _, stop_token_ids = model_example_map[model](question, image_urls)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
stop_token_ids=stop_token_ids)
outputs = llm.chat([{
"role":
@ -63,7 +113,8 @@ def run_phi3v_chat(question: str, image_urls: List[str]):
},
} for image_url in image_urls),
],
}])
}],
sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
@ -71,12 +122,13 @@ def run_phi3v_chat(question: str, image_urls: List[str]):
def main(args: Namespace):
model = args.model_type
method = args.method
if method == "generate":
run_phi3v_generate(QUESTION, IMAGE_URLS)
run_generate(model, QUESTION, IMAGE_URLS)
elif method == "chat":
run_phi3v_chat(QUESTION, IMAGE_URLS)
run_chat(model, QUESTION, IMAGE_URLS)
else:
raise ValueError(f"Invalid method: {method}")
@ -85,6 +137,12 @@ if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'vision language models that support multi-image input')
parser.add_argument('--model-type',
'-m',
type=str,
default="phi3_v",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument("--method",
type=str,
default="generate",

View File

@ -1,5 +1,5 @@
import types
from typing import List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type, Union
import pytest
import torch
@ -9,7 +9,8 @@ from transformers import AutoConfig
from vllm.multimodal.utils import rescale_image_size
from vllm.utils import is_cpu
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
@ -20,6 +21,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"cherry_blossom":
"<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
})
HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: <image>\nImage-2: <image>\nDescribe the two images in detail.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501
models = [
"OpenGVLab/InternVL2-1B",
@ -64,13 +66,13 @@ def generate(
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
mm_limit: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
@ -83,12 +85,6 @@ def run_test(
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
@ -110,13 +106,21 @@ def run_test(
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
def __call__(self, text: str, images: Image, **kwargs):
def __call__(self, text: str, images: Union[Image, List[Image]],
**kwargs):
from vllm.model_executor.models.internvl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
pixel_values = image_to_pixel_values(
images, self.image_size, self.min_num, self.max_num,
self.use_thumbnail).to(self.dtype)
num_patches_list = [pixel_values.shape[0]]
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values(image, self.image_size, self.min_num,
self.max_num,
self.use_thumbnail).to(self.dtype)
for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
]
pixel_values = torch.cat(pixel_values, dim=0)
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
@ -130,6 +134,7 @@ def run_test(
with vllm_runner(model,
max_model_len=4096,
dtype=dtype,
limit_mm_per_prompt={"image": mm_limit},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
@ -138,7 +143,7 @@ def run_test(
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
for prompts, images in inputs
]
with hf_runner(model, dtype=dtype) as hf_model:
@ -156,7 +161,7 @@ def run_test(
num_logprobs=num_logprobs,
images=hf_images,
eos_token_id=eos_token_id)
for prompts, hf_images in inputs_per_image
for prompts, hf_images in inputs
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
@ -264,15 +269,64 @@ if is_cpu():
@torch.inference_mode()
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
run_test(
hf_runner,
vllm_runner,
image_assets,
inputs_per_image,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=1,
tensor_parallel_size=1,
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.5, 0.75, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@torch.inference_mode()
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets]
inputs_per_case = [
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
[[rescale_image_size(image, factor) for image in images]
for factor in size_factors])
]
run_test(
hf_runner,
vllm_runner,
inputs_per_case,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=2,
tensor_parallel_size=1,
)

View File

@ -1,16 +1,15 @@
import os
import re
from typing import List, Optional, Tuple, Type, Union
from typing import List, Optional, Tuple, Type
import pytest
from PIL import Image
from transformers import AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu, is_hip
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner
from ..conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
@ -60,8 +59,7 @@ if is_hip():
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], Union[List[Image.Image],
List[List[Image.Image]]]]],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
*,
dtype: str,

View File

@ -5,6 +5,7 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import itertools
import re
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
@ -26,6 +27,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
@ -95,8 +97,8 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
max_num: int,
image_size: int) -> Tuple[int, int, int]:
max_num: int, image_size: int,
use_thumbnail: bool) -> Tuple[int, int, int]:
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
@ -114,17 +116,26 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# add thumbnail image if num_blocks > 1
if use_thumbnail and blocks > 1:
blocks += 1
return blocks, target_width, target_height
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
image_size: int,
use_thumbnail: int) -> List[Image.Image]:
use_thumbnail: bool) -> List[Image.Image]:
orig_width, orig_height = image.size
# calculate the number of blocks without thumbnail
blocks, target_width, target_height = calculate_num_blocks(
orig_width, orig_height, min_num, max_num, image_size)
orig_width,
orig_height,
min_num,
max_num,
image_size,
use_thumbnail=False)
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
@ -197,17 +208,23 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
downsample_ratio)
image_data = multi_modal_data["image"]
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
if isinstance(image_data, Image.Image):
width, height = image_data.size
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size)
# add thumbnail image if num_blocks > 1
if hf_config.use_thumbnail and num_blocks > 1:
num_blocks += 1
image_feature_size = num_blocks * num_patches
max_num, image_size,
use_thumbnail)
image_feature_size = [num_blocks * num_patches]
elif is_list_of(image_data, Image.Image):
image_feature_size = []
for image in image_data:
width, height = image.size
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size,
use_thumbnail)
image_feature_size.append(num_blocks * num_patches)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
else:
@ -220,8 +237,14 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
new_prompt = prompt.replace('<image>', image_prompt, 1)
new_prompt = prompt
image_idx = sorted(map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
for idx, feature_size in enumerate(image_feature_size, start=1):
image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END
if not image_idx:
image_prompt = f"Image-{idx}: {image_prompt}"
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
new_prompt_token_ids = tokenizer.encode(new_prompt)
return LLMInputs(prompt=prompt,
@ -245,6 +268,15 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
use_thumbnail=use_thumbnail)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
data = [
image_to_pixel_values(img,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail) for img in data
]
data = torch.stack(data)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)