mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:26:15 +08:00
[Model] Add multi-image input support for LLaVA-Next offline inference (#7230)
This commit is contained in:
parent
345be0e244
commit
5340a2dccf
@ -41,6 +41,10 @@ _TEST_DIR = os.path.dirname(__file__)
|
|||||||
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
|
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
|
||||||
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
||||||
|
|
||||||
|
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
|
||||||
|
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
|
||||||
|
List[List[Tuple[np.ndarray, int]]]]
|
||||||
|
|
||||||
|
|
||||||
def _read_prompts(filename: str) -> List[str]:
|
def _read_prompts(filename: str) -> List[str]:
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
@ -161,7 +165,7 @@ def example_encoder_decoder_prompts(
|
|||||||
decoder prompt) tuple.
|
decoder prompt) tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
* Encoder prompt list
|
* Encoder prompt list
|
||||||
* Decoder prompt list (reverse of encoder prompt list)
|
* Decoder prompt list (reverse of encoder prompt list)
|
||||||
'''
|
'''
|
||||||
@ -578,8 +582,7 @@ class VllmRunner:
|
|||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
images: Optional[Union[List[Image.Image],
|
images: Optional[PromptImageInput] = None,
|
||||||
List[List[Image.Image]]]] = None,
|
|
||||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||||
if images is not None:
|
if images is not None:
|
||||||
assert len(prompts) == len(images)
|
assert len(prompts) == len(images)
|
||||||
@ -623,10 +626,8 @@ class VllmRunner:
|
|||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
images: Optional[Union[List[Image.Image],
|
images: Optional[PromptImageInput] = None,
|
||||||
List[List[Image.Image]]]] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
audios: Optional[Union[List[Tuple[np.ndarray, int]],
|
|
||||||
List[List[Tuple[np.ndarray, int]]]]] = None
|
|
||||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||||
assert sampling_params.logprobs is not None
|
assert sampling_params.logprobs is not None
|
||||||
|
|
||||||
@ -676,10 +677,8 @@ class VllmRunner:
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
images: Optional[Union[List[Image.Image],
|
images: Optional[PromptImageInput] = None,
|
||||||
List[List[Image.Image]]]] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
audios: Optional[Union[List[Tuple[np.ndarray, int]],
|
|
||||||
List[List[Tuple[np.ndarray, int]]]]] = None,
|
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||||
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
||||||
|
|||||||
@ -6,24 +6,22 @@ from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
|
|||||||
from vllm.multimodal.utils import rescale_image_size
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||||
|
_ImageAssets)
|
||||||
from .utils import check_logprobs_close
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
pytestmark = pytest.mark.vlm
|
pytestmark = pytest.mark.vlm
|
||||||
|
|
||||||
_PREFACE = (
|
_LIMIT_IMAGE_PER_PROMPT = 4
|
||||||
"A chat between a curious human and an artificial intelligence assistant. "
|
|
||||||
"The assistant gives helpful, detailed, and polite answers to the human's "
|
|
||||||
"questions.")
|
|
||||||
|
|
||||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||||
"stop_sign":
|
"stop_sign":
|
||||||
f"{_PREFACE} USER: <image>\nWhat's the content of the image? ASSISTANT:",
|
"[INST] <image>\nWhat's the content of the image? [/INST]",
|
||||||
"cherry_blossom":
|
"cherry_blossom":
|
||||||
f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
|
"[INST] <image>\nWhat is the season? [/INST]",
|
||||||
})
|
})
|
||||||
|
|
||||||
models = ["llava-hf/llava-v1.6-vicuna-7b-hf"]
|
models = ["llava-hf/llava-v1.6-mistral-7b-hf"]
|
||||||
|
|
||||||
|
|
||||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||||
@ -114,19 +112,43 @@ def run_test(
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You must provide either `size_factors` or `sizes`")
|
raise ValueError("You must provide either `size_factors` or `sizes`")
|
||||||
|
|
||||||
|
_run_test(hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
inputs_per_image,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
inputs: List[Tuple[List[str], PromptImageInput]],
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
# max_model_len should be greater than image_feature_size
|
# max_model_len should be greater than image_feature_size
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=4096,
|
max_model_len=10240,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
enforce_eager=True) as vllm_model:
|
enforce_eager=True,
|
||||||
|
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
|
||||||
|
}) as vllm_model:
|
||||||
vllm_outputs_per_image = [
|
vllm_outputs_per_image = [
|
||||||
vllm_model.generate_greedy_logprobs(prompts,
|
vllm_model.generate_greedy_logprobs(prompts,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
images=images)
|
images=images)
|
||||||
for prompts, images in inputs_per_image
|
for prompts, images in inputs
|
||||||
]
|
]
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype,
|
with hf_runner(model, dtype=dtype,
|
||||||
@ -136,7 +158,7 @@ def run_test(
|
|||||||
max_tokens,
|
max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
images=images)
|
images=images)
|
||||||
for prompts, images in inputs_per_image
|
for prompts, images in inputs
|
||||||
]
|
]
|
||||||
|
|
||||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||||
@ -177,7 +199,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
|||||||
|
|
||||||
All the image fixtures for the test is under tests/images.
|
All the image fixtures for the test is under tests/images.
|
||||||
For huggingface runner, we provide the PIL images as input.
|
For huggingface runner, we provide the PIL images as input.
|
||||||
For vllm runner, we provide MultiModalDataDict objects
|
For vllm runner, we provide MultiModalDataDict objects
|
||||||
and corresponding MultiModalConfig as input.
|
and corresponding MultiModalConfig as input.
|
||||||
Note, the text input is also adjusted to abide by vllm contract.
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
The text output is sanitized to be able to compare with hf.
|
The text output is sanitized to be able to compare with hf.
|
||||||
@ -216,3 +238,48 @@ def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
|
|||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
|
||||||
|
model, dtype, max_tokens,
|
||||||
|
num_logprobs) -> None:
|
||||||
|
stop_sign = image_assets[0].pil_image
|
||||||
|
cherry_blossom = image_assets[1].pil_image
|
||||||
|
|
||||||
|
inputs = [(
|
||||||
|
[
|
||||||
|
"[INST] <image><image>\nDescribe 2 images. [/INST]",
|
||||||
|
"[INST] <image><image>\nDescribe 2 images. [/INST]",
|
||||||
|
"[INST] <image><image><image><image>\nDescribe 4 images. [/INST]",
|
||||||
|
"[INST] <image>\nWhat is the season? [/INST]"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[stop_sign, cherry_blossom],
|
||||||
|
# Images with different sizes and aspect-ratios
|
||||||
|
[
|
||||||
|
rescale_image_size(stop_sign, 0.1),
|
||||||
|
stop_sign,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
stop_sign,
|
||||||
|
rescale_image_size(stop_sign, 0.25),
|
||||||
|
cherry_blossom.resize((183, 488)),
|
||||||
|
cherry_blossom.resize((488, 183))
|
||||||
|
],
|
||||||
|
cherry_blossom,
|
||||||
|
])]
|
||||||
|
|
||||||
|
_run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
inputs,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|||||||
@ -6,8 +6,10 @@ from typing import Dict, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
|
||||||
from vllm.multimodal.utils import async_fetch_image, fetch_image
|
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
|
||||||
|
repeat_and_pad_placeholder_tokens)
|
||||||
|
|
||||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||||
TEST_IMAGE_URLS = [
|
TEST_IMAGE_URLS = [
|
||||||
@ -80,3 +82,34 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
|||||||
|
|
||||||
data_image_async = await async_fetch_image(data_url)
|
data_image_async = await async_fetch_image(data_url)
|
||||||
assert _image_equals(data_image_sync, data_image_async)
|
assert _image_equals(data_image_sync, data_image_async)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||||
|
def test_repeat_and_pad_placeholder_tokens(model):
|
||||||
|
config = AutoConfig.from_pretrained(model)
|
||||||
|
image_token_id = config.image_token_index
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
("<image>", 2, "<image><image>", [32000, 32000]),
|
||||||
|
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]),
|
||||||
|
("<image><image>", [3, 2], "<image><image><image><image><image>",
|
||||||
|
[32000, 32000, 32000, 32000, 32000]),
|
||||||
|
("Image:<image>Image:<image>!", [3, 2],
|
||||||
|
"Image:<image><image><image>Image:<image><image>!",
|
||||||
|
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]),
|
||||||
|
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]),
|
||||||
|
]
|
||||||
|
|
||||||
|
for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases:
|
||||||
|
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_token_ids=tokenizer.encode(prompt,
|
||||||
|
add_special_tokens=False),
|
||||||
|
placeholder_token_id=image_token_id,
|
||||||
|
repeat_count=repeat_count,
|
||||||
|
)
|
||||||
|
assert new_prompt == expected_prompt
|
||||||
|
assert new_token_ids == expected_token_ids
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""Minimal implementation of CLIPVisionModel intended to be only used
|
"""Minimal implementation of CLIPVisionModel intended to be only used
|
||||||
within a vision language model."""
|
within a vision language model."""
|
||||||
from array import array
|
from array import array
|
||||||
from typing import Iterable, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -84,7 +84,7 @@ def input_processor_for_clip(
|
|||||||
llm_inputs: LLMInputs,
|
llm_inputs: LLMInputs,
|
||||||
*,
|
*,
|
||||||
image_token_id: int,
|
image_token_id: int,
|
||||||
image_feature_size_override: Optional[int] = None,
|
image_feature_size_override: Optional[Union[int, List[int]]] = None,
|
||||||
):
|
):
|
||||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||||
@ -217,7 +217,7 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
class CLIPEncoder(nn.Module):
|
class CLIPEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Transformer encoder consisting of `config.num_hidden_layers` self
|
Transformer encoder consisting of `config.num_hidden_layers` self
|
||||||
attention layers. Each layer is a [`CLIPEncoderLayer`].
|
attention layers. Each layer is a [`CLIPEncoderLayer`].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||||
dummy_seq_data_for_clip, get_clip_image_feature_size,
|
dummy_seq_data_for_clip, get_clip_image_feature_size,
|
||||||
@ -223,6 +224,13 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
input_height=height,
|
input_height=height,
|
||||||
input_width=width,
|
input_width=width,
|
||||||
)
|
)
|
||||||
|
elif is_list_of(image_data, Image.Image):
|
||||||
|
image_feature_size = [
|
||||||
|
get_llava_next_image_feature_size(hf_config,
|
||||||
|
input_height=img.height,
|
||||||
|
input_width=img.width)
|
||||||
|
for img in image_data
|
||||||
|
]
|
||||||
elif isinstance(image_data, torch.Tensor):
|
elif isinstance(image_data, torch.Tensor):
|
||||||
image_feature_size = image_data.shape[0]
|
image_feature_size = image_data.shape[0]
|
||||||
else:
|
else:
|
||||||
@ -425,7 +433,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
self.config.image_grid_pinpoints,
|
self.config.image_grid_pinpoints,
|
||||||
self.config.vision_config.image_size,
|
self.config.vision_config.image_size,
|
||||||
)
|
)
|
||||||
other_patch_embeds = other_patch_embeds \
|
num_patches = num_patch_height * num_patch_width
|
||||||
|
|
||||||
|
# Image patches might be padded for batch processing
|
||||||
|
other_patch_embeds = other_patch_embeds[:num_patches] \
|
||||||
.view(num_patch_height, num_patch_width, height, width, -1)
|
.view(num_patch_height, num_patch_width, height, width, -1)
|
||||||
|
|
||||||
if "unpad" in strategy:
|
if "unpad" in strategy:
|
||||||
@ -496,7 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
self,
|
self,
|
||||||
image_input: LlavaNextImageInputs,
|
image_input: LlavaNextImageInputs,
|
||||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
return [image_input["data"]]
|
return [image_input["data"]]
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ within a vision language model."""
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from array import array
|
from array import array
|
||||||
from typing import Iterable, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -93,7 +93,7 @@ def input_processor_for_siglip(
|
|||||||
llm_inputs: LLMInputs,
|
llm_inputs: LLMInputs,
|
||||||
*,
|
*,
|
||||||
image_token_id: int,
|
image_token_id: int,
|
||||||
image_feature_size_override: Optional[int] = None,
|
image_feature_size_override: Optional[Union[int, List[int]]] = None,
|
||||||
):
|
):
|
||||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||||
|
|||||||
@ -189,10 +189,13 @@ def repeat_and_pad_placeholder_tokens(
|
|||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
*,
|
*,
|
||||||
placeholder_token_id: int,
|
placeholder_token_id: int,
|
||||||
repeat_count: int = 1,
|
repeat_count: Union[int, List[int]],
|
||||||
pad_token_left: Optional[int] = None,
|
pad_token_left: Optional[int] = None,
|
||||||
pad_token_right: Optional[int] = None,
|
pad_token_right: Optional[int] = None,
|
||||||
) -> Tuple[Optional[str], List[int]]:
|
) -> Tuple[Optional[str], List[int]]:
|
||||||
|
if isinstance(repeat_count, int):
|
||||||
|
repeat_count = [repeat_count]
|
||||||
|
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
new_prompt = None
|
new_prompt = None
|
||||||
else:
|
else:
|
||||||
@ -201,13 +204,6 @@ def repeat_and_pad_placeholder_tokens(
|
|||||||
tokenizer.decode(pad_token_left))
|
tokenizer.decode(pad_token_left))
|
||||||
pad_token_str_right = (None if pad_token_right is None else
|
pad_token_str_right = (None if pad_token_right is None else
|
||||||
tokenizer.decode(pad_token_right))
|
tokenizer.decode(pad_token_right))
|
||||||
replacement_str = "".join(
|
|
||||||
repeat_and_pad_token(
|
|
||||||
placeholder_token_str,
|
|
||||||
repeat_count=repeat_count,
|
|
||||||
pad_token_left=pad_token_str_left,
|
|
||||||
pad_token_right=pad_token_str_right,
|
|
||||||
))
|
|
||||||
|
|
||||||
placeholder_token_count = prompt.count(placeholder_token_str)
|
placeholder_token_count = prompt.count(placeholder_token_str)
|
||||||
# This is an arbitrary number to distinguish between the two cases
|
# This is an arbitrary number to distinguish between the two cases
|
||||||
@ -216,28 +212,45 @@ def repeat_and_pad_placeholder_tokens(
|
|||||||
"Please follow the prompt format that is "
|
"Please follow the prompt format that is "
|
||||||
"documented on HuggingFace which does not involve "
|
"documented on HuggingFace which does not involve "
|
||||||
"repeating %s tokens.", placeholder_token_str)
|
"repeating %s tokens.", placeholder_token_str)
|
||||||
elif placeholder_token_count > 1:
|
if placeholder_token_count < len(repeat_count):
|
||||||
logger.warning("Multiple multi-modal input is not supported yet, "
|
logger.warning(
|
||||||
"so any extra placeholder tokens will be treated "
|
"The number of multi-modal placeholder tokens in the prompt "
|
||||||
"as plain text.")
|
"is less than the number of multi-modal inputs. Extra "
|
||||||
|
"placeholder tokens will be treated as plain text")
|
||||||
|
repeat_count = repeat_count[:placeholder_token_count]
|
||||||
|
|
||||||
# The image tokens are removed to be consistent with HuggingFace
|
prompt_parts = prompt.split(placeholder_token_str,
|
||||||
new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1)
|
maxsplit=len(repeat_count))
|
||||||
|
new_prompt = ""
|
||||||
|
for i, repeat_count_item in enumerate(repeat_count):
|
||||||
|
replacement_str = "".join(
|
||||||
|
repeat_and_pad_token(
|
||||||
|
placeholder_token_str,
|
||||||
|
repeat_count=repeat_count_item,
|
||||||
|
pad_token_left=pad_token_str_left,
|
||||||
|
pad_token_right=pad_token_str_right,
|
||||||
|
))
|
||||||
|
# The image tokens are removed to be consistent with HuggingFace
|
||||||
|
new_prompt += prompt_parts[i] + replacement_str
|
||||||
|
new_prompt += prompt_parts[-1]
|
||||||
|
|
||||||
new_token_ids: List[int] = []
|
new_token_ids: List[int] = []
|
||||||
|
placeholder_token_idx = 0
|
||||||
for i, token in enumerate(prompt_token_ids):
|
for i, token in enumerate(prompt_token_ids):
|
||||||
if token == placeholder_token_id:
|
if token == placeholder_token_id:
|
||||||
replacement_ids = repeat_and_pad_token(
|
replacement_ids = repeat_and_pad_token(
|
||||||
placeholder_token_id,
|
placeholder_token_id,
|
||||||
repeat_count=repeat_count,
|
repeat_count=repeat_count[placeholder_token_idx],
|
||||||
pad_token_left=pad_token_left,
|
pad_token_left=pad_token_left,
|
||||||
pad_token_right=pad_token_right,
|
pad_token_right=pad_token_right,
|
||||||
)
|
)
|
||||||
new_token_ids.extend(replacement_ids)
|
new_token_ids.extend(replacement_ids)
|
||||||
|
placeholder_token_idx += 1
|
||||||
|
|
||||||
# No need to further scan the list since we only replace once
|
# No need to further scan the list since we replaced all tokens
|
||||||
new_token_ids.extend(prompt_token_ids[i + 1:])
|
if placeholder_token_idx >= len(repeat_count):
|
||||||
break
|
new_token_ids.extend(prompt_token_ids[i + 1:])
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
new_token_ids.append(token)
|
new_token_ids.append(token)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user