mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 21:05:01 +08:00
[2/N] handling placeholders in merged multi-modal processor (#10485)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
4634a89d18
commit
c8acd80548
370
tests/multimodal/test_processing.py
Normal file
370
tests/multimodal/test_processing.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import BatchFeature
|
||||||
|
|
||||||
|
from vllm.multimodal.processing import (PromptReplacement, find_text_matches,
|
||||||
|
find_token_matches, iter_token_matches,
|
||||||
|
iter_token_runs, replace_text_matches)
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils import full_groupby
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("token_ids", "expected"),
|
||||||
|
[
|
||||||
|
([], []),
|
||||||
|
(
|
||||||
|
[32000, 32000, 32000],
|
||||||
|
[{ "token_id": 32000, "start_idx": 0, "length": 3 }],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
|
[
|
||||||
|
{ "token_id": 9833, "start_idx": 0, "length": 1 },
|
||||||
|
{ "token_id": 28747, "start_idx": 1, "length": 1 },
|
||||||
|
{ "token_id": 32000, "start_idx": 2, "length": 3 },
|
||||||
|
{ "token_id": 9833, "start_idx": 5, "length": 1 },
|
||||||
|
{ "token_id": 28747, "start_idx": 6, "length": 1 },
|
||||||
|
{ "token_id": 32000, "start_idx": 7, "length": 2 },
|
||||||
|
{ "token_id": 918, "start_idx": 9, "length": 1 },
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# yapf: enable
|
||||||
|
def test_iter_token_runs(token_ids, expected):
|
||||||
|
result = list(iter_token_runs(token_ids))
|
||||||
|
|
||||||
|
# Only displayed on error
|
||||||
|
print("result:", result)
|
||||||
|
|
||||||
|
# Manually constructed results
|
||||||
|
assert [item._asdict() for item in result] == expected
|
||||||
|
|
||||||
|
# Invariants
|
||||||
|
assert sum(run_info.length for run_info in result) == len(token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("token_ids", "match_ids", "expected"),
|
||||||
|
[
|
||||||
|
([], [], [{ "start_idx": 0, "end_idx": 0 }]),
|
||||||
|
([], [32000], []),
|
||||||
|
(
|
||||||
|
[32000, 32000, 32000],
|
||||||
|
[32000],
|
||||||
|
[
|
||||||
|
{ "start_idx": 0, "end_idx": 1 },
|
||||||
|
{ "start_idx": 1, "end_idx": 2 },
|
||||||
|
{ "start_idx": 2, "end_idx": 3 },
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[32000, 32000, 32000],
|
||||||
|
[32000, 32000],
|
||||||
|
[{ "start_idx": 0, "end_idx": 2 }],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[32000, 32000, 32000],
|
||||||
|
[32000, 32000, 32000],
|
||||||
|
[{ "start_idx": 0, "end_idx": 3 }],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
|
[28747, 32000],
|
||||||
|
[
|
||||||
|
{ "start_idx": 1, "end_idx": 3 },
|
||||||
|
{ "start_idx": 6, "end_idx": 8 },
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
|
[28747, 32000, 32000, 32000],
|
||||||
|
[
|
||||||
|
{ "start_idx": 1, "end_idx": 5 },
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
|
[28747, 0, 32000],
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# yapf: enable
|
||||||
|
def test_iter_token_matches(token_ids, match_ids, expected):
|
||||||
|
result = list(iter_token_matches(token_ids, match_ids))
|
||||||
|
|
||||||
|
# Manually constructed results
|
||||||
|
assert [item._asdict() for item in result] == expected
|
||||||
|
|
||||||
|
# Invariants
|
||||||
|
match_lens = [end - start for start, end in result]
|
||||||
|
print("match_lens:", match_lens) # Only displayed on error
|
||||||
|
assert all(match_len == len(match_ids) for match_len in match_lens)
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("prompt", "target_by_key", "expected_by_key"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
[],
|
||||||
|
{
|
||||||
|
"pattern_1": [],
|
||||||
|
"pattern_2": [32000],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
|
||||||
|
"pattern_2": [],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[32000, 32000, 32000, 32000],
|
||||||
|
{
|
||||||
|
"pattern_1": [32000],
|
||||||
|
"pattern_2": [32000, 32000],
|
||||||
|
"pattern_3": [32000, 32000, 32000],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pattern_1": [
|
||||||
|
{ "start_idx": 0, "end_idx": 1 },
|
||||||
|
{ "start_idx": 1, "end_idx": 2 },
|
||||||
|
{ "start_idx": 2, "end_idx": 3 },
|
||||||
|
{ "start_idx": 3, "end_idx": 4 },
|
||||||
|
],
|
||||||
|
"pattern_2": [
|
||||||
|
{ "start_idx": 0, "end_idx": 2 },
|
||||||
|
{ "start_idx": 2, "end_idx": 4 },
|
||||||
|
],
|
||||||
|
"pattern_3": [
|
||||||
|
{ "start_idx": 0, "end_idx": 3 },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
|
{
|
||||||
|
"pattern_1": [28747, 32000],
|
||||||
|
"pattern_2": [28747, 32000, 32000, 32000],
|
||||||
|
"pattern_3": [28747, 0, 32000],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pattern_1": [
|
||||||
|
{ "start_idx": 1, "end_idx": 3 },
|
||||||
|
{ "start_idx": 6, "end_idx": 8 },
|
||||||
|
],
|
||||||
|
"pattern_2": [
|
||||||
|
{ "start_idx": 1, "end_idx": 5 },
|
||||||
|
],
|
||||||
|
"pattern_3": [],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# yapf: enable
|
||||||
|
def test_find_token_matches(prompt, target_by_key, expected_by_key):
|
||||||
|
# Should not be used since there is nothing to convert to token IDs
|
||||||
|
mock_tokenizer = cast(AnyTokenizer, object())
|
||||||
|
|
||||||
|
result = find_token_matches(
|
||||||
|
prompt,
|
||||||
|
[
|
||||||
|
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
||||||
|
for key, target in target_by_key.items()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only displayed on error
|
||||||
|
print("result:", result)
|
||||||
|
|
||||||
|
# Manually constructed results
|
||||||
|
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
|
||||||
|
assert {
|
||||||
|
key: [
|
||||||
|
dict(start_idx=item.start_idx, end_idx=item.end_idx)
|
||||||
|
for item in result_groups.get(key, [])
|
||||||
|
]
|
||||||
|
for key in expected_by_key
|
||||||
|
} == expected_by_key
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("prompt", "target_by_key", "expected_by_key"),
|
||||||
|
[
|
||||||
|
# Detokenized test cases of `test_find_token_matches`
|
||||||
|
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
|
||||||
|
(
|
||||||
|
"",
|
||||||
|
{
|
||||||
|
"pattern_1": "",
|
||||||
|
"pattern_2": "<image>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
|
||||||
|
"pattern_2": [],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"<image><image><image><image>",
|
||||||
|
{
|
||||||
|
"pattern_1": "<image>",
|
||||||
|
"pattern_2": "<image><image>",
|
||||||
|
"pattern_3": "<image><image><image>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pattern_1": [
|
||||||
|
{ "start_idx": 0, "end_idx": 7 },
|
||||||
|
{ "start_idx": 7, "end_idx": 14 },
|
||||||
|
{ "start_idx": 14, "end_idx": 21 },
|
||||||
|
{ "start_idx": 21, "end_idx": 28 },
|
||||||
|
],
|
||||||
|
"pattern_2": [
|
||||||
|
{ "start_idx": 0, "end_idx": 14 },
|
||||||
|
{ "start_idx": 14, "end_idx": 28 },
|
||||||
|
],
|
||||||
|
"pattern_3": [
|
||||||
|
{ "start_idx": 0, "end_idx": 21 },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Image:<image><image><image>Image:<image><image>!",
|
||||||
|
{
|
||||||
|
"pattern_1": "Image:<image>",
|
||||||
|
"pattern_2": "Image:<image><image><image>",
|
||||||
|
"pattern_3": "Image:<unk><image>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pattern_1": [
|
||||||
|
{ "start_idx": 0, "end_idx": 13 },
|
||||||
|
{ "start_idx": 27, "end_idx": 40 },
|
||||||
|
],
|
||||||
|
"pattern_2": [
|
||||||
|
{ "start_idx": 0, "end_idx": 27 },
|
||||||
|
],
|
||||||
|
"pattern_3": [],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# Test regex escape
|
||||||
|
(
|
||||||
|
"<|image|><image><|image|><image>",
|
||||||
|
{
|
||||||
|
"pattern_1": "<|image|>",
|
||||||
|
"pattern_2": "<|image|><image>",
|
||||||
|
"pattern_3": "<|image|><image><|image|>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pattern_1": [
|
||||||
|
{ "start_idx": 0, "end_idx": 9 },
|
||||||
|
{ "start_idx": 16, "end_idx": 25 },
|
||||||
|
],
|
||||||
|
"pattern_2": [
|
||||||
|
{ "start_idx": 0, "end_idx": 16 },
|
||||||
|
{ "start_idx": 16, "end_idx": 32 },
|
||||||
|
],
|
||||||
|
"pattern_3": [
|
||||||
|
{ "start_idx": 0, "end_idx": 25 },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# yapf: enable
|
||||||
|
def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||||
|
# Should not be used since there is nothing to convert to text
|
||||||
|
mock_tokenizer = cast(AnyTokenizer, object())
|
||||||
|
|
||||||
|
result = find_text_matches(
|
||||||
|
prompt,
|
||||||
|
[
|
||||||
|
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
||||||
|
for key, target in target_by_key.items()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only displayed on error
|
||||||
|
print("result:", result)
|
||||||
|
|
||||||
|
# Manually constructed results
|
||||||
|
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
|
||||||
|
assert {
|
||||||
|
key: [
|
||||||
|
dict(start_idx=item.start_idx, end_idx=item.end_idx)
|
||||||
|
for item in result_groups.get(key, [])
|
||||||
|
]
|
||||||
|
for key in expected_by_key
|
||||||
|
} == expected_by_key
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"Image:<image>Image:<image><image>!",
|
||||||
|
{
|
||||||
|
# We use `<image>` before `Image:` to test matches that
|
||||||
|
# occur out of order
|
||||||
|
"pattern_1": "<image>",
|
||||||
|
"pattern_2": "Image:",
|
||||||
|
"pattern_3": "!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Test whether target is confused with repl_unit
|
||||||
|
"pattern_1": ("<image><image>", 1),
|
||||||
|
# Test empty repl_unit
|
||||||
|
"pattern_2": ("", 1),
|
||||||
|
# Test multiple repl_count
|
||||||
|
"pattern_3": ("?", 2),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Test no replacement
|
||||||
|
0: "Image:<image>Image:<image><image>!",
|
||||||
|
# Test single replacement
|
||||||
|
1: "<image><image>Image:<image><image>??",
|
||||||
|
# Test repeated replacement
|
||||||
|
2: "<image><image><image><image><image>??",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# yapf: enable
|
||||||
|
def test_find_replace_text(
|
||||||
|
prompt,
|
||||||
|
target_by_key,
|
||||||
|
repl_by_key,
|
||||||
|
expected_by_mm_count,
|
||||||
|
):
|
||||||
|
# Should not be used since there is nothing to convert to text
|
||||||
|
mock_tokenizer = cast(AnyTokenizer, object())
|
||||||
|
|
||||||
|
matches = find_text_matches(
|
||||||
|
prompt,
|
||||||
|
[
|
||||||
|
PromptReplacement(target, *repl_by_key[key]) \
|
||||||
|
.bind(key, mock_tokenizer)
|
||||||
|
for key, target in target_by_key.items()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
result_by_mm_count = {
|
||||||
|
mm_count: replace_text_matches(
|
||||||
|
prompt,
|
||||||
|
matches,
|
||||||
|
{key: list(range(mm_count))
|
||||||
|
for key in repl_by_key},
|
||||||
|
BatchFeature(),
|
||||||
|
)
|
||||||
|
for mm_count in expected_by_mm_count
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only displayed on error
|
||||||
|
print("matches:", matches)
|
||||||
|
print("result_by_mm_count:", result_by_mm_count)
|
||||||
|
|
||||||
|
# Manually constructed results
|
||||||
|
assert result_by_mm_count == expected_by_mm_count
|
||||||
@ -139,7 +139,8 @@ def test_repeat_and_pad_placeholder_tokens(model):
|
|||||||
2,
|
2,
|
||||||
"<image><image><image>",
|
"<image><image><image>",
|
||||||
[32000, 32000, 32000],
|
[32000, 32000, 32000],
|
||||||
[{ "offset": 0, "length": 2 }]),
|
[{ "offset": 0, "length": 2 }],
|
||||||
|
),
|
||||||
(
|
(
|
||||||
"<image><image>",
|
"<image><image>",
|
||||||
[3, 2],
|
[3, 2],
|
||||||
|
|||||||
@ -203,14 +203,7 @@ class MultiModalInputsV2(TypedDict):
|
|||||||
"""The type of inputs."""
|
"""The type of inputs."""
|
||||||
|
|
||||||
prompt: str
|
prompt: str
|
||||||
"""
|
"""The processed prompt text."""
|
||||||
The original, unprocessed prompt text.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Since prompt text is not required by vLLM internals, we leave this
|
|
||||||
unprocessed to save CPU computation. You can still call
|
|
||||||
:code:`tokenizer.decode(prompt_token_ids)` to get the processed text.
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt_token_ids: List[int]
|
prompt_token_ids: List[int]
|
||||||
"""The processed token IDs which includes placeholder tokens."""
|
"""The processed token IDs which includes placeholder tokens."""
|
||||||
|
|||||||
@ -1,34 +1,91 @@
|
|||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache, partial
|
from functools import lru_cache
|
||||||
from typing import (Any, Callable, Collection, Generic, List, Mapping,
|
from itertools import groupby
|
||||||
Optional, TypedDict, TypeVar, final)
|
from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias, TypedDict
|
||||||
|
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.inputs import InputProcessingContext
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
|
||||||
|
|
||||||
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
|
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
|
||||||
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
|
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
|
||||||
VideoItem)
|
VideoItem)
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]]
|
def bind_prompt_sequence(
|
||||||
|
seq: Union[str, list[int]],
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
) -> "_BoundPromptSequence":
|
||||||
"""
|
"""
|
||||||
Given the original data item, HF-processed data, and index of the processed
|
Bind a text or token sequence to a tokenizer so that it can be
|
||||||
item, output the replacement token IDs to be allocated in vLLM.
|
lazily converted into the other format on demand.
|
||||||
"""
|
"""
|
||||||
|
return _BoundPromptSequence(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
_text=seq if isinstance(seq, str) else None,
|
||||||
|
_token_ids=seq if isinstance(seq, list) else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
_S = TypeVar("_S", str, list[int])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptReplacement(Generic[_S, _T]):
|
||||||
|
target: _S
|
||||||
|
"""The text or token sequence to find and replace."""
|
||||||
|
|
||||||
|
repl_unit: _S
|
||||||
|
"""
|
||||||
|
The unit making up the replacement text or token sequence.
|
||||||
|
|
||||||
|
See :code:`repl_count` for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
|
||||||
|
"""
|
||||||
|
Given the original multi-modal items for this modality, HF-processed data,
|
||||||
|
and index of the processed item, output the number of repetitions of
|
||||||
|
:code:`repl_unit` to build up the replacement text or token sequence.
|
||||||
|
|
||||||
|
For convenience, you can pass in an integer if the number of repetitions is
|
||||||
|
a constant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"{type(self).__name__}(target={self.target!r}, "
|
||||||
|
f"repl_unit={self.repl_unit!r})")
|
||||||
|
|
||||||
|
def bind(
|
||||||
|
self,
|
||||||
|
modality: str,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
) -> "_BoundPromptReplacement[_T]":
|
||||||
|
return _BoundPromptReplacement(
|
||||||
|
modality=modality,
|
||||||
|
target=bind_prompt_sequence(self.target, tokenizer),
|
||||||
|
repl_unit=bind_prompt_sequence(self.repl_unit, tokenizer),
|
||||||
|
repl_count=self.repl_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModalityProcessingMetadata(Generic[_T]):
|
class ModalityProcessingMetadata(Generic[_T]):
|
||||||
placeholder_replacements: Mapping[str, ReplacementFunc]
|
prompt_repls: Sequence[Union[PromptReplacement[str, _T],
|
||||||
|
PromptReplacement[list[int], _T]]]
|
||||||
"""
|
"""
|
||||||
A dictionary where each item represents the original placeholder in the
|
Defines each text or token sequence to replace in the HF-processed prompt.
|
||||||
prompt text and the corresponding replacement.
|
|
||||||
|
This is skipped if the HF-processed prompt is found to already contain
|
||||||
|
the replacement prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -52,46 +109,138 @@ Note:
|
|||||||
Read more on that :ref:`here <adding_multimodal_plugin>`.
|
Read more on that :ref:`here <adding_multimodal_plugin>`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MultiModalMultiData: TypeAlias = List[_T]
|
|
||||||
|
def _encode(
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
text: str,
|
||||||
|
*,
|
||||||
|
add_special_tokens: bool = False,
|
||||||
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
A list of data items, where the number of data items allowed
|
Backend-agnostic equivalent of HF's
|
||||||
per modality is restricted by :code:`--limit-mm-per-prompt`.
|
:code:`tokenizer.encode(text, add_special_tokens=...)`.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
|
return tokenizer.tokenizer.encode(text,
|
||||||
|
bos=add_special_tokens,
|
||||||
|
eos=add_special_tokens)
|
||||||
|
|
||||||
|
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
||||||
|
|
||||||
|
|
||||||
@final
|
@lru_cache(maxsize=2048)
|
||||||
class MultiModalMultiDataBuiltins(TypedDict, total=False):
|
def _cached_encode(
|
||||||
"""Type annotations for modality types predefined by vLLM."""
|
tokenizer: AnyTokenizer,
|
||||||
|
text: str,
|
||||||
image: MultiModalMultiData[ImageItem]
|
*,
|
||||||
"""The input images."""
|
add_special_tokens: bool = False,
|
||||||
|
) -> list[int]:
|
||||||
video: MultiModalMultiData[VideoItem]
|
return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
|
||||||
"""The input videos."""
|
|
||||||
|
|
||||||
audio: MultiModalMultiData[AudioItem]
|
|
||||||
"""The input audios."""
|
|
||||||
|
|
||||||
|
|
||||||
MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]]
|
def _decode(
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
token_ids: list[int],
|
||||||
|
*,
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
A dictionary containing an entry for each modality type to input.
|
Backend-agnostic equivalent of HF's
|
||||||
|
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
|
||||||
Note:
|
|
||||||
This dictionary also accepts modality keys defined outside
|
|
||||||
:class:`MultiModalMultiDataBuiltins` as long as a customized plugin
|
|
||||||
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
|
||||||
Read more on that :ref:`here <adding_multimodal_plugin>`.
|
|
||||||
"""
|
"""
|
||||||
|
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
|
||||||
def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
|
@lru_cache(maxsize=2048)
|
||||||
|
def _cached_decode(
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
token_ids: tuple[int, ...],
|
||||||
|
*,
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
) -> str:
|
||||||
|
return _decode(tokenizer,
|
||||||
|
list(token_ids),
|
||||||
|
skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
class _HasModalityAttr(Protocol):
|
||||||
|
modality: str
|
||||||
|
|
||||||
|
|
||||||
|
class _HasModalityProp(Protocol):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def modality(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])
|
||||||
|
|
||||||
|
|
||||||
|
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
|
||||||
|
"""Convenience function to apply :func:`full_groupby` based on modality."""
|
||||||
|
return full_groupby(values, key=lambda x: x.modality)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BoundPromptSequence:
|
||||||
|
tokenizer: AnyTokenizer
|
||||||
|
_text: Optional[str]
|
||||||
|
_token_ids: Optional[list[int]]
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self._text is None and self._token_ids is None:
|
||||||
|
raise ValueError("At least one of 'text' and 'token_ids' must be "
|
||||||
|
"specified")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
if self._text is None:
|
||||||
|
assert self._token_ids is not None
|
||||||
|
self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))
|
||||||
|
|
||||||
|
return self._text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_ids(self) -> list[int]:
|
||||||
|
if self._token_ids is None:
|
||||||
|
assert self._text is not None
|
||||||
|
self._token_ids = _cached_encode(self.tokenizer, self._text)
|
||||||
|
|
||||||
|
return self._token_ids
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"{type(self).__name__}(_text={self._text!r}, "
|
||||||
|
f"_token_ids={self._token_ids!r})")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BoundPromptReplacement(Generic[_T]):
|
||||||
|
modality: str
|
||||||
|
target: _BoundPromptSequence
|
||||||
|
repl_unit: _BoundPromptSequence
|
||||||
|
repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
|
||||||
|
|
||||||
|
def get_count(
|
||||||
|
self,
|
||||||
|
mm_items: list[_T],
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
item_idx: int,
|
||||||
|
) -> int:
|
||||||
|
repl_count = self.repl_count
|
||||||
|
if isinstance(repl_count, int):
|
||||||
|
return repl_count
|
||||||
|
|
||||||
|
return repl_count(mm_items, hf_inputs, item_idx)
|
||||||
|
|
||||||
|
|
||||||
|
def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
|
||||||
"""
|
"""
|
||||||
Convert a :class:`MultiModalDataDict` containing single data items
|
Convert a :class:`MultiModalDataDict` containing single data items
|
||||||
to a :class:`MultiModalMultiDataDict` containing multiple data items
|
to a :class:`MultiModalMultiDataDict` containing multiple data items
|
||||||
per entry.
|
per entry.
|
||||||
"""
|
"""
|
||||||
multi_data: Mapping[str, MultiModalMultiData[Any]] = {}
|
multi_data = dict[str, list[Any]]()
|
||||||
|
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -107,86 +256,279 @@ def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
|
|||||||
return multi_data
|
return multi_data
|
||||||
|
|
||||||
|
|
||||||
def encode_no_special_tokens(
|
class _TokenRun(NamedTuple):
|
||||||
tokenizer: AnyTokenizer,
|
token_id: int
|
||||||
text: str,
|
|
||||||
) -> List[int]:
|
start_idx: int
|
||||||
|
length: int
|
||||||
|
|
||||||
|
|
||||||
|
def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]:
|
||||||
"""
|
"""
|
||||||
Backend-agnostic equivalent of HF's
|
Yield the starting index and length of each run of tokens that are the same.
|
||||||
:code:`tokenizer.encode(text, add_special_tokens=False)`.
|
|
||||||
"""
|
"""
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
start_idx = 0
|
||||||
return tokenizer.tokenizer.encode(text, bos=False, eos=False)
|
|
||||||
|
|
||||||
return tokenizer.encode(text, add_special_tokens=False)
|
for token_id, it in groupby(token_ids):
|
||||||
|
length = sum(1 for _ in it)
|
||||||
|
yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length)
|
||||||
|
|
||||||
|
start_idx += length
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
class _PlaceholderInfo(NamedTuple):
|
||||||
def candidate_placeholders(
|
modality: str
|
||||||
tokenizer: AnyTokenizer,
|
offset: int
|
||||||
placeholder_text: str,
|
length: int
|
||||||
) -> Collection[List[int]]:
|
|
||||||
"""Generate token ID sequences that may represent a placeholder text."""
|
|
||||||
# When the placeholder text is not mapped to a special token ID,
|
|
||||||
# it may be tokenized differently based on whether it is at the start/end
|
|
||||||
# of the string. So, we go through each combination of whether the text
|
|
||||||
# is at the start and end boundaries of the string
|
|
||||||
|
|
||||||
# Matches the placeholder when it is in the middle of the string
|
def to_range(self) -> PlaceholderRange:
|
||||||
start_id, = encode_no_special_tokens(tokenizer, "a")
|
return PlaceholderRange(offset=self.offset, length=self.length)
|
||||||
end_id, = encode_no_special_tokens(tokenizer, "b")
|
|
||||||
|
|
||||||
candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text)
|
|
||||||
|
|
||||||
start_id_, *candidate_a = encode_no_special_tokens(
|
def iter_placeholders(
|
||||||
tokenizer,
|
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||||
f"a{placeholder_text}",
|
token_ids: list[int],
|
||||||
)
|
*,
|
||||||
assert start_id == start_id_
|
min_placeholder_count: int,
|
||||||
|
) -> Iterable[_PlaceholderInfo]:
|
||||||
start_id_, *candidate_ab, end_id_ = encode_no_special_tokens(
|
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
|
||||||
tokenizer,
|
placeholder_ids_by_modality = {
|
||||||
f"a{placeholder_text}b",
|
modality: {
|
||||||
)
|
token_id
|
||||||
assert start_id == start_id_ and end_id == end_id_
|
for prompt_repl in repls
|
||||||
|
for token_id in prompt_repl.repl_unit.token_ids
|
||||||
*candidate_b, end_id_ = encode_no_special_tokens(
|
}
|
||||||
tokenizer,
|
for modality, repls in full_groupby_modality(prompt_repls)
|
||||||
f"{placeholder_text}b",
|
|
||||||
)
|
|
||||||
assert end_id == end_id_
|
|
||||||
|
|
||||||
# Remove duplicates (need to convert to tuple to be hashable)
|
|
||||||
unique_candidates = {
|
|
||||||
tuple(c)
|
|
||||||
for c in [candidate_basic, candidate_a, candidate_ab, candidate_b]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Convert back to list
|
for run_info in iter_token_runs(token_ids):
|
||||||
return [list(c) for c in unique_candidates]
|
if run_info.length > min_placeholder_count:
|
||||||
|
for (modality,
|
||||||
|
placeholder_ids) in placeholder_ids_by_modality.items():
|
||||||
|
if run_info.token_id in placeholder_ids:
|
||||||
|
yield _PlaceholderInfo(
|
||||||
|
modality=modality,
|
||||||
|
offset=run_info.start_idx,
|
||||||
|
length=run_info.length,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_placeholders(
|
class _TokenMatch(NamedTuple):
|
||||||
token_ids: List[int],
|
start_idx: int
|
||||||
placeholder_ids: List[int],
|
end_idx: int
|
||||||
get_replacement_ids: Callable[[], List[int]],
|
|
||||||
) -> Optional[PlaceholderRange]:
|
|
||||||
|
def iter_token_matches(
|
||||||
|
token_ids: list[int],
|
||||||
|
match_ids: list[int],
|
||||||
|
) -> Iterable[_TokenMatch]:
|
||||||
|
"""Yield each occurrence of :code:`match_ids` in :code:`token_ids`."""
|
||||||
|
match_len = len(match_ids)
|
||||||
|
|
||||||
|
last_end_idx = 0
|
||||||
|
for start_idx in range(len(token_ids) - match_len + 1):
|
||||||
|
if start_idx < last_end_idx:
|
||||||
|
continue # Exclude overlapping matches
|
||||||
|
|
||||||
|
end_idx = start_idx + match_len
|
||||||
|
if token_ids[start_idx:end_idx] == match_ids:
|
||||||
|
yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
|
||||||
|
last_end_idx = end_idx
|
||||||
|
|
||||||
|
|
||||||
|
class _PromptReplacementMatch(ABC, Generic[_T, _S]):
|
||||||
|
prompt_repl: _BoundPromptReplacement[_T]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def modality(self) -> str:
|
||||||
|
return self.prompt_repl.modality
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def start_idx(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def end_idx(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_repl(
|
||||||
|
self,
|
||||||
|
mm_items: list[_T],
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
item_idx: int,
|
||||||
|
) -> _S:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"{type(self).__name__}(modality={self.modality!r}, "
|
||||||
|
f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(repr=False)
|
||||||
|
class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
|
||||||
|
prompt_repl: _BoundPromptReplacement[_T]
|
||||||
|
match: _TokenMatch
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_idx(self) -> int:
|
||||||
|
return self.match.start_idx
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_idx(self) -> int:
|
||||||
|
return self.match.end_idx
|
||||||
|
|
||||||
|
def get_repl(
|
||||||
|
self,
|
||||||
|
mm_items: list[_T],
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
item_idx: int,
|
||||||
|
) -> list[int]:
|
||||||
|
prompt_repl = self.prompt_repl
|
||||||
|
count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
|
||||||
|
return prompt_repl.repl_unit.token_ids * count
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(repr=False)
|
||||||
|
class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
|
||||||
|
prompt_repl: _BoundPromptReplacement[_T]
|
||||||
|
match: re.Match[str]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_idx(self) -> int:
|
||||||
|
return self.match.start()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_idx(self) -> int:
|
||||||
|
return self.match.end()
|
||||||
|
|
||||||
|
def get_repl(
|
||||||
|
self,
|
||||||
|
mm_items: list[_T],
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
item_idx: int,
|
||||||
|
) -> str:
|
||||||
|
prompt_repl = self.prompt_repl
|
||||||
|
count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
|
||||||
|
return prompt_repl.repl_unit.text * count
|
||||||
|
|
||||||
|
|
||||||
|
def find_token_matches(
|
||||||
|
prompt: list[int],
|
||||||
|
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
|
||||||
|
) -> list[_PromptReplacementTokenMatch[_T]]:
|
||||||
|
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
||||||
|
return [
|
||||||
|
_PromptReplacementTokenMatch(prompt_repl, match)
|
||||||
|
for prompt_repl in prompt_repls
|
||||||
|
for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def find_text_matches(
|
||||||
|
prompt: str,
|
||||||
|
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
|
||||||
|
) -> list[_PromptReplacementTextMatch[_T]]:
|
||||||
|
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
||||||
|
return [
|
||||||
|
_PromptReplacementTextMatch(prompt_repl, match)
|
||||||
|
for prompt_repl in prompt_repls
|
||||||
|
for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_matches(
|
||||||
|
prompt: _S,
|
||||||
|
matches: Sequence[_PromptReplacementMatch[_T, _S]],
|
||||||
|
) -> list[_PromptReplacementMatch[_T, _S]]:
|
||||||
"""
|
"""
|
||||||
Find the first occurrence of :code:`placeholder_ids`,
|
Resolve :code:`matches` to ensure that there are no overlapping matches,
|
||||||
and replace it with the output of :code:`get_replacement_ids`.
|
and sort them such that earlier matches take priority over later ones.
|
||||||
|
|
||||||
This function updates :code:`token_ids` in place.
|
|
||||||
"""
|
"""
|
||||||
placeholder_length = len(placeholder_ids)
|
num_matches_by_idx = np.zeros(len(prompt), dtype=int)
|
||||||
|
for match in matches:
|
||||||
|
num_matches_by_idx[match.start_idx:match.end_idx] += 1
|
||||||
|
|
||||||
for start_idx in range(len(token_ids) - placeholder_length + 1):
|
duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1)
|
||||||
if token_ids[start_idx:placeholder_length] == placeholder_ids:
|
if len(duplicate_matches_idxs) > 0:
|
||||||
token_ids[start_idx:placeholder_length] = get_replacement_ids()
|
raise ValueError("Unable to find a unique replacement "
|
||||||
|
f"at indices={duplicate_matches_idxs} "
|
||||||
|
f"of prompt={prompt}")
|
||||||
|
|
||||||
return PlaceholderRange(offset=start_idx,
|
return sorted(matches, key=lambda x: x.start_idx)
|
||||||
length=placeholder_length)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
def _replace_matches(
|
||||||
|
prompt: _S,
|
||||||
|
matches: Sequence[_PromptReplacementMatch[_T, _S]],
|
||||||
|
mm_items_by_modality: Mapping[str, list[_T]],
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
) -> list[_S]:
|
||||||
|
out_seqs = list[_S]()
|
||||||
|
prev_end_idx = 0
|
||||||
|
next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality}
|
||||||
|
|
||||||
|
for match in _resolve_matches(prompt, matches):
|
||||||
|
modality = match.modality
|
||||||
|
mm_items = mm_items_by_modality[modality]
|
||||||
|
|
||||||
|
item_idx = next_idx_by_modality[modality]
|
||||||
|
if item_idx >= len(mm_items):
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_idx = match.start_idx
|
||||||
|
end_idx = match.end_idx
|
||||||
|
repl_ids = match.get_repl(mm_items, hf_inputs, item_idx)
|
||||||
|
|
||||||
|
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids)
|
||||||
|
prev_end_idx = end_idx
|
||||||
|
next_idx_by_modality[modality] += 1
|
||||||
|
|
||||||
|
out_seqs.append(prompt[prev_end_idx:])
|
||||||
|
|
||||||
|
return out_seqs
|
||||||
|
|
||||||
|
|
||||||
|
def replace_token_matches(
|
||||||
|
prompt: list[int],
|
||||||
|
matches: Sequence[_PromptReplacementMatch[_T, list[int]]],
|
||||||
|
mm_items_by_modality: Mapping[str, list[_T]],
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
) -> list[int]:
|
||||||
|
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
||||||
|
if not matches:
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
token_id_seqs = _replace_matches(
|
||||||
|
prompt,
|
||||||
|
matches,
|
||||||
|
mm_items_by_modality,
|
||||||
|
hf_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return flatten_2d_lists(token_id_seqs)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_text_matches(
|
||||||
|
prompt: str,
|
||||||
|
matches: Sequence[_PromptReplacementMatch[_T, str]],
|
||||||
|
mm_items_by_modality: Mapping[str, list[_T]],
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
) -> str:
|
||||||
|
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
||||||
|
if not matches:
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
texts = _replace_matches(
|
||||||
|
prompt,
|
||||||
|
matches,
|
||||||
|
mm_items_by_modality,
|
||||||
|
hf_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return "".join(texts)
|
||||||
|
|
||||||
|
|
||||||
class MultiModalProcessor:
|
class MultiModalProcessor:
|
||||||
@ -212,62 +554,166 @@ class MultiModalProcessor:
|
|||||||
) -> MultiModalInputsV2:
|
) -> MultiModalInputsV2:
|
||||||
return self.apply(prompt, mm_data, mm_processor_kwargs)
|
return self.apply(prompt, mm_data, mm_processor_kwargs)
|
||||||
|
|
||||||
def apply(
|
def _find_placeholders(
|
||||||
|
self,
|
||||||
|
all_prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||||
|
new_token_ids: list[int],
|
||||||
|
*,
|
||||||
|
# To avoid false positives from multi-input when detecting
|
||||||
|
# whether placeholder tokens have been inserted, in case
|
||||||
|
# the target sequence is a subset of the replacement tokens
|
||||||
|
min_placeholder_count: int = 16,
|
||||||
|
) -> list[_PlaceholderInfo]:
|
||||||
|
return list(
|
||||||
|
iter_placeholders(
|
||||||
|
all_prompt_repls,
|
||||||
|
new_token_ids,
|
||||||
|
min_placeholder_count=min_placeholder_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
def _apply_hf_processor(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
mm_data: MultiModalDataDict,
|
mm_data: MultiModalDataDict,
|
||||||
mm_processor_kwargs: Mapping[str, object],
|
mm_processor_kwargs: Mapping[str, object],
|
||||||
) -> MultiModalInputsV2:
|
) -> BatchFeature:
|
||||||
tokenizer = self.ctx.tokenizer
|
|
||||||
hf_processor = self.ctx.get_hf_processor()
|
hf_processor = self.ctx.get_hf_processor()
|
||||||
|
|
||||||
processed_inputs = hf_processor(
|
return hf_processor(
|
||||||
text=prompt, # type: ignore
|
text=prompt, # type: ignore
|
||||||
**mm_data,
|
**mm_data,
|
||||||
**mm_processor_kwargs,
|
**mm_processor_kwargs,
|
||||||
)
|
)
|
||||||
new_token_ids, = processed_inputs.pop("input_ids").tolist()
|
|
||||||
mm_kwargs = MultiModalKwargs(processed_inputs)
|
|
||||||
|
|
||||||
mm_placeholders: Mapping[str, List[PlaceholderRange]] = {}
|
def _bind_prompt_replacements(
|
||||||
|
self,
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
) -> list[_BoundPromptReplacement[Any]]:
|
||||||
|
tokenizer = self.ctx.tokenizer
|
||||||
|
|
||||||
for modality, orig_inputs in to_multi_format(mm_data).items():
|
return [
|
||||||
assert isinstance(orig_inputs, list)
|
prompt_repl.bind(modality, tokenizer)
|
||||||
|
for modality, metadata in self.metadata.items()
|
||||||
|
if modality in mm_data for prompt_repl in metadata.prompt_repls
|
||||||
|
]
|
||||||
|
|
||||||
metadata = self.metadata[modality]
|
def _apply_prompt_replacements(
|
||||||
placeholder_replacements = metadata.placeholder_replacements
|
self,
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
token_ids: list[int],
|
||||||
|
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||||
|
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
||||||
|
tokenizer = self.ctx.tokenizer
|
||||||
|
|
||||||
modality_placeholders: List[PlaceholderRange] = []
|
mm_items = to_multi_format(mm_data)
|
||||||
|
token_matches = find_token_matches(token_ids, prompt_repls)
|
||||||
|
|
||||||
for item_idx, orig_item in enumerate(orig_inputs):
|
# If the search text does not represent a special token,
|
||||||
for match_text, replace_fn in placeholder_replacements.items():
|
# it may have different token IDs in the prompt, because
|
||||||
candidates = candidate_placeholders(tokenizer, match_text)
|
# the tokens may go across the boundaries of the search text.
|
||||||
get_replacement_ids = partial(
|
# ----
|
||||||
replace_fn,
|
# e.g. when searching for "foo" in "food", if "food" itself makes
|
||||||
orig_item,
|
# up a token, then the token ID of "foo" will not appear at all
|
||||||
processed_inputs,
|
# ----
|
||||||
item_idx,
|
# Since it is inefficient to search for all possible tokenizations
|
||||||
|
# of the search text in the prompt, we instead perform string
|
||||||
|
# replacement on the decoded token IDs, then encode them back.
|
||||||
|
if all(
|
||||||
|
len(matches) >= len(mm_data[modality])
|
||||||
|
for modality, matches in full_groupby_modality(token_matches)
|
||||||
|
): # yapf: disable
|
||||||
|
token_ids = replace_token_matches(
|
||||||
|
token_ids,
|
||||||
|
token_matches,
|
||||||
|
mm_items,
|
||||||
|
hf_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
for match_ids in candidates:
|
text = _decode(tokenizer, token_ids)
|
||||||
# TODO(youkaichao): Don't update new_token_ids
|
matched_repls = [match.prompt_repl for match in token_matches]
|
||||||
placeholders = apply_placeholders(
|
else:
|
||||||
new_token_ids,
|
text = _decode(tokenizer, token_ids)
|
||||||
match_ids,
|
|
||||||
get_replacement_ids,
|
text_matches = find_text_matches(text, prompt_repls)
|
||||||
|
text = replace_text_matches(
|
||||||
|
text,
|
||||||
|
text_matches,
|
||||||
|
mm_items,
|
||||||
|
hf_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if placeholders is not None:
|
token_ids = _encode(tokenizer, text)
|
||||||
modality_placeholders.append(placeholders)
|
matched_repls = [match.prompt_repl for match in text_matches]
|
||||||
|
|
||||||
# yapf: disable
|
placeholders = self._find_placeholders(matched_repls, token_ids)
|
||||||
mm_placeholders[modality] = modality_placeholders # type: ignore[index]
|
|
||||||
# yapf: enable
|
# Sanity check
|
||||||
|
assert len(placeholders) == len(matched_repls), dict(
|
||||||
|
# Log this information for easier debugging
|
||||||
|
text=text,
|
||||||
|
token_ids=token_ids,
|
||||||
|
placeholders=placeholders,
|
||||||
|
matched_repls=matched_repls,
|
||||||
|
)
|
||||||
|
|
||||||
|
return token_ids, text, placeholders
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
prompt_text: str,
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
mm_processor_kwargs: Mapping[str, object],
|
||||||
|
) -> MultiModalInputsV2:
|
||||||
|
"""
|
||||||
|
Process multi-modal inputs to be used in vLLM.
|
||||||
|
|
||||||
|
The main steps are:
|
||||||
|
|
||||||
|
1. Apply HF Processor on prompt text and multi-modal data together,
|
||||||
|
outputting token IDs and processed tensors.
|
||||||
|
2. Find and replace sequences in the token IDs with placeholder tokens.
|
||||||
|
The number of placeholder tokens equals the feature size of the
|
||||||
|
multi-modal data outputted by the multi-modal encoder.
|
||||||
|
3. Extract information about the placeholder tokens from the
|
||||||
|
processed token IDs.
|
||||||
|
"""
|
||||||
|
tokenizer = self.ctx.tokenizer
|
||||||
|
|
||||||
|
hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
|
||||||
|
mm_processor_kwargs)
|
||||||
|
prompt_ids, = hf_inputs.pop("input_ids").tolist()
|
||||||
|
mm_kwargs = MultiModalKwargs(hf_inputs)
|
||||||
|
|
||||||
|
all_prompt_repls = self._bind_prompt_replacements(mm_data)
|
||||||
|
|
||||||
|
# If HF processor already inserts placeholder tokens,
|
||||||
|
# there is no need for us to insert them
|
||||||
|
all_placeholders = self._find_placeholders(all_prompt_repls,
|
||||||
|
prompt_ids)
|
||||||
|
if all_placeholders:
|
||||||
|
prompt_text = _decode(tokenizer, prompt_ids)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
prompt_ids,
|
||||||
|
prompt_text,
|
||||||
|
all_placeholders,
|
||||||
|
) = self._apply_prompt_replacements(
|
||||||
|
mm_data,
|
||||||
|
hf_inputs,
|
||||||
|
prompt_ids,
|
||||||
|
all_prompt_repls,
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_placeholders = {
|
||||||
|
modality: [item.to_range() for item in items]
|
||||||
|
for modality, items in full_groupby_modality(all_placeholders)
|
||||||
|
}
|
||||||
|
|
||||||
return MultiModalInputsV2(
|
return MultiModalInputsV2(
|
||||||
type="multimodal",
|
type="multimodal",
|
||||||
prompt=prompt,
|
prompt=prompt_text,
|
||||||
prompt_token_ids=new_token_ids,
|
prompt_token_ids=prompt_ids,
|
||||||
mm_kwargs=mm_kwargs,
|
mm_kwargs=mm_kwargs,
|
||||||
mm_placeholders=mm_placeholders,
|
mm_placeholders=mm_placeholders,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -19,7 +19,8 @@ import uuid
|
|||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
||||||
from collections.abc import Mapping
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterable, Mapping
|
||||||
from functools import lru_cache, partial, wraps
|
from functools import lru_cache, partial, wraps
|
||||||
from platform import uname
|
from platform import uname
|
||||||
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
|
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
|
||||||
@ -905,6 +906,23 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
|
|||||||
return [item for sublist in lists for item in sublist]
|
return [item for sublist in lists for item in sublist]
|
||||||
|
|
||||||
|
|
||||||
|
_K = TypeVar("_K", bound=Hashable)
|
||||||
|
_V = TypeVar("_V")
|
||||||
|
|
||||||
|
|
||||||
|
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
|
||||||
|
"""
|
||||||
|
Unlike :class:`itertools.groupby`, groups are not broken by
|
||||||
|
non-contiguous data.
|
||||||
|
"""
|
||||||
|
groups = defaultdict[_K, list[_V]](list)
|
||||||
|
|
||||||
|
for value in values:
|
||||||
|
groups[key(value)].append(value)
|
||||||
|
|
||||||
|
return groups.items()
|
||||||
|
|
||||||
|
|
||||||
# TODO: This function can be removed if transformer_modules classes are
|
# TODO: This function can be removed if transformer_modules classes are
|
||||||
# serialized by value when communicating between processes
|
# serialized by value when communicating between processes
|
||||||
def init_cached_hf_modules() -> None:
|
def init_cached_hf_modules() -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user