mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[VLM][Bugfix] Enable specifying prompt target via index (#14038)
This commit is contained in:
parent
e0734387fb
commit
f7bee5c815
@ -14,8 +14,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
||||
PromptInsertion, PromptReplacement,
|
||||
apply_text_matches,
|
||||
PromptIndexTargets, PromptInsertion,
|
||||
PromptReplacement, apply_text_matches,
|
||||
apply_token_matches,
|
||||
find_mm_placeholders,
|
||||
find_text_matches, find_token_matches,
|
||||
@ -98,10 +98,20 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
{
|
||||
"pattern_1": [],
|
||||
"pattern_2": [32000],
|
||||
"pattern_3": PromptIndexTargets.start(),
|
||||
"pattern_4": PromptIndexTargets.prefix([32000]),
|
||||
"pattern_5": PromptIndexTargets.end(),
|
||||
},
|
||||
{
|
||||
"pattern_1": [],
|
||||
"pattern_2": [],
|
||||
"pattern_3": [
|
||||
{ "start_idx": 0, "end_idx": 0 },
|
||||
],
|
||||
"pattern_4": [],
|
||||
"pattern_5": [
|
||||
{ "start_idx": 0, "end_idx": 0 },
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
@ -110,6 +120,9 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
"pattern_1": [32000],
|
||||
"pattern_2": [32000, 32000],
|
||||
"pattern_3": [32000, 32000, 32000],
|
||||
"pattern_4": PromptIndexTargets.start(),
|
||||
"pattern_5": PromptIndexTargets.prefix([32000]),
|
||||
"pattern_6": PromptIndexTargets.end(),
|
||||
},
|
||||
{
|
||||
"pattern_1": [
|
||||
@ -125,6 +138,15 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
"pattern_3": [
|
||||
{ "start_idx": 0, "end_idx": 3 },
|
||||
],
|
||||
"pattern_4": [
|
||||
{ "start_idx": 0, "end_idx": 0 },
|
||||
],
|
||||
"pattern_5": [
|
||||
{ "start_idx": 1, "end_idx": 1 },
|
||||
],
|
||||
"pattern_6": [
|
||||
{ "start_idx": 4, "end_idx": 4 },
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
@ -133,6 +155,9 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
"pattern_1": [28747, 32000],
|
||||
"pattern_2": [28747, 32000, 32000, 32000],
|
||||
"pattern_3": [28747, 0, 32000],
|
||||
"pattern_4": PromptIndexTargets.start(),
|
||||
"pattern_5": PromptIndexTargets.prefix([28747, 32000]),
|
||||
"pattern_6": PromptIndexTargets.end(),
|
||||
},
|
||||
{
|
||||
"pattern_1": [
|
||||
@ -143,6 +168,13 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
{ "start_idx": 1, "end_idx": 5 },
|
||||
],
|
||||
"pattern_3": [],
|
||||
"pattern_4": [
|
||||
{ "start_idx": 0, "end_idx": 0 },
|
||||
],
|
||||
"pattern_5": [],
|
||||
"pattern_6": [
|
||||
{ "start_idx": 10, "end_idx": 10 },
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
@ -189,10 +221,20 @@ def test_find_token_matches(
|
||||
{
|
||||
"pattern_1": "",
|
||||
"pattern_2": "<image>",
|
||||
"pattern_3": PromptIndexTargets.start(),
|
||||
"pattern_4": PromptIndexTargets.prefix("<image>"),
|
||||
"pattern_5": PromptIndexTargets.end(),
|
||||
},
|
||||
{
|
||||
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
|
||||
"pattern_2": [],
|
||||
"pattern_3": [
|
||||
{ "start_idx": 0, "end_idx": 0 },
|
||||
],
|
||||
"pattern_4": [],
|
||||
"pattern_5": [
|
||||
{ "start_idx": 0, "end_idx": 0 },
|
||||
],
|
||||
}
|
||||
),
|
||||
(
|
||||
@ -201,6 +243,9 @@ def test_find_token_matches(
|
||||
"pattern_1": "<image>",
|
||||
"pattern_2": "<image><image>",
|
||||
"pattern_3": "<image><image><image>",
|
||||
"pattern_4": PromptIndexTargets.start(),
|
||||
"pattern_5": PromptIndexTargets.prefix("<image>"),
|
||||
"pattern_6": PromptIndexTargets.end(),
|
||||
},
|
||||
{
|
||||
"pattern_1": [
|
||||
@ -216,6 +261,15 @@ def test_find_token_matches(
|
||||
"pattern_3": [
|
||||
{ "start_idx": 0, "end_idx": 21 },
|
||||
],
|
||||
"pattern_4": [
|
||||
{ "start_idx": 0, "end_idx": 0 },
|
||||
],
|
||||
"pattern_5": [
|
||||
{ "start_idx": 7, "end_idx": 7 },
|
||||
],
|
||||
"pattern_6": [
|
||||
{ "start_idx": 28, "end_idx": 28 },
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
@ -224,6 +278,9 @@ def test_find_token_matches(
|
||||
"pattern_1": "Image:<image>",
|
||||
"pattern_2": "Image:<image><image><image>",
|
||||
"pattern_3": "Image:<unk><image>",
|
||||
"pattern_4": PromptIndexTargets.start(),
|
||||
"pattern_5": PromptIndexTargets.prefix("Image:<image>"),
|
||||
"pattern_6": PromptIndexTargets.end(),
|
||||
},
|
||||
{
|
||||
"pattern_1": [
|
||||
@ -234,6 +291,15 @@ def test_find_token_matches(
|
||||
{ "start_idx": 0, "end_idx": 27 },
|
||||
],
|
||||
"pattern_3": [],
|
||||
"pattern_4": [
|
||||
{ "start_idx": 0, "end_idx": 0 },
|
||||
],
|
||||
"pattern_5": [
|
||||
{ "start_idx": 13, "end_idx": 13 },
|
||||
],
|
||||
"pattern_6": [
|
||||
{ "start_idx": 48, "end_idx": 48 },
|
||||
],
|
||||
},
|
||||
),
|
||||
# Test regex escape
|
||||
@ -325,6 +391,100 @@ def test_find_text_matches(
|
||||
},
|
||||
},
|
||||
),
|
||||
# Test index targets
|
||||
(
|
||||
"",
|
||||
{
|
||||
"pattern_1": PromptIndexTargets.start(),
|
||||
"pattern_2": PromptIndexTargets.prefix("<image>"),
|
||||
"pattern_3": PromptIndexTargets.end(),
|
||||
},
|
||||
{
|
||||
"pattern_1": "1",
|
||||
"pattern_2": "2",
|
||||
"pattern_3": "3",
|
||||
},
|
||||
{
|
||||
PromptInsertion: {
|
||||
0: "",
|
||||
1: "13",
|
||||
2: "1133",
|
||||
},
|
||||
PromptReplacement: {
|
||||
0: "",
|
||||
1: "13",
|
||||
2: "1133",
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
"<image>",
|
||||
{
|
||||
"pattern_1": PromptIndexTargets.start(),
|
||||
"pattern_2": PromptIndexTargets.prefix("<image>"),
|
||||
"pattern_3": PromptIndexTargets.end(),
|
||||
},
|
||||
{
|
||||
"pattern_1": "1",
|
||||
"pattern_2": "2",
|
||||
"pattern_3": "3",
|
||||
},
|
||||
{
|
||||
PromptInsertion: {
|
||||
0: "<image>",
|
||||
1: "1<image>23",
|
||||
2: "11<image>2233",
|
||||
},
|
||||
PromptReplacement: {
|
||||
0: "<image>",
|
||||
1: "1<image>23",
|
||||
2: "11<image>2233",
|
||||
},
|
||||
},
|
||||
),
|
||||
# Test different replacement per item
|
||||
(
|
||||
"<image><image><image>",
|
||||
{
|
||||
"pattern_1": "<image>",
|
||||
},
|
||||
{
|
||||
"pattern_1": lambda idx: str(idx + 1),
|
||||
},
|
||||
{
|
||||
PromptInsertion: {
|
||||
0: "<image><image><image>",
|
||||
1: "<image>1<image><image>",
|
||||
2: "<image>12<image><image>",
|
||||
},
|
||||
PromptReplacement: {
|
||||
0: "<image><image><image>",
|
||||
1: "1<image><image>",
|
||||
2: "12<image>",
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
"<image><image><image>",
|
||||
{
|
||||
"pattern_1": PromptIndexTargets.prefix("<image>"),
|
||||
},
|
||||
{
|
||||
"pattern_1": lambda idx: str(idx + 1),
|
||||
},
|
||||
{
|
||||
PromptInsertion: {
|
||||
0: "<image><image><image>",
|
||||
1: "<image>1<image><image>",
|
||||
2: "<image>12<image><image>",
|
||||
},
|
||||
PromptReplacement: {
|
||||
0: "<image><image><image>",
|
||||
1: "<image>1<image><image>",
|
||||
2: "<image>12<image><image>",
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
@ -405,6 +565,100 @@ def test_find_update_text(
|
||||
},
|
||||
},
|
||||
),
|
||||
# Test index targets
|
||||
(
|
||||
[],
|
||||
{
|
||||
"pattern_1": PromptIndexTargets.start(),
|
||||
"pattern_2": PromptIndexTargets.prefix([32000]),
|
||||
"pattern_3": PromptIndexTargets.end(),
|
||||
},
|
||||
{
|
||||
"pattern_1": [-1],
|
||||
"pattern_2": [-2],
|
||||
"pattern_3": [-3],
|
||||
},
|
||||
{
|
||||
PromptInsertion: {
|
||||
0: [],
|
||||
1: [-1, -3],
|
||||
2: [-1, -1, -3, -3],
|
||||
},
|
||||
PromptReplacement: {
|
||||
0: [],
|
||||
1: [-1, -3],
|
||||
2: [-1, -1, -3, -3],
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
[32000],
|
||||
{
|
||||
"pattern_1": PromptIndexTargets.start(),
|
||||
"pattern_2": PromptIndexTargets.prefix([32000]),
|
||||
"pattern_3": PromptIndexTargets.end(),
|
||||
},
|
||||
{
|
||||
"pattern_1": [-1],
|
||||
"pattern_2": [-2],
|
||||
"pattern_3": [-3],
|
||||
},
|
||||
{
|
||||
PromptInsertion: {
|
||||
0: [32000],
|
||||
1: [-1, 32000, -2, -3],
|
||||
2: [-1, -1, 32000, -2, -2, -3, -3],
|
||||
},
|
||||
PromptReplacement: {
|
||||
0: [32000],
|
||||
1: [-1, 32000, -2, -3],
|
||||
2: [-1, -1, 32000, -2, -2, -3, -3],
|
||||
},
|
||||
},
|
||||
),
|
||||
# Test different replacement per item
|
||||
(
|
||||
[32000, 32000, 32000],
|
||||
{
|
||||
"pattern_1": [32000],
|
||||
},
|
||||
{
|
||||
"pattern_1": lambda idx: [-(idx + 1)],
|
||||
},
|
||||
{
|
||||
PromptInsertion: {
|
||||
0: [32000, 32000, 32000],
|
||||
1: [32000, -1, 32000, 32000],
|
||||
2: [32000, -1, -2, 32000, 32000],
|
||||
},
|
||||
PromptReplacement: {
|
||||
0: [32000, 32000, 32000],
|
||||
1: [-1, 32000, 32000],
|
||||
2: [-1, -2, 32000],
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
[32000, 32000, 32000],
|
||||
{
|
||||
"pattern_1": PromptIndexTargets.prefix([32000]),
|
||||
},
|
||||
{
|
||||
"pattern_1": lambda idx: [-(idx + 1)],
|
||||
},
|
||||
{
|
||||
PromptInsertion: {
|
||||
0: [32000, 32000, 32000],
|
||||
1: [32000, -1, 32000, 32000],
|
||||
2: [32000, -1, -2, 32000, 32000],
|
||||
},
|
||||
PromptReplacement: {
|
||||
0: [32000, 32000, 32000],
|
||||
1: [32000, -1, 32000, 32000],
|
||||
2: [32000, -1, -2, 32000, 32000],
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
|
||||
@ -19,8 +19,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptInsertion,
|
||||
PromptUpdate)
|
||||
BaseProcessingInfo, PromptIndexTargets,
|
||||
PromptInsertion, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -490,7 +490,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
|
||||
return [
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target="",
|
||||
target=PromptIndexTargets.start(),
|
||||
insertion=image_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
@ -25,7 +25,8 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptInsertion, PromptUpdate)
|
||||
PromptIndexTargets, PromptInsertion,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -864,7 +865,7 @@ class Florence2MultiModalProcessor(
|
||||
return [
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target="",
|
||||
target=PromptIndexTargets.start(),
|
||||
insertion=image_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptInsertion,
|
||||
PromptUpdate)
|
||||
BaseProcessingInfo, PromptIndexTargets,
|
||||
PromptInsertion, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import JSONTree, json_map_leaves
|
||||
@ -1371,7 +1371,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
return [
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target="<|endoftext|>",
|
||||
target=PromptIndexTargets.prefix("<|endoftext|>"),
|
||||
insertion=get_insertion_molmo,
|
||||
)
|
||||
]
|
||||
|
||||
@ -8,7 +8,6 @@ from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from itertools import groupby
|
||||
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
|
||||
TypeVar, Union, cast)
|
||||
|
||||
@ -40,6 +39,65 @@ PromptSeq = Union[str, list[int]]
|
||||
"""A token sequence (list of token IDs) or text."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptIndex:
|
||||
"""Resolves to an index in the prompt."""
|
||||
get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]]
|
||||
|
||||
|
||||
class PromptIndexTargets:
|
||||
|
||||
@staticmethod
|
||||
def start() -> PromptIndex:
|
||||
"""
|
||||
Resolves to the start of the prompt (before the first token).
|
||||
|
||||
This results in a match even if the prompt is empty.
|
||||
"""
|
||||
return PromptIndex(lambda tok, prompt: 0)
|
||||
|
||||
@staticmethod
|
||||
def prefix(seq: PromptSeq) -> PromptIndex:
|
||||
"""
|
||||
Resolves to a location in the prompt after the given prefix.
|
||||
"""
|
||||
|
||||
def get_match_index(
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: PromptSeq,
|
||||
) -> Optional[int]:
|
||||
prefix = seq
|
||||
|
||||
if isinstance(prompt, str):
|
||||
if not isinstance(prefix, str):
|
||||
# Make both `str`
|
||||
prefix = decode_tokens(tokenizer, prefix)
|
||||
else:
|
||||
if isinstance(prefix, str):
|
||||
# Make both `list[int]`
|
||||
prefix = encode_tokens(tokenizer, prefix)
|
||||
|
||||
match_idx = len(prefix)
|
||||
return match_idx if prompt[:match_idx] == prefix else None
|
||||
|
||||
return PromptIndex(get_match_index)
|
||||
|
||||
@staticmethod
|
||||
def end() -> PromptIndex:
|
||||
"""
|
||||
Resolves to the end of the prompt (after the last token).
|
||||
|
||||
This results in a match even if the prompt is empty.
|
||||
"""
|
||||
return PromptIndex(lambda tok, prompt: len(prompt))
|
||||
|
||||
|
||||
PromptTarget = Union[PromptSeq, PromptIndex]
|
||||
"""
|
||||
The token sequence or text to update.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptUpdateDetails:
|
||||
"""Details about the token sequence or text that are part of the update."""
|
||||
@ -84,7 +142,7 @@ class UpdateMode(str, Enum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptUpdate:
|
||||
class PromptUpdate(ABC):
|
||||
"""
|
||||
Defines how to update a prompt with placeholder tokens.
|
||||
"""
|
||||
@ -92,7 +150,7 @@ class PromptUpdate:
|
||||
modality: str
|
||||
"""The modality for which the update is made."""
|
||||
|
||||
target: PromptSeq
|
||||
target: PromptTarget
|
||||
"""The token sequence (or text) to update."""
|
||||
|
||||
@property
|
||||
@ -122,18 +180,7 @@ class PromptInsertion(PromptUpdate):
|
||||
Example:
|
||||
|
||||
For each image, insert a number of ``<image>`` feature placeholders
|
||||
equal to the feature size of the vision encoder at the start of the
|
||||
prompt:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target="",
|
||||
insertion="<image>" * image_feature_size,
|
||||
)
|
||||
|
||||
As above, but insert after the ``<s>`` token:
|
||||
equal to the feature size of the vision encoder after the ``<s>`` token:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -142,6 +189,36 @@ class PromptInsertion(PromptUpdate):
|
||||
target="<s>",
|
||||
insertion="<image>" * image_feature_size,
|
||||
)
|
||||
|
||||
Insert these tokens at the start of the prompt:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target=PromptIndexTargets.start(),
|
||||
insertion="<image>" * image_feature_size,
|
||||
)
|
||||
|
||||
Insert these tokens after a prefix ``Images:``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target=PromptIndexTargets.prefix("Images:"),
|
||||
insertion="<image>" * image_feature_size,
|
||||
)
|
||||
|
||||
Insert these tokens at the end of the prompt:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target=PromptIndexTargets.end(),
|
||||
insertion="<image>" * image_feature_size,
|
||||
)
|
||||
"""
|
||||
|
||||
insertion: PromptUpdateContent = field(repr=False)
|
||||
@ -345,10 +422,14 @@ class BoundPromptUpdate:
|
||||
return self._origin.modality
|
||||
|
||||
@property
|
||||
def target(self) -> _BoundPromptSequence:
|
||||
def target(self) -> Union[_BoundPromptSequence, PromptIndex]:
|
||||
"""The token sequence (or text) to update."""
|
||||
return _BoundPromptSequence.from_seq(self.tokenizer,
|
||||
self._origin.target)
|
||||
target = self._origin.target
|
||||
|
||||
if isinstance(target, PromptIndex):
|
||||
return target
|
||||
|
||||
return _BoundPromptSequence.from_seq(self.tokenizer, target)
|
||||
|
||||
@property
|
||||
def content(self) -> PromptUpdateContent:
|
||||
@ -447,6 +528,19 @@ class _PromptTargetMatch(ABC):
|
||||
f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptTargetIndexMatch(_PromptTargetMatch):
|
||||
match_idx: int
|
||||
|
||||
@property
|
||||
def start_idx(self) -> int:
|
||||
return self.match_idx
|
||||
|
||||
@property
|
||||
def end_idx(self) -> int:
|
||||
return self.match_idx
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptTargetTokenMatch(_PromptTargetMatch):
|
||||
match: _TokenMatch
|
||||
@ -496,9 +590,24 @@ def find_token_matches(
|
||||
prompt_updates: Sequence[BoundPromptUpdate],
|
||||
) -> Sequence[_PromptTargetMatch]:
|
||||
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
|
||||
|
||||
def get_matches(update: BoundPromptUpdate):
|
||||
target = update.target
|
||||
|
||||
if isinstance(target, PromptIndex):
|
||||
match_idx = target.get_match_index(update.tokenizer, prompt)
|
||||
if match_idx is None:
|
||||
return []
|
||||
|
||||
return [_PromptTargetIndexMatch(update, match_idx)]
|
||||
|
||||
return [
|
||||
_PromptTargetTokenMatch(update, match)
|
||||
for match in iter_token_matches(prompt, target.token_ids)
|
||||
]
|
||||
|
||||
return [
|
||||
_PromptTargetTokenMatch(update, match) for update in prompt_updates
|
||||
for match in iter_token_matches(prompt, update.target.token_ids)
|
||||
match for update in prompt_updates for match in get_matches(update)
|
||||
]
|
||||
|
||||
|
||||
@ -507,9 +616,24 @@ def find_text_matches(
|
||||
prompt_updates: Sequence[BoundPromptUpdate],
|
||||
) -> Sequence[_PromptTargetMatch]:
|
||||
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
|
||||
|
||||
def get_matches(update: BoundPromptUpdate):
|
||||
target = update.target
|
||||
|
||||
if isinstance(target, PromptIndex):
|
||||
match_idx = target.get_match_index(update.tokenizer, prompt)
|
||||
if match_idx is None:
|
||||
return []
|
||||
|
||||
return [_PromptTargetIndexMatch(update, match_idx)]
|
||||
|
||||
return [
|
||||
_PromptTargetTextMatch(update, match)
|
||||
for match in re.finditer(re.escape(target.text), prompt)
|
||||
]
|
||||
|
||||
return [
|
||||
_PromptTargetTextMatch(update, match) for update in prompt_updates
|
||||
for match in re.finditer(re.escape(update.target.text), prompt)
|
||||
match for update in prompt_updates for match in get_matches(update)
|
||||
]
|
||||
|
||||
|
||||
@ -547,45 +671,39 @@ def _apply_matches(
|
||||
prev_end_idx = 0
|
||||
next_idx_by_modality = defaultdict[str, int](lambda: 0)
|
||||
|
||||
for (start_idx, end_idx), group in groupby(
|
||||
_resolve_matches(prompt, mm_matches),
|
||||
key=lambda x: (x.start_idx, x.end_idx),
|
||||
):
|
||||
matches = tuple(group)
|
||||
assert len(matches) == 1
|
||||
for match in _resolve_matches(prompt, mm_matches):
|
||||
modality = match.modality
|
||||
|
||||
for match in matches:
|
||||
modality = match.modality
|
||||
item_start_idx = next_idx_by_modality[modality]
|
||||
max_item_count = mm_item_counts.get(modality, 0)
|
||||
if item_start_idx >= max_item_count:
|
||||
continue
|
||||
|
||||
item_idx = next_idx_by_modality[modality]
|
||||
if item_idx >= mm_item_counts.get(modality, 0):
|
||||
continue
|
||||
start_idx = match.start_idx
|
||||
end_idx = match.end_idx
|
||||
origin = match._origin
|
||||
mode = origin.mode
|
||||
|
||||
origin = match._origin
|
||||
if mode == UpdateMode.INSERT:
|
||||
out_seqs.append(prompt[prev_end_idx:end_idx])
|
||||
num_inserts = max_item_count
|
||||
elif mode == UpdateMode.REPLACE:
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx])
|
||||
num_inserts = max_item_count if start_idx == end_idx else 1
|
||||
else:
|
||||
assert_never(mode)
|
||||
|
||||
item_end_idx = min(item_start_idx + num_inserts, max_item_count)
|
||||
|
||||
for item_idx in range(item_start_idx, item_end_idx):
|
||||
content = origin.get_content(item_idx)
|
||||
mode = origin.mode
|
||||
insert_seq = (content.full.text if isinstance(prompt, str) else
|
||||
content.full.token_ids)
|
||||
|
||||
if mode == UpdateMode.INSERT:
|
||||
out_seqs.append(prompt[prev_end_idx:end_idx])
|
||||
num_inserts = mm_item_counts.get(modality, 0)
|
||||
elif mode == UpdateMode.REPLACE:
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx])
|
||||
num_inserts = 1
|
||||
else:
|
||||
assert_never(mode)
|
||||
out_seqs.append(insert_seq)
|
||||
|
||||
for _ in range(num_inserts):
|
||||
if item_idx >= mm_item_counts.get(modality, 0):
|
||||
continue
|
||||
|
||||
if isinstance(prompt, str):
|
||||
out_seqs.append(content.full.text)
|
||||
else:
|
||||
out_seqs.append(content.full.token_ids)
|
||||
|
||||
next_idx_by_modality[modality] += 1
|
||||
|
||||
prev_end_idx = end_idx
|
||||
prev_end_idx = end_idx
|
||||
next_idx_by_modality[modality] += item_end_idx - item_start_idx
|
||||
|
||||
out_seqs.append(prompt[prev_end_idx:])
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user