[VLM][Bugfix] Enable specifying prompt target via index (#14038)

This commit is contained in:
Cyrus Leung 2025-02-28 23:35:55 +08:00 committed by GitHub
parent e0734387fb
commit f7bee5c815
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 438 additions and 65 deletions

View File

@ -14,8 +14,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptInsertion, PromptReplacement,
apply_text_matches,
PromptIndexTargets, PromptInsertion,
PromptReplacement, apply_text_matches,
apply_token_matches,
find_mm_placeholders,
find_text_matches, find_token_matches,
@ -98,10 +98,20 @@ def test_iter_token_matches(token_ids, match_ids, expected):
{
"pattern_1": [],
"pattern_2": [32000],
"pattern_3": PromptIndexTargets.start(),
"pattern_4": PromptIndexTargets.prefix([32000]),
"pattern_5": PromptIndexTargets.end(),
},
{
"pattern_1": [],
"pattern_2": [],
"pattern_3": [
{ "start_idx": 0, "end_idx": 0 },
],
"pattern_4": [],
"pattern_5": [
{ "start_idx": 0, "end_idx": 0 },
],
},
),
(
@ -110,6 +120,9 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_1": [32000],
"pattern_2": [32000, 32000],
"pattern_3": [32000, 32000, 32000],
"pattern_4": PromptIndexTargets.start(),
"pattern_5": PromptIndexTargets.prefix([32000]),
"pattern_6": PromptIndexTargets.end(),
},
{
"pattern_1": [
@ -125,6 +138,15 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_3": [
{ "start_idx": 0, "end_idx": 3 },
],
"pattern_4": [
{ "start_idx": 0, "end_idx": 0 },
],
"pattern_5": [
{ "start_idx": 1, "end_idx": 1 },
],
"pattern_6": [
{ "start_idx": 4, "end_idx": 4 },
],
},
),
(
@ -133,6 +155,9 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_1": [28747, 32000],
"pattern_2": [28747, 32000, 32000, 32000],
"pattern_3": [28747, 0, 32000],
"pattern_4": PromptIndexTargets.start(),
"pattern_5": PromptIndexTargets.prefix([28747, 32000]),
"pattern_6": PromptIndexTargets.end(),
},
{
"pattern_1": [
@ -143,6 +168,13 @@ def test_iter_token_matches(token_ids, match_ids, expected):
{ "start_idx": 1, "end_idx": 5 },
],
"pattern_3": [],
"pattern_4": [
{ "start_idx": 0, "end_idx": 0 },
],
"pattern_5": [],
"pattern_6": [
{ "start_idx": 10, "end_idx": 10 },
],
},
),
],
@ -189,10 +221,20 @@ def test_find_token_matches(
{
"pattern_1": "",
"pattern_2": "<image>",
"pattern_3": PromptIndexTargets.start(),
"pattern_4": PromptIndexTargets.prefix("<image>"),
"pattern_5": PromptIndexTargets.end(),
},
{
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
"pattern_2": [],
"pattern_3": [
{ "start_idx": 0, "end_idx": 0 },
],
"pattern_4": [],
"pattern_5": [
{ "start_idx": 0, "end_idx": 0 },
],
}
),
(
@ -201,6 +243,9 @@ def test_find_token_matches(
"pattern_1": "<image>",
"pattern_2": "<image><image>",
"pattern_3": "<image><image><image>",
"pattern_4": PromptIndexTargets.start(),
"pattern_5": PromptIndexTargets.prefix("<image>"),
"pattern_6": PromptIndexTargets.end(),
},
{
"pattern_1": [
@ -216,6 +261,15 @@ def test_find_token_matches(
"pattern_3": [
{ "start_idx": 0, "end_idx": 21 },
],
"pattern_4": [
{ "start_idx": 0, "end_idx": 0 },
],
"pattern_5": [
{ "start_idx": 7, "end_idx": 7 },
],
"pattern_6": [
{ "start_idx": 28, "end_idx": 28 },
],
},
),
(
@ -224,6 +278,9 @@ def test_find_token_matches(
"pattern_1": "Image:<image>",
"pattern_2": "Image:<image><image><image>",
"pattern_3": "Image:<unk><image>",
"pattern_4": PromptIndexTargets.start(),
"pattern_5": PromptIndexTargets.prefix("Image:<image>"),
"pattern_6": PromptIndexTargets.end(),
},
{
"pattern_1": [
@ -234,6 +291,15 @@ def test_find_token_matches(
{ "start_idx": 0, "end_idx": 27 },
],
"pattern_3": [],
"pattern_4": [
{ "start_idx": 0, "end_idx": 0 },
],
"pattern_5": [
{ "start_idx": 13, "end_idx": 13 },
],
"pattern_6": [
{ "start_idx": 48, "end_idx": 48 },
],
},
),
# Test regex escape
@ -325,6 +391,100 @@ def test_find_text_matches(
},
},
),
# Test index targets
(
"",
{
"pattern_1": PromptIndexTargets.start(),
"pattern_2": PromptIndexTargets.prefix("<image>"),
"pattern_3": PromptIndexTargets.end(),
},
{
"pattern_1": "1",
"pattern_2": "2",
"pattern_3": "3",
},
{
PromptInsertion: {
0: "",
1: "13",
2: "1133",
},
PromptReplacement: {
0: "",
1: "13",
2: "1133",
},
},
),
(
"<image>",
{
"pattern_1": PromptIndexTargets.start(),
"pattern_2": PromptIndexTargets.prefix("<image>"),
"pattern_3": PromptIndexTargets.end(),
},
{
"pattern_1": "1",
"pattern_2": "2",
"pattern_3": "3",
},
{
PromptInsertion: {
0: "<image>",
1: "1<image>23",
2: "11<image>2233",
},
PromptReplacement: {
0: "<image>",
1: "1<image>23",
2: "11<image>2233",
},
},
),
# Test different replacement per item
(
"<image><image><image>",
{
"pattern_1": "<image>",
},
{
"pattern_1": lambda idx: str(idx + 1),
},
{
PromptInsertion: {
0: "<image><image><image>",
1: "<image>1<image><image>",
2: "<image>12<image><image>",
},
PromptReplacement: {
0: "<image><image><image>",
1: "1<image><image>",
2: "12<image>",
},
},
),
(
"<image><image><image>",
{
"pattern_1": PromptIndexTargets.prefix("<image>"),
},
{
"pattern_1": lambda idx: str(idx + 1),
},
{
PromptInsertion: {
0: "<image><image><image>",
1: "<image>1<image><image>",
2: "<image>12<image><image>",
},
PromptReplacement: {
0: "<image><image><image>",
1: "<image>1<image><image>",
2: "<image>12<image><image>",
},
},
),
]
)
# yapf: enable
@ -405,6 +565,100 @@ def test_find_update_text(
},
},
),
# Test index targets
(
[],
{
"pattern_1": PromptIndexTargets.start(),
"pattern_2": PromptIndexTargets.prefix([32000]),
"pattern_3": PromptIndexTargets.end(),
},
{
"pattern_1": [-1],
"pattern_2": [-2],
"pattern_3": [-3],
},
{
PromptInsertion: {
0: [],
1: [-1, -3],
2: [-1, -1, -3, -3],
},
PromptReplacement: {
0: [],
1: [-1, -3],
2: [-1, -1, -3, -3],
},
},
),
(
[32000],
{
"pattern_1": PromptIndexTargets.start(),
"pattern_2": PromptIndexTargets.prefix([32000]),
"pattern_3": PromptIndexTargets.end(),
},
{
"pattern_1": [-1],
"pattern_2": [-2],
"pattern_3": [-3],
},
{
PromptInsertion: {
0: [32000],
1: [-1, 32000, -2, -3],
2: [-1, -1, 32000, -2, -2, -3, -3],
},
PromptReplacement: {
0: [32000],
1: [-1, 32000, -2, -3],
2: [-1, -1, 32000, -2, -2, -3, -3],
},
},
),
# Test different replacement per item
(
[32000, 32000, 32000],
{
"pattern_1": [32000],
},
{
"pattern_1": lambda idx: [-(idx + 1)],
},
{
PromptInsertion: {
0: [32000, 32000, 32000],
1: [32000, -1, 32000, 32000],
2: [32000, -1, -2, 32000, 32000],
},
PromptReplacement: {
0: [32000, 32000, 32000],
1: [-1, 32000, 32000],
2: [-1, -2, 32000],
},
},
),
(
[32000, 32000, 32000],
{
"pattern_1": PromptIndexTargets.prefix([32000]),
},
{
"pattern_1": lambda idx: [-(idx + 1)],
},
{
PromptInsertion: {
0: [32000, 32000, 32000],
1: [32000, -1, 32000, 32000],
2: [32000, -1, -2, 32000, 32000],
},
PromptReplacement: {
0: [32000, 32000, 32000],
1: [32000, -1, 32000, 32000],
2: [32000, -1, -2, 32000, 32000],
},
},
),
]
)
# yapf: enable

View File

@ -19,8 +19,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptInsertion,
PromptUpdate)
BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -490,7 +490,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
return [
PromptInsertion(
modality="image",
target="",
target=PromptIndexTargets.start(),
insertion=image_tokens,
)
]

View File

@ -25,7 +25,8 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptInsertion, PromptUpdate)
PromptIndexTargets, PromptInsertion,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -864,7 +865,7 @@ class Florence2MultiModalProcessor(
return [
PromptInsertion(
modality="image",
target="",
target=PromptIndexTargets.start(),
insertion=image_tokens,
)
]

View File

@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptInsertion,
PromptUpdate)
BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, json_map_leaves
@ -1371,7 +1371,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
return [
PromptInsertion(
modality="image",
target="<|endoftext|>",
target=PromptIndexTargets.prefix("<|endoftext|>"),
insertion=get_insertion_molmo,
)
]

View File

@ -8,7 +8,6 @@ from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache
from itertools import groupby
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast)
@ -40,6 +39,65 @@ PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
@dataclass
class PromptIndex:
"""Resolves to an index in the prompt."""
get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]]
class PromptIndexTargets:
@staticmethod
def start() -> PromptIndex:
"""
Resolves to the start of the prompt (before the first token).
This results in a match even if the prompt is empty.
"""
return PromptIndex(lambda tok, prompt: 0)
@staticmethod
def prefix(seq: PromptSeq) -> PromptIndex:
"""
Resolves to a location in the prompt after the given prefix.
"""
def get_match_index(
tokenizer: AnyTokenizer,
prompt: PromptSeq,
) -> Optional[int]:
prefix = seq
if isinstance(prompt, str):
if not isinstance(prefix, str):
# Make both `str`
prefix = decode_tokens(tokenizer, prefix)
else:
if isinstance(prefix, str):
# Make both `list[int]`
prefix = encode_tokens(tokenizer, prefix)
match_idx = len(prefix)
return match_idx if prompt[:match_idx] == prefix else None
return PromptIndex(get_match_index)
@staticmethod
def end() -> PromptIndex:
"""
Resolves to the end of the prompt (after the last token).
This results in a match even if the prompt is empty.
"""
return PromptIndex(lambda tok, prompt: len(prompt))
PromptTarget = Union[PromptSeq, PromptIndex]
"""
The token sequence or text to update.
"""
@dataclass
class PromptUpdateDetails:
"""Details about the token sequence or text that are part of the update."""
@ -84,7 +142,7 @@ class UpdateMode(str, Enum):
@dataclass
class PromptUpdate:
class PromptUpdate(ABC):
"""
Defines how to update a prompt with placeholder tokens.
"""
@ -92,7 +150,7 @@ class PromptUpdate:
modality: str
"""The modality for which the update is made."""
target: PromptSeq
target: PromptTarget
"""The token sequence (or text) to update."""
@property
@ -122,18 +180,7 @@ class PromptInsertion(PromptUpdate):
Example:
For each image, insert a number of ``<image>`` feature placeholders
equal to the feature size of the vision encoder at the start of the
prompt:
.. code-block:: python
PromptInsertion(
modality="image",
target="",
insertion="<image>" * image_feature_size,
)
As above, but insert after the ``<s>`` token:
equal to the feature size of the vision encoder after the ``<s>`` token:
.. code-block:: python
@ -142,6 +189,36 @@ class PromptInsertion(PromptUpdate):
target="<s>",
insertion="<image>" * image_feature_size,
)
Insert these tokens at the start of the prompt:
.. code-block:: python
PromptInsertion(
modality="image",
target=PromptIndexTargets.start(),
insertion="<image>" * image_feature_size,
)
Insert these tokens after a prefix ``Images:``:
.. code-block:: python
PromptInsertion(
modality="image",
target=PromptIndexTargets.prefix("Images:"),
insertion="<image>" * image_feature_size,
)
Insert these tokens at the end of the prompt:
.. code-block:: python
PromptInsertion(
modality="image",
target=PromptIndexTargets.end(),
insertion="<image>" * image_feature_size,
)
"""
insertion: PromptUpdateContent = field(repr=False)
@ -345,10 +422,14 @@ class BoundPromptUpdate:
return self._origin.modality
@property
def target(self) -> _BoundPromptSequence:
def target(self) -> Union[_BoundPromptSequence, PromptIndex]:
"""The token sequence (or text) to update."""
return _BoundPromptSequence.from_seq(self.tokenizer,
self._origin.target)
target = self._origin.target
if isinstance(target, PromptIndex):
return target
return _BoundPromptSequence.from_seq(self.tokenizer, target)
@property
def content(self) -> PromptUpdateContent:
@ -447,6 +528,19 @@ class _PromptTargetMatch(ABC):
f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
@dataclass(repr=False)
class _PromptTargetIndexMatch(_PromptTargetMatch):
match_idx: int
@property
def start_idx(self) -> int:
return self.match_idx
@property
def end_idx(self) -> int:
return self.match_idx
@dataclass(repr=False)
class _PromptTargetTokenMatch(_PromptTargetMatch):
match: _TokenMatch
@ -496,9 +590,24 @@ def find_token_matches(
prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[_PromptTargetMatch]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
def get_matches(update: BoundPromptUpdate):
target = update.target
if isinstance(target, PromptIndex):
match_idx = target.get_match_index(update.tokenizer, prompt)
if match_idx is None:
return []
return [_PromptTargetIndexMatch(update, match_idx)]
return [
_PromptTargetTokenMatch(update, match)
for match in iter_token_matches(prompt, target.token_ids)
]
return [
_PromptTargetTokenMatch(update, match) for update in prompt_updates
for match in iter_token_matches(prompt, update.target.token_ids)
match for update in prompt_updates for match in get_matches(update)
]
@ -507,9 +616,24 @@ def find_text_matches(
prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[_PromptTargetMatch]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
def get_matches(update: BoundPromptUpdate):
target = update.target
if isinstance(target, PromptIndex):
match_idx = target.get_match_index(update.tokenizer, prompt)
if match_idx is None:
return []
return [_PromptTargetIndexMatch(update, match_idx)]
return [
_PromptTargetTextMatch(update, match)
for match in re.finditer(re.escape(target.text), prompt)
]
return [
_PromptTargetTextMatch(update, match) for update in prompt_updates
for match in re.finditer(re.escape(update.target.text), prompt)
match for update in prompt_updates for match in get_matches(update)
]
@ -547,45 +671,39 @@ def _apply_matches(
prev_end_idx = 0
next_idx_by_modality = defaultdict[str, int](lambda: 0)
for (start_idx, end_idx), group in groupby(
_resolve_matches(prompt, mm_matches),
key=lambda x: (x.start_idx, x.end_idx),
):
matches = tuple(group)
assert len(matches) == 1
for match in _resolve_matches(prompt, mm_matches):
modality = match.modality
for match in matches:
modality = match.modality
item_start_idx = next_idx_by_modality[modality]
max_item_count = mm_item_counts.get(modality, 0)
if item_start_idx >= max_item_count:
continue
item_idx = next_idx_by_modality[modality]
if item_idx >= mm_item_counts.get(modality, 0):
continue
start_idx = match.start_idx
end_idx = match.end_idx
origin = match._origin
mode = origin.mode
origin = match._origin
if mode == UpdateMode.INSERT:
out_seqs.append(prompt[prev_end_idx:end_idx])
num_inserts = max_item_count
elif mode == UpdateMode.REPLACE:
out_seqs.append(prompt[prev_end_idx:start_idx])
num_inserts = max_item_count if start_idx == end_idx else 1
else:
assert_never(mode)
item_end_idx = min(item_start_idx + num_inserts, max_item_count)
for item_idx in range(item_start_idx, item_end_idx):
content = origin.get_content(item_idx)
mode = origin.mode
insert_seq = (content.full.text if isinstance(prompt, str) else
content.full.token_ids)
if mode == UpdateMode.INSERT:
out_seqs.append(prompt[prev_end_idx:end_idx])
num_inserts = mm_item_counts.get(modality, 0)
elif mode == UpdateMode.REPLACE:
out_seqs.append(prompt[prev_end_idx:start_idx])
num_inserts = 1
else:
assert_never(mode)
out_seqs.append(insert_seq)
for _ in range(num_inserts):
if item_idx >= mm_item_counts.get(modality, 0):
continue
if isinstance(prompt, str):
out_seqs.append(content.full.text)
else:
out_seqs.append(content.full.token_ids)
next_idx_by_modality[modality] += 1
prev_end_idx = end_idx
prev_end_idx = end_idx
next_idx_by_modality[modality] += item_end_idx - item_start_idx
out_seqs.append(prompt[prev_end_idx:])