mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:45:01 +08:00
[Model][VLM] Support multi-images inputs for InternVL2 models (#8201)
This commit is contained in:
parent
9f68e00d27
commit
e807125936
@ -214,7 +214,7 @@ Multimodal Language Models
|
||||
-
|
||||
* - :code:`InternVLChatModel`
|
||||
- InternVL2
|
||||
- Image\ :sup:`E`
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
|
||||
-
|
||||
* - :code:`LlavaForConditionalGeneration`
|
||||
|
||||
@ -6,7 +6,9 @@ by the model.
|
||||
from argparse import Namespace
|
||||
from typing import List
|
||||
|
||||
from vllm import LLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -17,36 +19,84 @@ IMAGE_URLS = [
|
||||
]
|
||||
|
||||
|
||||
def _load_phi3v(image_urls: List[str]):
|
||||
return LLM(
|
||||
def load_phi3v(question, image_urls: List[str]):
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
|
||||
def run_phi3v_generate(question: str, image_urls: List[str]):
|
||||
llm = _load_phi3v(image_urls)
|
||||
|
||||
placeholders = "\n".join(f"<|image_{i}|>"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"image": [fetch_image(url) for url in image_urls]
|
||||
|
||||
def load_internvl(question, image_urls: List[str]):
|
||||
model_name = "OpenGVLab/InternVL2-2B"
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_num_seqs=5,
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
# Stop tokens for InternVL
|
||||
# models variants may have different stop tokens
|
||||
# please refer to the model card for the correct "stop words":
|
||||
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
|
||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"phi3_v": load_phi3v,
|
||||
"internvl_chat": load_internvl,
|
||||
}
|
||||
|
||||
|
||||
def run_generate(model, question: str, image_urls: List[str]):
|
||||
llm, prompt, stop_token_ids = model_example_map[model](question,
|
||||
image_urls)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=128,
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"image": [fetch_image(url) for url in image_urls]
|
||||
},
|
||||
},
|
||||
})
|
||||
sampling_params=sampling_params)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
|
||||
def run_phi3v_chat(question: str, image_urls: List[str]):
|
||||
llm = _load_phi3v(image_urls)
|
||||
def run_chat(model: str, question: str, image_urls: List[str]):
|
||||
llm, _, stop_token_ids = model_example_map[model](question, image_urls)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=128,
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
outputs = llm.chat([{
|
||||
"role":
|
||||
@ -63,7 +113,8 @@ def run_phi3v_chat(question: str, image_urls: List[str]):
|
||||
},
|
||||
} for image_url in image_urls),
|
||||
],
|
||||
}])
|
||||
}],
|
||||
sampling_params=sampling_params)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
@ -71,12 +122,13 @@ def run_phi3v_chat(question: str, image_urls: List[str]):
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
model = args.model_type
|
||||
method = args.method
|
||||
|
||||
if method == "generate":
|
||||
run_phi3v_generate(QUESTION, IMAGE_URLS)
|
||||
run_generate(model, QUESTION, IMAGE_URLS)
|
||||
elif method == "chat":
|
||||
run_phi3v_chat(QUESTION, IMAGE_URLS)
|
||||
run_chat(model, QUESTION, IMAGE_URLS)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method}")
|
||||
|
||||
@ -85,6 +137,12 @@ if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models that support multi-image input')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="phi3_v",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument("--method",
|
||||
type=str,
|
||||
default="generate",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import types
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -9,7 +9,8 @@ from transformers import AutoConfig
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_ImageAssets)
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
@ -20,6 +21,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"cherry_blossom":
|
||||
"<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
|
||||
})
|
||||
HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: <image>\nImage-2: <image>\nDescribe the two images in detail.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501
|
||||
|
||||
models = [
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
@ -64,13 +66,13 @@ def generate(
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
inputs: List[Tuple[List[str], PromptImageInput]],
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
mm_limit: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
@ -83,12 +85,6 @@ def run_test(
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
@ -110,13 +106,21 @@ def run_test(
|
||||
self.max_num = self.config.max_dynamic_patch
|
||||
self.image_size = self.vision_config.image_size
|
||||
|
||||
def __call__(self, text: str, images: Image, **kwargs):
|
||||
def __call__(self, text: str, images: Union[Image, List[Image]],
|
||||
**kwargs):
|
||||
from vllm.model_executor.models.internvl import (
|
||||
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
|
||||
pixel_values = image_to_pixel_values(
|
||||
images, self.image_size, self.min_num, self.max_num,
|
||||
self.use_thumbnail).to(self.dtype)
|
||||
num_patches_list = [pixel_values.shape[0]]
|
||||
images = [images] if isinstance(images, Image) else images
|
||||
pixel_values = [
|
||||
image_to_pixel_values(image, self.image_size, self.min_num,
|
||||
self.max_num,
|
||||
self.use_thumbnail).to(self.dtype)
|
||||
for image in images
|
||||
]
|
||||
num_patches_list = [
|
||||
pixel_value.shape[0] for pixel_value in pixel_values
|
||||
]
|
||||
pixel_values = torch.cat(pixel_values, dim=0)
|
||||
for num_patches in num_patches_list:
|
||||
context_tokens = IMG_CONTEXT * self.num_image_token \
|
||||
* num_patches
|
||||
@ -130,6 +134,7 @@ def run_test(
|
||||
with vllm_runner(model,
|
||||
max_model_len=4096,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": mm_limit},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
@ -138,7 +143,7 @@ def run_test(
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
@ -156,7 +161,7 @@ def run_test(
|
||||
num_logprobs=num_logprobs,
|
||||
images=hf_images,
|
||||
eos_token_id=eos_token_id)
|
||||
for prompts, hf_images in inputs_per_image
|
||||
for prompts, hf_images in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
@ -264,15 +269,64 @@ if is_cpu():
|
||||
@torch.inference_mode()
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
inputs_per_image,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.5, 0.75, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@torch.inference_mode()
|
||||
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||
size_factors, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_case = [
|
||||
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||
[[rescale_image_size(image, factor) for image in images]
|
||||
for factor in size_factors])
|
||||
]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs_per_case,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=2,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
@ -60,8 +59,7 @@ if is_hip():
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
inputs: List[Tuple[List[str], Union[List[Image.Image],
|
||||
List[List[Image.Image]]]]],
|
||||
inputs: List[Tuple[List[str], PromptImageInput]],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
import itertools
|
||||
import re
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@ -26,6 +27,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
get_clip_num_patches)
|
||||
@ -95,8 +97,8 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
||||
|
||||
|
||||
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
||||
max_num: int,
|
||||
image_size: int) -> Tuple[int, int, int]:
|
||||
max_num: int, image_size: int,
|
||||
use_thumbnail: bool) -> Tuple[int, int, int]:
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
@ -114,17 +116,26 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
# add thumbnail image if num_blocks > 1
|
||||
if use_thumbnail and blocks > 1:
|
||||
blocks += 1
|
||||
return blocks, target_width, target_height
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
|
||||
image_size: int,
|
||||
use_thumbnail: int) -> List[Image.Image]:
|
||||
use_thumbnail: bool) -> List[Image.Image]:
|
||||
orig_width, orig_height = image.size
|
||||
|
||||
# calculate the number of blocks without thumbnail
|
||||
blocks, target_width, target_height = calculate_num_blocks(
|
||||
orig_width, orig_height, min_num, max_num, image_size)
|
||||
orig_width,
|
||||
orig_height,
|
||||
min_num,
|
||||
max_num,
|
||||
image_size,
|
||||
use_thumbnail=False)
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
@ -197,17 +208,23 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
downsample_ratio)
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
max_num = hf_config.max_dynamic_patch
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
max_num = hf_config.max_dynamic_patch
|
||||
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
|
||||
max_num, image_size)
|
||||
# add thumbnail image if num_blocks > 1
|
||||
if hf_config.use_thumbnail and num_blocks > 1:
|
||||
num_blocks += 1
|
||||
image_feature_size = num_blocks * num_patches
|
||||
|
||||
max_num, image_size,
|
||||
use_thumbnail)
|
||||
image_feature_size = [num_blocks * num_patches]
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_size = []
|
||||
for image in image_data:
|
||||
width, height = image.size
|
||||
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
|
||||
max_num, image_size,
|
||||
use_thumbnail)
|
||||
image_feature_size.append(num_blocks * num_patches)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
else:
|
||||
@ -220,8 +237,14 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
|
||||
new_prompt = prompt.replace('<image>', image_prompt, 1)
|
||||
|
||||
new_prompt = prompt
|
||||
image_idx = sorted(map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
|
||||
for idx, feature_size in enumerate(image_feature_size, start=1):
|
||||
image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END
|
||||
if not image_idx:
|
||||
image_prompt = f"Image-{idx}: {image_prompt}"
|
||||
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
|
||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||
|
||||
return LLMInputs(prompt=prompt,
|
||||
@ -245,6 +268,15 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
|
||||
use_thumbnail=use_thumbnail)
|
||||
# Add an N dimension for number of images per prompt (currently 1).
|
||||
data = data.unsqueeze(0)
|
||||
elif is_list_of(data, Image.Image):
|
||||
data = [
|
||||
image_to_pixel_values(img,
|
||||
image_size,
|
||||
min_num,
|
||||
max_num,
|
||||
use_thumbnail=use_thumbnail) for img in data
|
||||
]
|
||||
data = torch.stack(data)
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user