[2/N] handling placeholders in merged multi-modal processor (#10485)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-11-23 13:25:09 +08:00 committed by GitHub
parent 4634a89d18
commit c8acd80548
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 979 additions and 151 deletions

View File

@ -0,0 +1,370 @@
from typing import cast
import pytest
from transformers import BatchFeature
from vllm.multimodal.processing import (PromptReplacement, find_text_matches,
find_token_matches, iter_token_matches,
iter_token_runs, replace_text_matches)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby
# yapf: disable
@pytest.mark.parametrize(
("token_ids", "expected"),
[
([], []),
(
[32000, 32000, 32000],
[{ "token_id": 32000, "start_idx": 0, "length": 3 }],
),
(
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[
{ "token_id": 9833, "start_idx": 0, "length": 1 },
{ "token_id": 28747, "start_idx": 1, "length": 1 },
{ "token_id": 32000, "start_idx": 2, "length": 3 },
{ "token_id": 9833, "start_idx": 5, "length": 1 },
{ "token_id": 28747, "start_idx": 6, "length": 1 },
{ "token_id": 32000, "start_idx": 7, "length": 2 },
{ "token_id": 918, "start_idx": 9, "length": 1 },
],
),
],
)
# yapf: enable
def test_iter_token_runs(token_ids, expected):
result = list(iter_token_runs(token_ids))
# Only displayed on error
print("result:", result)
# Manually constructed results
assert [item._asdict() for item in result] == expected
# Invariants
assert sum(run_info.length for run_info in result) == len(token_ids)
# yapf: disable
@pytest.mark.parametrize(
("token_ids", "match_ids", "expected"),
[
([], [], [{ "start_idx": 0, "end_idx": 0 }]),
([], [32000], []),
(
[32000, 32000, 32000],
[32000],
[
{ "start_idx": 0, "end_idx": 1 },
{ "start_idx": 1, "end_idx": 2 },
{ "start_idx": 2, "end_idx": 3 },
],
),
(
[32000, 32000, 32000],
[32000, 32000],
[{ "start_idx": 0, "end_idx": 2 }],
),
(
[32000, 32000, 32000],
[32000, 32000, 32000],
[{ "start_idx": 0, "end_idx": 3 }],
),
(
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[28747, 32000],
[
{ "start_idx": 1, "end_idx": 3 },
{ "start_idx": 6, "end_idx": 8 },
],
),
(
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[28747, 32000, 32000, 32000],
[
{ "start_idx": 1, "end_idx": 5 },
],
),
(
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[28747, 0, 32000],
[],
),
],
)
# yapf: enable
def test_iter_token_matches(token_ids, match_ids, expected):
result = list(iter_token_matches(token_ids, match_ids))
# Manually constructed results
assert [item._asdict() for item in result] == expected
# Invariants
match_lens = [end - start for start, end in result]
print("match_lens:", match_lens) # Only displayed on error
assert all(match_len == len(match_ids) for match_len in match_lens)
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "expected_by_key"),
[
(
[],
{
"pattern_1": [],
"pattern_2": [32000],
},
{
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
"pattern_2": [],
}
),
(
[32000, 32000, 32000, 32000],
{
"pattern_1": [32000],
"pattern_2": [32000, 32000],
"pattern_3": [32000, 32000, 32000],
},
{
"pattern_1": [
{ "start_idx": 0, "end_idx": 1 },
{ "start_idx": 1, "end_idx": 2 },
{ "start_idx": 2, "end_idx": 3 },
{ "start_idx": 3, "end_idx": 4 },
],
"pattern_2": [
{ "start_idx": 0, "end_idx": 2 },
{ "start_idx": 2, "end_idx": 4 },
],
"pattern_3": [
{ "start_idx": 0, "end_idx": 3 },
],
},
),
(
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
{
"pattern_1": [28747, 32000],
"pattern_2": [28747, 32000, 32000, 32000],
"pattern_3": [28747, 0, 32000],
},
{
"pattern_1": [
{ "start_idx": 1, "end_idx": 3 },
{ "start_idx": 6, "end_idx": 8 },
],
"pattern_2": [
{ "start_idx": 1, "end_idx": 5 },
],
"pattern_3": [],
},
),
],
)
# yapf: enable
def test_find_token_matches(prompt, target_by_key, expected_by_key):
# Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object())
result = find_token_matches(
prompt,
[
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
for key, target in target_by_key.items()
],
)
# Only displayed on error
print("result:", result)
# Manually constructed results
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
assert {
key: [
dict(start_idx=item.start_idx, end_idx=item.end_idx)
for item in result_groups.get(key, [])
]
for key in expected_by_key
} == expected_by_key
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "expected_by_key"),
[
# Detokenized test cases of `test_find_token_matches`
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
(
"",
{
"pattern_1": "",
"pattern_2": "<image>",
},
{
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
"pattern_2": [],
}
),
(
"<image><image><image><image>",
{
"pattern_1": "<image>",
"pattern_2": "<image><image>",
"pattern_3": "<image><image><image>",
},
{
"pattern_1": [
{ "start_idx": 0, "end_idx": 7 },
{ "start_idx": 7, "end_idx": 14 },
{ "start_idx": 14, "end_idx": 21 },
{ "start_idx": 21, "end_idx": 28 },
],
"pattern_2": [
{ "start_idx": 0, "end_idx": 14 },
{ "start_idx": 14, "end_idx": 28 },
],
"pattern_3": [
{ "start_idx": 0, "end_idx": 21 },
],
},
),
(
"Image:<image><image><image>Image:<image><image>!",
{
"pattern_1": "Image:<image>",
"pattern_2": "Image:<image><image><image>",
"pattern_3": "Image:<unk><image>",
},
{
"pattern_1": [
{ "start_idx": 0, "end_idx": 13 },
{ "start_idx": 27, "end_idx": 40 },
],
"pattern_2": [
{ "start_idx": 0, "end_idx": 27 },
],
"pattern_3": [],
},
),
# Test regex escape
(
"<|image|><image><|image|><image>",
{
"pattern_1": "<|image|>",
"pattern_2": "<|image|><image>",
"pattern_3": "<|image|><image><|image|>",
},
{
"pattern_1": [
{ "start_idx": 0, "end_idx": 9 },
{ "start_idx": 16, "end_idx": 25 },
],
"pattern_2": [
{ "start_idx": 0, "end_idx": 16 },
{ "start_idx": 16, "end_idx": 32 },
],
"pattern_3": [
{ "start_idx": 0, "end_idx": 25 },
],
},
),
],
)
# yapf: enable
def test_find_text_matches(prompt, target_by_key, expected_by_key):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
result = find_text_matches(
prompt,
[
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
for key, target in target_by_key.items()
],
)
# Only displayed on error
print("result:", result)
# Manually constructed results
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
assert {
key: [
dict(start_idx=item.start_idx, end_idx=item.end_idx)
for item in result_groups.get(key, [])
]
for key in expected_by_key
} == expected_by_key
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"),
[
(
"Image:<image>Image:<image><image>!",
{
# We use `<image>` before `Image:` to test matches that
# occur out of order
"pattern_1": "<image>",
"pattern_2": "Image:",
"pattern_3": "!",
},
{
# Test whether target is confused with repl_unit
"pattern_1": ("<image><image>", 1),
# Test empty repl_unit
"pattern_2": ("", 1),
# Test multiple repl_count
"pattern_3": ("?", 2),
},
{
# Test no replacement
0: "Image:<image>Image:<image><image>!",
# Test single replacement
1: "<image><image>Image:<image><image>??",
# Test repeated replacement
2: "<image><image><image><image><image>??",
},
),
]
)
# yapf: enable
def test_find_replace_text(
prompt,
target_by_key,
repl_by_key,
expected_by_mm_count,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
matches = find_text_matches(
prompt,
[
PromptReplacement(target, *repl_by_key[key]) \
.bind(key, mock_tokenizer)
for key, target in target_by_key.items()
],
)
result_by_mm_count = {
mm_count: replace_text_matches(
prompt,
matches,
{key: list(range(mm_count))
for key in repl_by_key},
BatchFeature(),
)
for mm_count in expected_by_mm_count
}
# Only displayed on error
print("matches:", matches)
print("result_by_mm_count:", result_by_mm_count)
# Manually constructed results
assert result_by_mm_count == expected_by_mm_count

View File

@ -139,7 +139,8 @@ def test_repeat_and_pad_placeholder_tokens(model):
2,
"<image><image><image>",
[32000, 32000, 32000],
[{ "offset": 0, "length": 2 }]),
[{ "offset": 0, "length": 2 }],
),
(
"<image><image>",
[3, 2],

View File

@ -203,14 +203,7 @@ class MultiModalInputsV2(TypedDict):
"""The type of inputs."""
prompt: str
"""
The original, unprocessed prompt text.
Note:
Since prompt text is not required by vLLM internals, we leave this
unprocessed to save CPU computation. You can still call
:code:`tokenizer.decode(prompt_token_ids)` to get the processed text.
"""
"""The processed prompt text."""
prompt_token_ids: List[int]
"""The processed token IDs which includes placeholder tokens."""

View File

@ -1,34 +1,91 @@
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import lru_cache, partial
from typing import (Any, Callable, Collection, Generic, List, Mapping,
Optional, TypedDict, TypeVar, final)
from functools import lru_cache
from itertools import groupby
from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union
import numpy as np
from transformers import BatchFeature
from typing_extensions import TypeAlias
from typing_extensions import TypeAlias, TypedDict
from vllm.inputs import InputProcessingContext
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import is_list_of
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
VideoItem)
_T = TypeVar("_T")
ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]]
"""
Given the original data item, HF-processed data, and index of the processed
item, output the replacement token IDs to be allocated in vLLM.
"""
def bind_prompt_sequence(
seq: Union[str, list[int]],
tokenizer: AnyTokenizer,
) -> "_BoundPromptSequence":
"""
Bind a text or token sequence to a tokenizer so that it can be
lazily converted into the other format on demand.
"""
return _BoundPromptSequence(
tokenizer=tokenizer,
_text=seq if isinstance(seq, str) else None,
_token_ids=seq if isinstance(seq, list) else None,
)
_T = TypeVar("_T")
_S = TypeVar("_S", str, list[int])
@dataclass
class PromptReplacement(Generic[_S, _T]):
target: _S
"""The text or token sequence to find and replace."""
repl_unit: _S
"""
The unit making up the replacement text or token sequence.
See :code:`repl_count` for more details.
"""
repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
"""
Given the original multi-modal items for this modality, HF-processed data,
and index of the processed item, output the number of repetitions of
:code:`repl_unit` to build up the replacement text or token sequence.
For convenience, you can pass in an integer if the number of repetitions is
a constant.
"""
def __repr__(self) -> str:
return (f"{type(self).__name__}(target={self.target!r}, "
f"repl_unit={self.repl_unit!r})")
def bind(
self,
modality: str,
tokenizer: AnyTokenizer,
) -> "_BoundPromptReplacement[_T]":
return _BoundPromptReplacement(
modality=modality,
target=bind_prompt_sequence(self.target, tokenizer),
repl_unit=bind_prompt_sequence(self.repl_unit, tokenizer),
repl_count=self.repl_count,
)
@dataclass
class ModalityProcessingMetadata(Generic[_T]):
placeholder_replacements: Mapping[str, ReplacementFunc]
prompt_repls: Sequence[Union[PromptReplacement[str, _T],
PromptReplacement[list[int], _T]]]
"""
A dictionary where each item represents the original placeholder in the
prompt text and the corresponding replacement.
Defines each text or token sequence to replace in the HF-processed prompt.
This is skipped if the HF-processed prompt is found to already contain
the replacement prompts.
"""
@ -52,46 +109,138 @@ Note:
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
MultiModalMultiData: TypeAlias = List[_T]
"""
A list of data items, where the number of data items allowed
per modality is restricted by :code:`--limit-mm-per-prompt`.
"""
def _encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: bool = False,
) -> list[int]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
@final
class MultiModalMultiDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: MultiModalMultiData[ImageItem]
"""The input images."""
video: MultiModalMultiData[VideoItem]
"""The input videos."""
audio: MultiModalMultiData[AudioItem]
"""The input audios."""
@lru_cache(maxsize=2048)
def _cached_encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: bool = False,
) -> list[int]:
return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]]
"""
A dictionary containing an entry for each modality type to input.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalMultiDataBuiltins` as long as a customized plugin
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
def _decode(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
@lru_cache(maxsize=2048)
def _cached_decode(
tokenizer: AnyTokenizer,
token_ids: tuple[int, ...],
*,
skip_special_tokens: bool = False,
) -> str:
return _decode(tokenizer,
list(token_ids),
skip_special_tokens=skip_special_tokens)
class _HasModalityAttr(Protocol):
modality: str
class _HasModalityProp(Protocol):
@property
def modality(self) -> str:
...
_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
"""Convenience function to apply :func:`full_groupby` based on modality."""
return full_groupby(values, key=lambda x: x.modality)
@dataclass
class _BoundPromptSequence:
tokenizer: AnyTokenizer
_text: Optional[str]
_token_ids: Optional[list[int]]
def __post_init__(self) -> None:
if self._text is None and self._token_ids is None:
raise ValueError("At least one of 'text' and 'token_ids' must be "
"specified")
@property
def text(self) -> str:
if self._text is None:
assert self._token_ids is not None
self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))
return self._text
@property
def token_ids(self) -> list[int]:
if self._token_ids is None:
assert self._text is not None
self._token_ids = _cached_encode(self.tokenizer, self._text)
return self._token_ids
def __repr__(self) -> str:
return (f"{type(self).__name__}(_text={self._text!r}, "
f"_token_ids={self._token_ids!r})")
@dataclass
class _BoundPromptReplacement(Generic[_T]):
modality: str
target: _BoundPromptSequence
repl_unit: _BoundPromptSequence
repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
def get_count(
self,
mm_items: list[_T],
hf_inputs: BatchFeature,
item_idx: int,
) -> int:
repl_count = self.repl_count
if isinstance(repl_count, int):
return repl_count
return repl_count(mm_items, hf_inputs, item_idx)
def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
"""
Convert a :class:`MultiModalDataDict` containing single data items
to a :class:`MultiModalMultiDataDict` containing multiple data items
per entry.
"""
multi_data: Mapping[str, MultiModalMultiData[Any]] = {}
multi_data = dict[str, list[Any]]()
for k, v in data.items():
# yapf: disable
@ -107,86 +256,279 @@ def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
return multi_data
def encode_no_special_tokens(
tokenizer: AnyTokenizer,
text: str,
) -> List[int]:
class _TokenRun(NamedTuple):
token_id: int
start_idx: int
length: int
def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=False)`.
Yield the starting index and length of each run of tokens that are the same.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text, bos=False, eos=False)
start_idx = 0
return tokenizer.encode(text, add_special_tokens=False)
for token_id, it in groupby(token_ids):
length = sum(1 for _ in it)
yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length)
start_idx += length
@lru_cache
def candidate_placeholders(
tokenizer: AnyTokenizer,
placeholder_text: str,
) -> Collection[List[int]]:
"""Generate token ID sequences that may represent a placeholder text."""
# When the placeholder text is not mapped to a special token ID,
# it may be tokenized differently based on whether it is at the start/end
# of the string. So, we go through each combination of whether the text
# is at the start and end boundaries of the string
class _PlaceholderInfo(NamedTuple):
modality: str
offset: int
length: int
# Matches the placeholder when it is in the middle of the string
start_id, = encode_no_special_tokens(tokenizer, "a")
end_id, = encode_no_special_tokens(tokenizer, "b")
def to_range(self) -> PlaceholderRange:
return PlaceholderRange(offset=self.offset, length=self.length)
candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text)
start_id_, *candidate_a = encode_no_special_tokens(
tokenizer,
f"a{placeholder_text}",
)
assert start_id == start_id_
start_id_, *candidate_ab, end_id_ = encode_no_special_tokens(
tokenizer,
f"a{placeholder_text}b",
)
assert start_id == start_id_ and end_id == end_id_
*candidate_b, end_id_ = encode_no_special_tokens(
tokenizer,
f"{placeholder_text}b",
)
assert end_id == end_id_
# Remove duplicates (need to convert to tuple to be hashable)
unique_candidates = {
tuple(c)
for c in [candidate_basic, candidate_a, candidate_ab, candidate_b]
def iter_placeholders(
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
token_ids: list[int],
*,
min_placeholder_count: int,
) -> Iterable[_PlaceholderInfo]:
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
placeholder_ids_by_modality = {
modality: {
token_id
for prompt_repl in repls
for token_id in prompt_repl.repl_unit.token_ids
}
for modality, repls in full_groupby_modality(prompt_repls)
}
# Convert back to list
return [list(c) for c in unique_candidates]
for run_info in iter_token_runs(token_ids):
if run_info.length > min_placeholder_count:
for (modality,
placeholder_ids) in placeholder_ids_by_modality.items():
if run_info.token_id in placeholder_ids:
yield _PlaceholderInfo(
modality=modality,
offset=run_info.start_idx,
length=run_info.length,
)
def apply_placeholders(
token_ids: List[int],
placeholder_ids: List[int],
get_replacement_ids: Callable[[], List[int]],
) -> Optional[PlaceholderRange]:
class _TokenMatch(NamedTuple):
start_idx: int
end_idx: int
def iter_token_matches(
token_ids: list[int],
match_ids: list[int],
) -> Iterable[_TokenMatch]:
"""Yield each occurrence of :code:`match_ids` in :code:`token_ids`."""
match_len = len(match_ids)
last_end_idx = 0
for start_idx in range(len(token_ids) - match_len + 1):
if start_idx < last_end_idx:
continue # Exclude overlapping matches
end_idx = start_idx + match_len
if token_ids[start_idx:end_idx] == match_ids:
yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
last_end_idx = end_idx
class _PromptReplacementMatch(ABC, Generic[_T, _S]):
prompt_repl: _BoundPromptReplacement[_T]
@property
def modality(self) -> str:
return self.prompt_repl.modality
@property
@abstractmethod
def start_idx(self) -> int:
raise NotImplementedError
@property
@abstractmethod
def end_idx(self) -> int:
raise NotImplementedError
@abstractmethod
def get_repl(
self,
mm_items: list[_T],
hf_inputs: BatchFeature,
item_idx: int,
) -> _S:
raise NotImplementedError
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r}, "
f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
@dataclass(repr=False)
class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
prompt_repl: _BoundPromptReplacement[_T]
match: _TokenMatch
@property
def start_idx(self) -> int:
return self.match.start_idx
@property
def end_idx(self) -> int:
return self.match.end_idx
def get_repl(
self,
mm_items: list[_T],
hf_inputs: BatchFeature,
item_idx: int,
) -> list[int]:
prompt_repl = self.prompt_repl
count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
return prompt_repl.repl_unit.token_ids * count
@dataclass(repr=False)
class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
prompt_repl: _BoundPromptReplacement[_T]
match: re.Match[str]
@property
def start_idx(self) -> int:
return self.match.start()
@property
def end_idx(self) -> int:
return self.match.end()
def get_repl(
self,
mm_items: list[_T],
hf_inputs: BatchFeature,
item_idx: int,
) -> str:
prompt_repl = self.prompt_repl
count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
return prompt_repl.repl_unit.text * count
def find_token_matches(
prompt: list[int],
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
) -> list[_PromptReplacementTokenMatch[_T]]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [
_PromptReplacementTokenMatch(prompt_repl, match)
for prompt_repl in prompt_repls
for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
]
def find_text_matches(
prompt: str,
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
) -> list[_PromptReplacementTextMatch[_T]]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [
_PromptReplacementTextMatch(prompt_repl, match)
for prompt_repl in prompt_repls
for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
]
def _resolve_matches(
prompt: _S,
matches: Sequence[_PromptReplacementMatch[_T, _S]],
) -> list[_PromptReplacementMatch[_T, _S]]:
"""
Find the first occurrence of :code:`placeholder_ids`,
and replace it with the output of :code:`get_replacement_ids`.
This function updates :code:`token_ids` in place.
Resolve :code:`matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones.
"""
placeholder_length = len(placeholder_ids)
num_matches_by_idx = np.zeros(len(prompt), dtype=int)
for match in matches:
num_matches_by_idx[match.start_idx:match.end_idx] += 1
for start_idx in range(len(token_ids) - placeholder_length + 1):
if token_ids[start_idx:placeholder_length] == placeholder_ids:
token_ids[start_idx:placeholder_length] = get_replacement_ids()
duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1)
if len(duplicate_matches_idxs) > 0:
raise ValueError("Unable to find a unique replacement "
f"at indices={duplicate_matches_idxs} "
f"of prompt={prompt}")
return PlaceholderRange(offset=start_idx,
length=placeholder_length)
return sorted(matches, key=lambda x: x.start_idx)
return None
def _replace_matches(
prompt: _S,
matches: Sequence[_PromptReplacementMatch[_T, _S]],
mm_items_by_modality: Mapping[str, list[_T]],
hf_inputs: BatchFeature,
) -> list[_S]:
out_seqs = list[_S]()
prev_end_idx = 0
next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality}
for match in _resolve_matches(prompt, matches):
modality = match.modality
mm_items = mm_items_by_modality[modality]
item_idx = next_idx_by_modality[modality]
if item_idx >= len(mm_items):
continue
start_idx = match.start_idx
end_idx = match.end_idx
repl_ids = match.get_repl(mm_items, hf_inputs, item_idx)
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids)
prev_end_idx = end_idx
next_idx_by_modality[modality] += 1
out_seqs.append(prompt[prev_end_idx:])
return out_seqs
def replace_token_matches(
prompt: list[int],
matches: Sequence[_PromptReplacementMatch[_T, list[int]]],
mm_items_by_modality: Mapping[str, list[_T]],
hf_inputs: BatchFeature,
) -> list[int]:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches:
return prompt
token_id_seqs = _replace_matches(
prompt,
matches,
mm_items_by_modality,
hf_inputs,
)
return flatten_2d_lists(token_id_seqs)
def replace_text_matches(
prompt: str,
matches: Sequence[_PromptReplacementMatch[_T, str]],
mm_items_by_modality: Mapping[str, list[_T]],
hf_inputs: BatchFeature,
) -> str:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches:
return prompt
texts = _replace_matches(
prompt,
matches,
mm_items_by_modality,
hf_inputs,
)
return "".join(texts)
class MultiModalProcessor:
@ -212,62 +554,166 @@ class MultiModalProcessor:
) -> MultiModalInputsV2:
return self.apply(prompt, mm_data, mm_processor_kwargs)
def apply(
def _find_placeholders(
self,
all_prompt_repls: Sequence[_BoundPromptReplacement[Any]],
new_token_ids: list[int],
*,
# To avoid false positives from multi-input when detecting
# whether placeholder tokens have been inserted, in case
# the target sequence is a subset of the replacement tokens
min_placeholder_count: int = 16,
) -> list[_PlaceholderInfo]:
return list(
iter_placeholders(
all_prompt_repls,
new_token_ids,
min_placeholder_count=min_placeholder_count,
))
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
tokenizer = self.ctx.tokenizer
) -> BatchFeature:
hf_processor = self.ctx.get_hf_processor()
processed_inputs = hf_processor(
return hf_processor(
text=prompt, # type: ignore
**mm_data,
**mm_processor_kwargs,
)
new_token_ids, = processed_inputs.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs(processed_inputs)
mm_placeholders: Mapping[str, List[PlaceholderRange]] = {}
def _bind_prompt_replacements(
self,
mm_data: MultiModalDataDict,
) -> list[_BoundPromptReplacement[Any]]:
tokenizer = self.ctx.tokenizer
for modality, orig_inputs in to_multi_format(mm_data).items():
assert isinstance(orig_inputs, list)
return [
prompt_repl.bind(modality, tokenizer)
for modality, metadata in self.metadata.items()
if modality in mm_data for prompt_repl in metadata.prompt_repls
]
metadata = self.metadata[modality]
placeholder_replacements = metadata.placeholder_replacements
def _apply_prompt_replacements(
self,
mm_data: MultiModalDataDict,
hf_inputs: BatchFeature,
token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
tokenizer = self.ctx.tokenizer
modality_placeholders: List[PlaceholderRange] = []
mm_items = to_multi_format(mm_data)
token_matches = find_token_matches(token_ids, prompt_repls)
for item_idx, orig_item in enumerate(orig_inputs):
for match_text, replace_fn in placeholder_replacements.items():
candidates = candidate_placeholders(tokenizer, match_text)
get_replacement_ids = partial(
replace_fn,
orig_item,
processed_inputs,
item_idx,
)
# If the search text does not represent a special token,
# it may have different token IDs in the prompt, because
# the tokens may go across the boundaries of the search text.
# ----
# e.g. when searching for "foo" in "food", if "food" itself makes
# up a token, then the token ID of "foo" will not appear at all
# ----
# Since it is inefficient to search for all possible tokenizations
# of the search text in the prompt, we instead perform string
# replacement on the decoded token IDs, then encode them back.
if all(
len(matches) >= len(mm_data[modality])
for modality, matches in full_groupby_modality(token_matches)
): # yapf: disable
token_ids = replace_token_matches(
token_ids,
token_matches,
mm_items,
hf_inputs,
)
for match_ids in candidates:
# TODO(youkaichao): Don't update new_token_ids
placeholders = apply_placeholders(
new_token_ids,
match_ids,
get_replacement_ids,
)
text = _decode(tokenizer, token_ids)
matched_repls = [match.prompt_repl for match in token_matches]
else:
text = _decode(tokenizer, token_ids)
if placeholders is not None:
modality_placeholders.append(placeholders)
text_matches = find_text_matches(text, prompt_repls)
text = replace_text_matches(
text,
text_matches,
mm_items,
hf_inputs,
)
# yapf: disable
mm_placeholders[modality] = modality_placeholders # type: ignore[index]
# yapf: enable
token_ids = _encode(tokenizer, text)
matched_repls = [match.prompt_repl for match in text_matches]
placeholders = self._find_placeholders(matched_repls, token_ids)
# Sanity check
assert len(placeholders) == len(matched_repls), dict(
# Log this information for easier debugging
text=text,
token_ids=token_ids,
placeholders=placeholders,
matched_repls=matched_repls,
)
return token_ids, text, placeholders
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
"""
Process multi-modal inputs to be used in vLLM.
The main steps are:
1. Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
2. Find and replace sequences in the token IDs with placeholder tokens.
The number of placeholder tokens equals the feature size of the
multi-modal data outputted by the multi-modal encoder.
3. Extract information about the placeholder tokens from the
processed token IDs.
"""
tokenizer = self.ctx.tokenizer
hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
mm_processor_kwargs)
prompt_ids, = hf_inputs.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs(hf_inputs)
all_prompt_repls = self._bind_prompt_replacements(mm_data)
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
all_placeholders = self._find_placeholders(all_prompt_repls,
prompt_ids)
if all_placeholders:
prompt_text = _decode(tokenizer, prompt_ids)
else:
(
prompt_ids,
prompt_text,
all_placeholders,
) = self._apply_prompt_replacements(
mm_data,
hf_inputs,
prompt_ids,
all_prompt_repls,
)
mm_placeholders = {
modality: [item.to_range() for item in items]
for modality, items in full_groupby_modality(all_placeholders)
}
return MultiModalInputsV2(
type="multimodal",
prompt=prompt,
prompt_token_ids=new_token_ids,
prompt=prompt_text,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders,
)

View File

@ -19,7 +19,8 @@ import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections.abc import Mapping
from collections import defaultdict
from collections.abc import Iterable, Mapping
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
@ -905,6 +906,23 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return [item for sublist in lists for item in sublist]
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
"""
Unlike :class:`itertools.groupby`, groups are not broken by
non-contiguous data.
"""
groups = defaultdict[_K, list[_V]](list)
for value in values:
groups[key(value)].append(value)
return groups.items()
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def init_cached_hf_modules() -> None: