[Model] Add multi-image input support for LLaVA-Next offline inference (#7230)

This commit is contained in:
zifeitong 2024-08-27 16:09:02 -07:00 committed by GitHub
parent 345be0e244
commit 5340a2dccf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 173 additions and 51 deletions

View File

@ -41,6 +41,10 @@ _TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]
def _read_prompts(filename: str) -> List[str]: def _read_prompts(filename: str) -> List[str]:
with open(filename, "r") as f: with open(filename, "r") as f:
@ -578,8 +582,7 @@ class VllmRunner:
self, self,
prompts: List[str], prompts: List[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
images: Optional[Union[List[Image.Image], images: Optional[PromptImageInput] = None,
List[List[Image.Image]]]] = None,
) -> List[Tuple[List[List[int]], List[str]]]: ) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None: if images is not None:
assert len(prompts) == len(images) assert len(prompts) == len(images)
@ -623,10 +626,8 @@ class VllmRunner:
self, self,
prompts: List[str], prompts: List[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
images: Optional[Union[List[Image.Image], images: Optional[PromptImageInput] = None,
List[List[Image.Image]]]] = None, audios: Optional[PromptAudioInput] = None,
audios: Optional[Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]] = None
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None assert sampling_params.logprobs is not None
@ -676,10 +677,8 @@ class VllmRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
images: Optional[Union[List[Image.Image], images: Optional[PromptImageInput] = None,
List[List[Image.Image]]]] = None, audios: Optional[PromptAudioInput] = None,
audios: Optional[Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0, greedy_logprobs_params = SamplingParams(temperature=0.0,

View File

@ -6,24 +6,22 @@ from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from .utils import check_logprobs_close from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
_PREFACE = ( _LIMIT_IMAGE_PER_PROMPT = 4
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's "
"questions.")
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "stop_sign":
f"{_PREFACE} USER: <image>\nWhat's the content of the image? ASSISTANT:", "[INST] <image>\nWhat's the content of the image? [/INST]",
"cherry_blossom": "cherry_blossom":
f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:", "[INST] <image>\nWhat is the season? [/INST]",
}) })
models = ["llava-hf/llava-v1.6-vicuna-7b-hf"] models = ["llava-hf/llava-v1.6-mistral-7b-hf"]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str, def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
@ -114,19 +112,43 @@ def run_test(
else: else:
raise ValueError("You must provide either `size_factors` or `sizes`") raise ValueError("You must provide either `size_factors` or `sizes`")
_run_test(hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend)
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
# max_model_len should be greater than image_feature_size # max_model_len should be greater than image_feature_size
with vllm_runner(model, with vllm_runner(model,
dtype=dtype, dtype=dtype,
max_model_len=4096, max_model_len=10240,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model: enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_image = [ vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts, vllm_model.generate_greedy_logprobs(prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=images) images=images)
for prompts, images in inputs_per_image for prompts, images in inputs
] ]
with hf_runner(model, dtype=dtype, with hf_runner(model, dtype=dtype,
@ -136,7 +158,7 @@ def run_test(
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=images) images=images)
for prompts, images in inputs_per_image for prompts, images in inputs
] ]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
@ -216,3 +238,48 @@ def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
tensor_parallel_size=1, tensor_parallel_size=1,
) )
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
inputs = [(
[
"[INST] <image><image>\nDescribe 2 images. [/INST]",
"[INST] <image><image>\nDescribe 2 images. [/INST]",
"[INST] <image><image><image><image>\nDescribe 4 images. [/INST]",
"[INST] <image>\nWhat is the season? [/INST]"
],
[
[stop_sign, cherry_blossom],
# Images with different sizes and aspect-ratios
[
rescale_image_size(stop_sign, 0.1),
stop_sign,
],
[
stop_sign,
rescale_image_size(stop_sign, 0.25),
cherry_blossom.resize((183, 488)),
cherry_blossom.resize((488, 183))
],
cherry_blossom,
])]
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -6,8 +6,10 @@ from typing import Dict, Tuple
import numpy as np import numpy as np
import pytest import pytest
from PIL import Image from PIL import Image
from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import async_fetch_image, fetch_image from vllm.multimodal.utils import (async_fetch_image, fetch_image,
repeat_and_pad_placeholder_tokens)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [ TEST_IMAGE_URLS = [
@ -80,3 +82,34 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
data_image_async = await async_fetch_image(data_url) data_image_async = await async_fetch_image(data_url)
assert _image_equals(data_image_sync, data_image_async) assert _image_equals(data_image_sync, data_image_async)
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
def test_repeat_and_pad_placeholder_tokens(model):
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
test_cases = [
("<image>", 2, "<image><image>", [32000, 32000]),
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]),
("<image><image>", [3, 2], "<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000]),
("Image:<image>Image:<image>!", [3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]),
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]),
]
for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases:
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer=tokenizer,
prompt=prompt,
prompt_token_ids=tokenizer.encode(prompt,
add_special_tokens=False),
placeholder_token_id=image_token_id,
repeat_count=repeat_count,
)
assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids

View File

@ -1,7 +1,7 @@
"""Minimal implementation of CLIPVisionModel intended to be only used """Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from array import array from array import array
from typing import Iterable, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -84,7 +84,7 @@ def input_processor_for_clip(
llm_inputs: LLMInputs, llm_inputs: LLMInputs,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[Union[int, List[int]]] = None,
): ):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:

View File

@ -19,6 +19,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size, dummy_seq_data_for_clip, get_clip_image_feature_size,
@ -223,6 +224,13 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
input_height=height, input_height=height,
input_width=width, input_width=width,
) )
elif is_list_of(image_data, Image.Image):
image_feature_size = [
get_llava_next_image_feature_size(hf_config,
input_height=img.height,
input_width=img.width)
for img in image_data
]
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] image_feature_size = image_data.shape[0]
else: else:
@ -425,7 +433,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
self.config.image_grid_pinpoints, self.config.image_grid_pinpoints,
self.config.vision_config.image_size, self.config.vision_config.image_size,
) )
other_patch_embeds = other_patch_embeds \ num_patches = num_patch_height * num_patch_width
# Image patches might be padded for batch processing
other_patch_embeds = other_patch_embeds[:num_patches] \
.view(num_patch_height, num_patch_width, height, width, -1) .view(num_patch_height, num_patch_width, height, width, -1)
if "unpad" in strategy: if "unpad" in strategy:
@ -496,7 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
self, self,
image_input: LlavaNextImageInputs, image_input: LlavaNextImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, List[torch.Tensor]]:
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return [image_input["data"]] return [image_input["data"]]

View File

@ -3,7 +3,7 @@ within a vision language model."""
import math import math
from array import array from array import array
from typing import Iterable, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from PIL import Image from PIL import Image
@ -93,7 +93,7 @@ def input_processor_for_siglip(
llm_inputs: LLMInputs, llm_inputs: LLMInputs,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[Union[int, List[int]]] = None,
): ):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:

View File

@ -189,10 +189,13 @@ def repeat_and_pad_placeholder_tokens(
prompt_token_ids: List[int], prompt_token_ids: List[int],
*, *,
placeholder_token_id: int, placeholder_token_id: int,
repeat_count: int = 1, repeat_count: Union[int, List[int]],
pad_token_left: Optional[int] = None, pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None, pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]: ) -> Tuple[Optional[str], List[int]]:
if isinstance(repeat_count, int):
repeat_count = [repeat_count]
if prompt is None: if prompt is None:
new_prompt = None new_prompt = None
else: else:
@ -201,13 +204,6 @@ def repeat_and_pad_placeholder_tokens(
tokenizer.decode(pad_token_left)) tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right)) tokenizer.decode(pad_token_right))
replacement_str = "".join(
repeat_and_pad_token(
placeholder_token_str,
repeat_count=repeat_count,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
placeholder_token_count = prompt.count(placeholder_token_str) placeholder_token_count = prompt.count(placeholder_token_str)
# This is an arbitrary number to distinguish between the two cases # This is an arbitrary number to distinguish between the two cases
@ -216,26 +212,43 @@ def repeat_and_pad_placeholder_tokens(
"Please follow the prompt format that is " "Please follow the prompt format that is "
"documented on HuggingFace which does not involve " "documented on HuggingFace which does not involve "
"repeating %s tokens.", placeholder_token_str) "repeating %s tokens.", placeholder_token_str)
elif placeholder_token_count > 1: if placeholder_token_count < len(repeat_count):
logger.warning("Multiple multi-modal input is not supported yet, " logger.warning(
"so any extra placeholder tokens will be treated " "The number of multi-modal placeholder tokens in the prompt "
"as plain text.") "is less than the number of multi-modal inputs. Extra "
"placeholder tokens will be treated as plain text")
repeat_count = repeat_count[:placeholder_token_count]
prompt_parts = prompt.split(placeholder_token_str,
maxsplit=len(repeat_count))
new_prompt = ""
for i, repeat_count_item in enumerate(repeat_count):
replacement_str = "".join(
repeat_and_pad_token(
placeholder_token_str,
repeat_count=repeat_count_item,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
# The image tokens are removed to be consistent with HuggingFace # The image tokens are removed to be consistent with HuggingFace
new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1) new_prompt += prompt_parts[i] + replacement_str
new_prompt += prompt_parts[-1]
new_token_ids: List[int] = [] new_token_ids: List[int] = []
placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids): for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id: if token == placeholder_token_id:
replacement_ids = repeat_and_pad_token( replacement_ids = repeat_and_pad_token(
placeholder_token_id, placeholder_token_id,
repeat_count=repeat_count, repeat_count=repeat_count[placeholder_token_idx],
pad_token_left=pad_token_left, pad_token_left=pad_token_left,
pad_token_right=pad_token_right, pad_token_right=pad_token_right,
) )
new_token_ids.extend(replacement_ids) new_token_ids.extend(replacement_ids)
placeholder_token_idx += 1
# No need to further scan the list since we only replace once # No need to further scan the list since we replaced all tokens
if placeholder_token_idx >= len(repeat_count):
new_token_ids.extend(prompt_token_ids[i + 1:]) new_token_ids.extend(prompt_token_ids[i + 1:])
break break
else: else: