mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 12:35:01 +08:00
[2/N] handling placeholders in merged multi-modal processor (#10485)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
4634a89d18
commit
c8acd80548
370
tests/multimodal/test_processing.py
Normal file
370
tests/multimodal/test_processing.py
Normal file
@ -0,0 +1,370 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.multimodal.processing import (PromptReplacement, find_text_matches,
|
||||
find_token_matches, iter_token_matches,
|
||||
iter_token_runs, replace_text_matches)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import full_groupby
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("token_ids", "expected"),
|
||||
[
|
||||
([], []),
|
||||
(
|
||||
[32000, 32000, 32000],
|
||||
[{ "token_id": 32000, "start_idx": 0, "length": 3 }],
|
||||
),
|
||||
(
|
||||
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
[
|
||||
{ "token_id": 9833, "start_idx": 0, "length": 1 },
|
||||
{ "token_id": 28747, "start_idx": 1, "length": 1 },
|
||||
{ "token_id": 32000, "start_idx": 2, "length": 3 },
|
||||
{ "token_id": 9833, "start_idx": 5, "length": 1 },
|
||||
{ "token_id": 28747, "start_idx": 6, "length": 1 },
|
||||
{ "token_id": 32000, "start_idx": 7, "length": 2 },
|
||||
{ "token_id": 918, "start_idx": 9, "length": 1 },
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_iter_token_runs(token_ids, expected):
|
||||
result = list(iter_token_runs(token_ids))
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
assert [item._asdict() for item in result] == expected
|
||||
|
||||
# Invariants
|
||||
assert sum(run_info.length for run_info in result) == len(token_ids)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("token_ids", "match_ids", "expected"),
|
||||
[
|
||||
([], [], [{ "start_idx": 0, "end_idx": 0 }]),
|
||||
([], [32000], []),
|
||||
(
|
||||
[32000, 32000, 32000],
|
||||
[32000],
|
||||
[
|
||||
{ "start_idx": 0, "end_idx": 1 },
|
||||
{ "start_idx": 1, "end_idx": 2 },
|
||||
{ "start_idx": 2, "end_idx": 3 },
|
||||
],
|
||||
),
|
||||
(
|
||||
[32000, 32000, 32000],
|
||||
[32000, 32000],
|
||||
[{ "start_idx": 0, "end_idx": 2 }],
|
||||
),
|
||||
(
|
||||
[32000, 32000, 32000],
|
||||
[32000, 32000, 32000],
|
||||
[{ "start_idx": 0, "end_idx": 3 }],
|
||||
),
|
||||
(
|
||||
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
[28747, 32000],
|
||||
[
|
||||
{ "start_idx": 1, "end_idx": 3 },
|
||||
{ "start_idx": 6, "end_idx": 8 },
|
||||
],
|
||||
),
|
||||
(
|
||||
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
[28747, 32000, 32000, 32000],
|
||||
[
|
||||
{ "start_idx": 1, "end_idx": 5 },
|
||||
],
|
||||
),
|
||||
(
|
||||
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
[28747, 0, 32000],
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
result = list(iter_token_matches(token_ids, match_ids))
|
||||
|
||||
# Manually constructed results
|
||||
assert [item._asdict() for item in result] == expected
|
||||
|
||||
# Invariants
|
||||
match_lens = [end - start for start, end in result]
|
||||
print("match_lens:", match_lens) # Only displayed on error
|
||||
assert all(match_len == len(match_ids) for match_len in match_lens)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("prompt", "target_by_key", "expected_by_key"),
|
||||
[
|
||||
(
|
||||
[],
|
||||
{
|
||||
"pattern_1": [],
|
||||
"pattern_2": [32000],
|
||||
},
|
||||
{
|
||||
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
|
||||
"pattern_2": [],
|
||||
}
|
||||
),
|
||||
(
|
||||
[32000, 32000, 32000, 32000],
|
||||
{
|
||||
"pattern_1": [32000],
|
||||
"pattern_2": [32000, 32000],
|
||||
"pattern_3": [32000, 32000, 32000],
|
||||
},
|
||||
{
|
||||
"pattern_1": [
|
||||
{ "start_idx": 0, "end_idx": 1 },
|
||||
{ "start_idx": 1, "end_idx": 2 },
|
||||
{ "start_idx": 2, "end_idx": 3 },
|
||||
{ "start_idx": 3, "end_idx": 4 },
|
||||
],
|
||||
"pattern_2": [
|
||||
{ "start_idx": 0, "end_idx": 2 },
|
||||
{ "start_idx": 2, "end_idx": 4 },
|
||||
],
|
||||
"pattern_3": [
|
||||
{ "start_idx": 0, "end_idx": 3 },
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
{
|
||||
"pattern_1": [28747, 32000],
|
||||
"pattern_2": [28747, 32000, 32000, 32000],
|
||||
"pattern_3": [28747, 0, 32000],
|
||||
},
|
||||
{
|
||||
"pattern_1": [
|
||||
{ "start_idx": 1, "end_idx": 3 },
|
||||
{ "start_idx": 6, "end_idx": 8 },
|
||||
],
|
||||
"pattern_2": [
|
||||
{ "start_idx": 1, "end_idx": 5 },
|
||||
],
|
||||
"pattern_3": [],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_find_token_matches(prompt, target_by_key, expected_by_key):
|
||||
# Should not be used since there is nothing to convert to token IDs
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
result = find_token_matches(
|
||||
prompt,
|
||||
[
|
||||
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
],
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
|
||||
assert {
|
||||
key: [
|
||||
dict(start_idx=item.start_idx, end_idx=item.end_idx)
|
||||
for item in result_groups.get(key, [])
|
||||
]
|
||||
for key in expected_by_key
|
||||
} == expected_by_key
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("prompt", "target_by_key", "expected_by_key"),
|
||||
[
|
||||
# Detokenized test cases of `test_find_token_matches`
|
||||
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
|
||||
(
|
||||
"",
|
||||
{
|
||||
"pattern_1": "",
|
||||
"pattern_2": "<image>",
|
||||
},
|
||||
{
|
||||
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
|
||||
"pattern_2": [],
|
||||
}
|
||||
),
|
||||
(
|
||||
"<image><image><image><image>",
|
||||
{
|
||||
"pattern_1": "<image>",
|
||||
"pattern_2": "<image><image>",
|
||||
"pattern_3": "<image><image><image>",
|
||||
},
|
||||
{
|
||||
"pattern_1": [
|
||||
{ "start_idx": 0, "end_idx": 7 },
|
||||
{ "start_idx": 7, "end_idx": 14 },
|
||||
{ "start_idx": 14, "end_idx": 21 },
|
||||
{ "start_idx": 21, "end_idx": 28 },
|
||||
],
|
||||
"pattern_2": [
|
||||
{ "start_idx": 0, "end_idx": 14 },
|
||||
{ "start_idx": 14, "end_idx": 28 },
|
||||
],
|
||||
"pattern_3": [
|
||||
{ "start_idx": 0, "end_idx": 21 },
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
"Image:<image><image><image>Image:<image><image>!",
|
||||
{
|
||||
"pattern_1": "Image:<image>",
|
||||
"pattern_2": "Image:<image><image><image>",
|
||||
"pattern_3": "Image:<unk><image>",
|
||||
},
|
||||
{
|
||||
"pattern_1": [
|
||||
{ "start_idx": 0, "end_idx": 13 },
|
||||
{ "start_idx": 27, "end_idx": 40 },
|
||||
],
|
||||
"pattern_2": [
|
||||
{ "start_idx": 0, "end_idx": 27 },
|
||||
],
|
||||
"pattern_3": [],
|
||||
},
|
||||
),
|
||||
# Test regex escape
|
||||
(
|
||||
"<|image|><image><|image|><image>",
|
||||
{
|
||||
"pattern_1": "<|image|>",
|
||||
"pattern_2": "<|image|><image>",
|
||||
"pattern_3": "<|image|><image><|image|>",
|
||||
},
|
||||
{
|
||||
"pattern_1": [
|
||||
{ "start_idx": 0, "end_idx": 9 },
|
||||
{ "start_idx": 16, "end_idx": 25 },
|
||||
],
|
||||
"pattern_2": [
|
||||
{ "start_idx": 0, "end_idx": 16 },
|
||||
{ "start_idx": 16, "end_idx": 32 },
|
||||
],
|
||||
"pattern_3": [
|
||||
{ "start_idx": 0, "end_idx": 25 },
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
result = find_text_matches(
|
||||
prompt,
|
||||
[
|
||||
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
],
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
|
||||
assert {
|
||||
key: [
|
||||
dict(start_idx=item.start_idx, end_idx=item.end_idx)
|
||||
for item in result_groups.get(key, [])
|
||||
]
|
||||
for key in expected_by_key
|
||||
} == expected_by_key
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"),
|
||||
[
|
||||
(
|
||||
"Image:<image>Image:<image><image>!",
|
||||
{
|
||||
# We use `<image>` before `Image:` to test matches that
|
||||
# occur out of order
|
||||
"pattern_1": "<image>",
|
||||
"pattern_2": "Image:",
|
||||
"pattern_3": "!",
|
||||
},
|
||||
{
|
||||
# Test whether target is confused with repl_unit
|
||||
"pattern_1": ("<image><image>", 1),
|
||||
# Test empty repl_unit
|
||||
"pattern_2": ("", 1),
|
||||
# Test multiple repl_count
|
||||
"pattern_3": ("?", 2),
|
||||
},
|
||||
{
|
||||
# Test no replacement
|
||||
0: "Image:<image>Image:<image><image>!",
|
||||
# Test single replacement
|
||||
1: "<image><image>Image:<image><image>??",
|
||||
# Test repeated replacement
|
||||
2: "<image><image><image><image><image>??",
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
def test_find_replace_text(
|
||||
prompt,
|
||||
target_by_key,
|
||||
repl_by_key,
|
||||
expected_by_mm_count,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
matches = find_text_matches(
|
||||
prompt,
|
||||
[
|
||||
PromptReplacement(target, *repl_by_key[key]) \
|
||||
.bind(key, mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
],
|
||||
)
|
||||
result_by_mm_count = {
|
||||
mm_count: replace_text_matches(
|
||||
prompt,
|
||||
matches,
|
||||
{key: list(range(mm_count))
|
||||
for key in repl_by_key},
|
||||
BatchFeature(),
|
||||
)
|
||||
for mm_count in expected_by_mm_count
|
||||
}
|
||||
|
||||
# Only displayed on error
|
||||
print("matches:", matches)
|
||||
print("result_by_mm_count:", result_by_mm_count)
|
||||
|
||||
# Manually constructed results
|
||||
assert result_by_mm_count == expected_by_mm_count
|
||||
@ -139,7 +139,8 @@ def test_repeat_and_pad_placeholder_tokens(model):
|
||||
2,
|
||||
"<image><image><image>",
|
||||
[32000, 32000, 32000],
|
||||
[{ "offset": 0, "length": 2 }]),
|
||||
[{ "offset": 0, "length": 2 }],
|
||||
),
|
||||
(
|
||||
"<image><image>",
|
||||
[3, 2],
|
||||
|
||||
@ -203,14 +203,7 @@ class MultiModalInputsV2(TypedDict):
|
||||
"""The type of inputs."""
|
||||
|
||||
prompt: str
|
||||
"""
|
||||
The original, unprocessed prompt text.
|
||||
|
||||
Note:
|
||||
Since prompt text is not required by vLLM internals, we leave this
|
||||
unprocessed to save CPU computation. You can still call
|
||||
:code:`tokenizer.decode(prompt_token_ids)` to get the processed text.
|
||||
"""
|
||||
"""The processed prompt text."""
|
||||
|
||||
prompt_token_ids: List[int]
|
||||
"""The processed token IDs which includes placeholder tokens."""
|
||||
|
||||
@ -1,34 +1,91 @@
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache, partial
|
||||
from typing import (Any, Callable, Collection, Generic, List, Mapping,
|
||||
Optional, TypedDict, TypeVar, final)
|
||||
from functools import lru_cache
|
||||
from itertools import groupby
|
||||
from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers import BatchFeature
|
||||
from typing_extensions import TypeAlias
|
||||
from typing_extensions import TypeAlias, TypedDict
|
||||
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
|
||||
|
||||
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
|
||||
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
|
||||
VideoItem)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]]
|
||||
"""
|
||||
Given the original data item, HF-processed data, and index of the processed
|
||||
item, output the replacement token IDs to be allocated in vLLM.
|
||||
"""
|
||||
def bind_prompt_sequence(
|
||||
seq: Union[str, list[int]],
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> "_BoundPromptSequence":
|
||||
"""
|
||||
Bind a text or token sequence to a tokenizer so that it can be
|
||||
lazily converted into the other format on demand.
|
||||
"""
|
||||
return _BoundPromptSequence(
|
||||
tokenizer=tokenizer,
|
||||
_text=seq if isinstance(seq, str) else None,
|
||||
_token_ids=seq if isinstance(seq, list) else None,
|
||||
)
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_S = TypeVar("_S", str, list[int])
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptReplacement(Generic[_S, _T]):
|
||||
target: _S
|
||||
"""The text or token sequence to find and replace."""
|
||||
|
||||
repl_unit: _S
|
||||
"""
|
||||
The unit making up the replacement text or token sequence.
|
||||
|
||||
See :code:`repl_count` for more details.
|
||||
"""
|
||||
|
||||
repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
|
||||
"""
|
||||
Given the original multi-modal items for this modality, HF-processed data,
|
||||
and index of the processed item, output the number of repetitions of
|
||||
:code:`repl_unit` to build up the replacement text or token sequence.
|
||||
|
||||
For convenience, you can pass in an integer if the number of repetitions is
|
||||
a constant.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(target={self.target!r}, "
|
||||
f"repl_unit={self.repl_unit!r})")
|
||||
|
||||
def bind(
|
||||
self,
|
||||
modality: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> "_BoundPromptReplacement[_T]":
|
||||
return _BoundPromptReplacement(
|
||||
modality=modality,
|
||||
target=bind_prompt_sequence(self.target, tokenizer),
|
||||
repl_unit=bind_prompt_sequence(self.repl_unit, tokenizer),
|
||||
repl_count=self.repl_count,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModalityProcessingMetadata(Generic[_T]):
|
||||
placeholder_replacements: Mapping[str, ReplacementFunc]
|
||||
prompt_repls: Sequence[Union[PromptReplacement[str, _T],
|
||||
PromptReplacement[list[int], _T]]]
|
||||
"""
|
||||
A dictionary where each item represents the original placeholder in the
|
||||
prompt text and the corresponding replacement.
|
||||
Defines each text or token sequence to replace in the HF-processed prompt.
|
||||
|
||||
This is skipped if the HF-processed prompt is found to already contain
|
||||
the replacement prompts.
|
||||
"""
|
||||
|
||||
|
||||
@ -52,46 +109,138 @@ Note:
|
||||
Read more on that :ref:`here <adding_multimodal_plugin>`.
|
||||
"""
|
||||
|
||||
MultiModalMultiData: TypeAlias = List[_T]
|
||||
"""
|
||||
A list of data items, where the number of data items allowed
|
||||
per modality is restricted by :code:`--limit-mm-per-prompt`.
|
||||
"""
|
||||
|
||||
def _encode(
|
||||
tokenizer: AnyTokenizer,
|
||||
text: str,
|
||||
*,
|
||||
add_special_tokens: bool = False,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Backend-agnostic equivalent of HF's
|
||||
:code:`tokenizer.encode(text, add_special_tokens=...)`.
|
||||
"""
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
return tokenizer.tokenizer.encode(text,
|
||||
bos=add_special_tokens,
|
||||
eos=add_special_tokens)
|
||||
|
||||
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
||||
|
||||
|
||||
@final
|
||||
class MultiModalMultiDataBuiltins(TypedDict, total=False):
|
||||
"""Type annotations for modality types predefined by vLLM."""
|
||||
|
||||
image: MultiModalMultiData[ImageItem]
|
||||
"""The input images."""
|
||||
|
||||
video: MultiModalMultiData[VideoItem]
|
||||
"""The input videos."""
|
||||
|
||||
audio: MultiModalMultiData[AudioItem]
|
||||
"""The input audios."""
|
||||
@lru_cache(maxsize=2048)
|
||||
def _cached_encode(
|
||||
tokenizer: AnyTokenizer,
|
||||
text: str,
|
||||
*,
|
||||
add_special_tokens: bool = False,
|
||||
) -> list[int]:
|
||||
return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
|
||||
|
||||
|
||||
MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]]
|
||||
"""
|
||||
A dictionary containing an entry for each modality type to input.
|
||||
|
||||
Note:
|
||||
This dictionary also accepts modality keys defined outside
|
||||
:class:`MultiModalMultiDataBuiltins` as long as a customized plugin
|
||||
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
Read more on that :ref:`here <adding_multimodal_plugin>`.
|
||||
"""
|
||||
def _decode(
|
||||
tokenizer: AnyTokenizer,
|
||||
token_ids: list[int],
|
||||
*,
|
||||
skip_special_tokens: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Backend-agnostic equivalent of HF's
|
||||
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
|
||||
"""
|
||||
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
||||
def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
|
||||
@lru_cache(maxsize=2048)
|
||||
def _cached_decode(
|
||||
tokenizer: AnyTokenizer,
|
||||
token_ids: tuple[int, ...],
|
||||
*,
|
||||
skip_special_tokens: bool = False,
|
||||
) -> str:
|
||||
return _decode(tokenizer,
|
||||
list(token_ids),
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
||||
class _HasModalityAttr(Protocol):
|
||||
modality: str
|
||||
|
||||
|
||||
class _HasModalityProp(Protocol):
|
||||
|
||||
@property
|
||||
def modality(self) -> str:
|
||||
...
|
||||
|
||||
|
||||
_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])
|
||||
|
||||
|
||||
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
|
||||
"""Convenience function to apply :func:`full_groupby` based on modality."""
|
||||
return full_groupby(values, key=lambda x: x.modality)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BoundPromptSequence:
|
||||
tokenizer: AnyTokenizer
|
||||
_text: Optional[str]
|
||||
_token_ids: Optional[list[int]]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self._text is None and self._token_ids is None:
|
||||
raise ValueError("At least one of 'text' and 'token_ids' must be "
|
||||
"specified")
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
if self._text is None:
|
||||
assert self._token_ids is not None
|
||||
self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))
|
||||
|
||||
return self._text
|
||||
|
||||
@property
|
||||
def token_ids(self) -> list[int]:
|
||||
if self._token_ids is None:
|
||||
assert self._text is not None
|
||||
self._token_ids = _cached_encode(self.tokenizer, self._text)
|
||||
|
||||
return self._token_ids
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(_text={self._text!r}, "
|
||||
f"_token_ids={self._token_ids!r})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BoundPromptReplacement(Generic[_T]):
|
||||
modality: str
|
||||
target: _BoundPromptSequence
|
||||
repl_unit: _BoundPromptSequence
|
||||
repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
|
||||
|
||||
def get_count(
|
||||
self,
|
||||
mm_items: list[_T],
|
||||
hf_inputs: BatchFeature,
|
||||
item_idx: int,
|
||||
) -> int:
|
||||
repl_count = self.repl_count
|
||||
if isinstance(repl_count, int):
|
||||
return repl_count
|
||||
|
||||
return repl_count(mm_items, hf_inputs, item_idx)
|
||||
|
||||
|
||||
def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Convert a :class:`MultiModalDataDict` containing single data items
|
||||
to a :class:`MultiModalMultiDataDict` containing multiple data items
|
||||
per entry.
|
||||
"""
|
||||
multi_data: Mapping[str, MultiModalMultiData[Any]] = {}
|
||||
multi_data = dict[str, list[Any]]()
|
||||
|
||||
for k, v in data.items():
|
||||
# yapf: disable
|
||||
@ -107,86 +256,279 @@ def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
|
||||
return multi_data
|
||||
|
||||
|
||||
def encode_no_special_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
text: str,
|
||||
) -> List[int]:
|
||||
class _TokenRun(NamedTuple):
|
||||
token_id: int
|
||||
|
||||
start_idx: int
|
||||
length: int
|
||||
|
||||
|
||||
def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]:
|
||||
"""
|
||||
Backend-agnostic equivalent of HF's
|
||||
:code:`tokenizer.encode(text, add_special_tokens=False)`.
|
||||
Yield the starting index and length of each run of tokens that are the same.
|
||||
"""
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
return tokenizer.tokenizer.encode(text, bos=False, eos=False)
|
||||
start_idx = 0
|
||||
|
||||
return tokenizer.encode(text, add_special_tokens=False)
|
||||
for token_id, it in groupby(token_ids):
|
||||
length = sum(1 for _ in it)
|
||||
yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length)
|
||||
|
||||
start_idx += length
|
||||
|
||||
|
||||
@lru_cache
|
||||
def candidate_placeholders(
|
||||
tokenizer: AnyTokenizer,
|
||||
placeholder_text: str,
|
||||
) -> Collection[List[int]]:
|
||||
"""Generate token ID sequences that may represent a placeholder text."""
|
||||
# When the placeholder text is not mapped to a special token ID,
|
||||
# it may be tokenized differently based on whether it is at the start/end
|
||||
# of the string. So, we go through each combination of whether the text
|
||||
# is at the start and end boundaries of the string
|
||||
class _PlaceholderInfo(NamedTuple):
|
||||
modality: str
|
||||
offset: int
|
||||
length: int
|
||||
|
||||
# Matches the placeholder when it is in the middle of the string
|
||||
start_id, = encode_no_special_tokens(tokenizer, "a")
|
||||
end_id, = encode_no_special_tokens(tokenizer, "b")
|
||||
def to_range(self) -> PlaceholderRange:
|
||||
return PlaceholderRange(offset=self.offset, length=self.length)
|
||||
|
||||
candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text)
|
||||
|
||||
start_id_, *candidate_a = encode_no_special_tokens(
|
||||
tokenizer,
|
||||
f"a{placeholder_text}",
|
||||
)
|
||||
assert start_id == start_id_
|
||||
|
||||
start_id_, *candidate_ab, end_id_ = encode_no_special_tokens(
|
||||
tokenizer,
|
||||
f"a{placeholder_text}b",
|
||||
)
|
||||
assert start_id == start_id_ and end_id == end_id_
|
||||
|
||||
*candidate_b, end_id_ = encode_no_special_tokens(
|
||||
tokenizer,
|
||||
f"{placeholder_text}b",
|
||||
)
|
||||
assert end_id == end_id_
|
||||
|
||||
# Remove duplicates (need to convert to tuple to be hashable)
|
||||
unique_candidates = {
|
||||
tuple(c)
|
||||
for c in [candidate_basic, candidate_a, candidate_ab, candidate_b]
|
||||
def iter_placeholders(
|
||||
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||
token_ids: list[int],
|
||||
*,
|
||||
min_placeholder_count: int,
|
||||
) -> Iterable[_PlaceholderInfo]:
|
||||
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
|
||||
placeholder_ids_by_modality = {
|
||||
modality: {
|
||||
token_id
|
||||
for prompt_repl in repls
|
||||
for token_id in prompt_repl.repl_unit.token_ids
|
||||
}
|
||||
for modality, repls in full_groupby_modality(prompt_repls)
|
||||
}
|
||||
|
||||
# Convert back to list
|
||||
return [list(c) for c in unique_candidates]
|
||||
for run_info in iter_token_runs(token_ids):
|
||||
if run_info.length > min_placeholder_count:
|
||||
for (modality,
|
||||
placeholder_ids) in placeholder_ids_by_modality.items():
|
||||
if run_info.token_id in placeholder_ids:
|
||||
yield _PlaceholderInfo(
|
||||
modality=modality,
|
||||
offset=run_info.start_idx,
|
||||
length=run_info.length,
|
||||
)
|
||||
|
||||
|
||||
def apply_placeholders(
|
||||
token_ids: List[int],
|
||||
placeholder_ids: List[int],
|
||||
get_replacement_ids: Callable[[], List[int]],
|
||||
) -> Optional[PlaceholderRange]:
|
||||
class _TokenMatch(NamedTuple):
|
||||
start_idx: int
|
||||
end_idx: int
|
||||
|
||||
|
||||
def iter_token_matches(
|
||||
token_ids: list[int],
|
||||
match_ids: list[int],
|
||||
) -> Iterable[_TokenMatch]:
|
||||
"""Yield each occurrence of :code:`match_ids` in :code:`token_ids`."""
|
||||
match_len = len(match_ids)
|
||||
|
||||
last_end_idx = 0
|
||||
for start_idx in range(len(token_ids) - match_len + 1):
|
||||
if start_idx < last_end_idx:
|
||||
continue # Exclude overlapping matches
|
||||
|
||||
end_idx = start_idx + match_len
|
||||
if token_ids[start_idx:end_idx] == match_ids:
|
||||
yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
|
||||
last_end_idx = end_idx
|
||||
|
||||
|
||||
class _PromptReplacementMatch(ABC, Generic[_T, _S]):
|
||||
prompt_repl: _BoundPromptReplacement[_T]
|
||||
|
||||
@property
|
||||
def modality(self) -> str:
|
||||
return self.prompt_repl.modality
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def start_idx(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def end_idx(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_repl(
|
||||
self,
|
||||
mm_items: list[_T],
|
||||
hf_inputs: BatchFeature,
|
||||
item_idx: int,
|
||||
) -> _S:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(modality={self.modality!r}, "
|
||||
f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
|
||||
prompt_repl: _BoundPromptReplacement[_T]
|
||||
match: _TokenMatch
|
||||
|
||||
@property
|
||||
def start_idx(self) -> int:
|
||||
return self.match.start_idx
|
||||
|
||||
@property
|
||||
def end_idx(self) -> int:
|
||||
return self.match.end_idx
|
||||
|
||||
def get_repl(
|
||||
self,
|
||||
mm_items: list[_T],
|
||||
hf_inputs: BatchFeature,
|
||||
item_idx: int,
|
||||
) -> list[int]:
|
||||
prompt_repl = self.prompt_repl
|
||||
count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
|
||||
return prompt_repl.repl_unit.token_ids * count
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
|
||||
prompt_repl: _BoundPromptReplacement[_T]
|
||||
match: re.Match[str]
|
||||
|
||||
@property
|
||||
def start_idx(self) -> int:
|
||||
return self.match.start()
|
||||
|
||||
@property
|
||||
def end_idx(self) -> int:
|
||||
return self.match.end()
|
||||
|
||||
def get_repl(
|
||||
self,
|
||||
mm_items: list[_T],
|
||||
hf_inputs: BatchFeature,
|
||||
item_idx: int,
|
||||
) -> str:
|
||||
prompt_repl = self.prompt_repl
|
||||
count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
|
||||
return prompt_repl.repl_unit.text * count
|
||||
|
||||
|
||||
def find_token_matches(
|
||||
prompt: list[int],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
|
||||
) -> list[_PromptReplacementTokenMatch[_T]]:
|
||||
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
||||
return [
|
||||
_PromptReplacementTokenMatch(prompt_repl, match)
|
||||
for prompt_repl in prompt_repls
|
||||
for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
|
||||
]
|
||||
|
||||
|
||||
def find_text_matches(
|
||||
prompt: str,
|
||||
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
|
||||
) -> list[_PromptReplacementTextMatch[_T]]:
|
||||
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
||||
return [
|
||||
_PromptReplacementTextMatch(prompt_repl, match)
|
||||
for prompt_repl in prompt_repls
|
||||
for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
|
||||
]
|
||||
|
||||
|
||||
def _resolve_matches(
|
||||
prompt: _S,
|
||||
matches: Sequence[_PromptReplacementMatch[_T, _S]],
|
||||
) -> list[_PromptReplacementMatch[_T, _S]]:
|
||||
"""
|
||||
Find the first occurrence of :code:`placeholder_ids`,
|
||||
and replace it with the output of :code:`get_replacement_ids`.
|
||||
|
||||
This function updates :code:`token_ids` in place.
|
||||
Resolve :code:`matches` to ensure that there are no overlapping matches,
|
||||
and sort them such that earlier matches take priority over later ones.
|
||||
"""
|
||||
placeholder_length = len(placeholder_ids)
|
||||
num_matches_by_idx = np.zeros(len(prompt), dtype=int)
|
||||
for match in matches:
|
||||
num_matches_by_idx[match.start_idx:match.end_idx] += 1
|
||||
|
||||
for start_idx in range(len(token_ids) - placeholder_length + 1):
|
||||
if token_ids[start_idx:placeholder_length] == placeholder_ids:
|
||||
token_ids[start_idx:placeholder_length] = get_replacement_ids()
|
||||
duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1)
|
||||
if len(duplicate_matches_idxs) > 0:
|
||||
raise ValueError("Unable to find a unique replacement "
|
||||
f"at indices={duplicate_matches_idxs} "
|
||||
f"of prompt={prompt}")
|
||||
|
||||
return PlaceholderRange(offset=start_idx,
|
||||
length=placeholder_length)
|
||||
return sorted(matches, key=lambda x: x.start_idx)
|
||||
|
||||
return None
|
||||
|
||||
def _replace_matches(
|
||||
prompt: _S,
|
||||
matches: Sequence[_PromptReplacementMatch[_T, _S]],
|
||||
mm_items_by_modality: Mapping[str, list[_T]],
|
||||
hf_inputs: BatchFeature,
|
||||
) -> list[_S]:
|
||||
out_seqs = list[_S]()
|
||||
prev_end_idx = 0
|
||||
next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality}
|
||||
|
||||
for match in _resolve_matches(prompt, matches):
|
||||
modality = match.modality
|
||||
mm_items = mm_items_by_modality[modality]
|
||||
|
||||
item_idx = next_idx_by_modality[modality]
|
||||
if item_idx >= len(mm_items):
|
||||
continue
|
||||
|
||||
start_idx = match.start_idx
|
||||
end_idx = match.end_idx
|
||||
repl_ids = match.get_repl(mm_items, hf_inputs, item_idx)
|
||||
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids)
|
||||
prev_end_idx = end_idx
|
||||
next_idx_by_modality[modality] += 1
|
||||
|
||||
out_seqs.append(prompt[prev_end_idx:])
|
||||
|
||||
return out_seqs
|
||||
|
||||
|
||||
def replace_token_matches(
|
||||
prompt: list[int],
|
||||
matches: Sequence[_PromptReplacementMatch[_T, list[int]]],
|
||||
mm_items_by_modality: Mapping[str, list[_T]],
|
||||
hf_inputs: BatchFeature,
|
||||
) -> list[int]:
|
||||
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
||||
if not matches:
|
||||
return prompt
|
||||
|
||||
token_id_seqs = _replace_matches(
|
||||
prompt,
|
||||
matches,
|
||||
mm_items_by_modality,
|
||||
hf_inputs,
|
||||
)
|
||||
|
||||
return flatten_2d_lists(token_id_seqs)
|
||||
|
||||
|
||||
def replace_text_matches(
|
||||
prompt: str,
|
||||
matches: Sequence[_PromptReplacementMatch[_T, str]],
|
||||
mm_items_by_modality: Mapping[str, list[_T]],
|
||||
hf_inputs: BatchFeature,
|
||||
) -> str:
|
||||
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
||||
if not matches:
|
||||
return prompt
|
||||
|
||||
texts = _replace_matches(
|
||||
prompt,
|
||||
matches,
|
||||
mm_items_by_modality,
|
||||
hf_inputs,
|
||||
)
|
||||
|
||||
return "".join(texts)
|
||||
|
||||
|
||||
class MultiModalProcessor:
|
||||
@ -212,62 +554,166 @@ class MultiModalProcessor:
|
||||
) -> MultiModalInputsV2:
|
||||
return self.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
|
||||
def apply(
|
||||
def _find_placeholders(
|
||||
self,
|
||||
all_prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||
new_token_ids: list[int],
|
||||
*,
|
||||
# To avoid false positives from multi-input when detecting
|
||||
# whether placeholder tokens have been inserted, in case
|
||||
# the target sequence is a subset of the replacement tokens
|
||||
min_placeholder_count: int = 16,
|
||||
) -> list[_PlaceholderInfo]:
|
||||
return list(
|
||||
iter_placeholders(
|
||||
all_prompt_repls,
|
||||
new_token_ids,
|
||||
min_placeholder_count=min_placeholder_count,
|
||||
))
|
||||
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
tokenizer = self.ctx.tokenizer
|
||||
) -> BatchFeature:
|
||||
hf_processor = self.ctx.get_hf_processor()
|
||||
|
||||
processed_inputs = hf_processor(
|
||||
return hf_processor(
|
||||
text=prompt, # type: ignore
|
||||
**mm_data,
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
new_token_ids, = processed_inputs.pop("input_ids").tolist()
|
||||
mm_kwargs = MultiModalKwargs(processed_inputs)
|
||||
|
||||
mm_placeholders: Mapping[str, List[PlaceholderRange]] = {}
|
||||
def _bind_prompt_replacements(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> list[_BoundPromptReplacement[Any]]:
|
||||
tokenizer = self.ctx.tokenizer
|
||||
|
||||
for modality, orig_inputs in to_multi_format(mm_data).items():
|
||||
assert isinstance(orig_inputs, list)
|
||||
return [
|
||||
prompt_repl.bind(modality, tokenizer)
|
||||
for modality, metadata in self.metadata.items()
|
||||
if modality in mm_data for prompt_repl in metadata.prompt_repls
|
||||
]
|
||||
|
||||
metadata = self.metadata[modality]
|
||||
placeholder_replacements = metadata.placeholder_replacements
|
||||
def _apply_prompt_replacements(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_inputs: BatchFeature,
|
||||
token_ids: list[int],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
||||
tokenizer = self.ctx.tokenizer
|
||||
|
||||
modality_placeholders: List[PlaceholderRange] = []
|
||||
mm_items = to_multi_format(mm_data)
|
||||
token_matches = find_token_matches(token_ids, prompt_repls)
|
||||
|
||||
for item_idx, orig_item in enumerate(orig_inputs):
|
||||
for match_text, replace_fn in placeholder_replacements.items():
|
||||
candidates = candidate_placeholders(tokenizer, match_text)
|
||||
get_replacement_ids = partial(
|
||||
replace_fn,
|
||||
orig_item,
|
||||
processed_inputs,
|
||||
item_idx,
|
||||
)
|
||||
# If the search text does not represent a special token,
|
||||
# it may have different token IDs in the prompt, because
|
||||
# the tokens may go across the boundaries of the search text.
|
||||
# ----
|
||||
# e.g. when searching for "foo" in "food", if "food" itself makes
|
||||
# up a token, then the token ID of "foo" will not appear at all
|
||||
# ----
|
||||
# Since it is inefficient to search for all possible tokenizations
|
||||
# of the search text in the prompt, we instead perform string
|
||||
# replacement on the decoded token IDs, then encode them back.
|
||||
if all(
|
||||
len(matches) >= len(mm_data[modality])
|
||||
for modality, matches in full_groupby_modality(token_matches)
|
||||
): # yapf: disable
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
token_matches,
|
||||
mm_items,
|
||||
hf_inputs,
|
||||
)
|
||||
|
||||
for match_ids in candidates:
|
||||
# TODO(youkaichao): Don't update new_token_ids
|
||||
placeholders = apply_placeholders(
|
||||
new_token_ids,
|
||||
match_ids,
|
||||
get_replacement_ids,
|
||||
)
|
||||
text = _decode(tokenizer, token_ids)
|
||||
matched_repls = [match.prompt_repl for match in token_matches]
|
||||
else:
|
||||
text = _decode(tokenizer, token_ids)
|
||||
|
||||
if placeholders is not None:
|
||||
modality_placeholders.append(placeholders)
|
||||
text_matches = find_text_matches(text, prompt_repls)
|
||||
text = replace_text_matches(
|
||||
text,
|
||||
text_matches,
|
||||
mm_items,
|
||||
hf_inputs,
|
||||
)
|
||||
|
||||
# yapf: disable
|
||||
mm_placeholders[modality] = modality_placeholders # type: ignore[index]
|
||||
# yapf: enable
|
||||
token_ids = _encode(tokenizer, text)
|
||||
matched_repls = [match.prompt_repl for match in text_matches]
|
||||
|
||||
placeholders = self._find_placeholders(matched_repls, token_ids)
|
||||
|
||||
# Sanity check
|
||||
assert len(placeholders) == len(matched_repls), dict(
|
||||
# Log this information for easier debugging
|
||||
text=text,
|
||||
token_ids=token_ids,
|
||||
placeholders=placeholders,
|
||||
matched_repls=matched_repls,
|
||||
)
|
||||
|
||||
return token_ids, text, placeholders
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
|
||||
The main steps are:
|
||||
|
||||
1. Apply HF Processor on prompt text and multi-modal data together,
|
||||
outputting token IDs and processed tensors.
|
||||
2. Find and replace sequences in the token IDs with placeholder tokens.
|
||||
The number of placeholder tokens equals the feature size of the
|
||||
multi-modal data outputted by the multi-modal encoder.
|
||||
3. Extract information about the placeholder tokens from the
|
||||
processed token IDs.
|
||||
"""
|
||||
tokenizer = self.ctx.tokenizer
|
||||
|
||||
hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
|
||||
mm_processor_kwargs)
|
||||
prompt_ids, = hf_inputs.pop("input_ids").tolist()
|
||||
mm_kwargs = MultiModalKwargs(hf_inputs)
|
||||
|
||||
all_prompt_repls = self._bind_prompt_replacements(mm_data)
|
||||
|
||||
# If HF processor already inserts placeholder tokens,
|
||||
# there is no need for us to insert them
|
||||
all_placeholders = self._find_placeholders(all_prompt_repls,
|
||||
prompt_ids)
|
||||
if all_placeholders:
|
||||
prompt_text = _decode(tokenizer, prompt_ids)
|
||||
else:
|
||||
(
|
||||
prompt_ids,
|
||||
prompt_text,
|
||||
all_placeholders,
|
||||
) = self._apply_prompt_replacements(
|
||||
mm_data,
|
||||
hf_inputs,
|
||||
prompt_ids,
|
||||
all_prompt_repls,
|
||||
)
|
||||
|
||||
mm_placeholders = {
|
||||
modality: [item.to_range() for item in items]
|
||||
for modality, items in full_groupby_modality(all_placeholders)
|
||||
}
|
||||
|
||||
return MultiModalInputsV2(
|
||||
type="multimodal",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=new_token_ids,
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
||||
|
||||
@ -19,7 +19,8 @@ import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
||||
from collections.abc import Mapping
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
|
||||
@ -905,6 +906,23 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
|
||||
return [item for sublist in lists for item in sublist]
|
||||
|
||||
|
||||
_K = TypeVar("_K", bound=Hashable)
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
|
||||
"""
|
||||
Unlike :class:`itertools.groupby`, groups are not broken by
|
||||
non-contiguous data.
|
||||
"""
|
||||
groups = defaultdict[_K, list[_V]](list)
|
||||
|
||||
for value in values:
|
||||
groups[key(value)].append(value)
|
||||
|
||||
return groups.items()
|
||||
|
||||
|
||||
# TODO: This function can be removed if transformer_modules classes are
|
||||
# serialized by value when communicating between processes
|
||||
def init_cached_hf_modules() -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user