diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 878b15925006..ba3df86f715a 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -14,8 +14,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (PlaceholderFeaturesInfo, - PromptInsertion, PromptReplacement, - apply_text_matches, + PromptIndexTargets, PromptInsertion, + PromptReplacement, apply_text_matches, apply_token_matches, find_mm_placeholders, find_text_matches, find_token_matches, @@ -98,10 +98,20 @@ def test_iter_token_matches(token_ids, match_ids, expected): { "pattern_1": [], "pattern_2": [32000], + "pattern_3": PromptIndexTargets.start(), + "pattern_4": PromptIndexTargets.prefix([32000]), + "pattern_5": PromptIndexTargets.end(), }, { "pattern_1": [], "pattern_2": [], + "pattern_3": [ + { "start_idx": 0, "end_idx": 0 }, + ], + "pattern_4": [], + "pattern_5": [ + { "start_idx": 0, "end_idx": 0 }, + ], }, ), ( @@ -110,6 +120,9 @@ def test_iter_token_matches(token_ids, match_ids, expected): "pattern_1": [32000], "pattern_2": [32000, 32000], "pattern_3": [32000, 32000, 32000], + "pattern_4": PromptIndexTargets.start(), + "pattern_5": PromptIndexTargets.prefix([32000]), + "pattern_6": PromptIndexTargets.end(), }, { "pattern_1": [ @@ -125,6 +138,15 @@ def test_iter_token_matches(token_ids, match_ids, expected): "pattern_3": [ { "start_idx": 0, "end_idx": 3 }, ], + "pattern_4": [ + { "start_idx": 0, "end_idx": 0 }, + ], + "pattern_5": [ + { "start_idx": 1, "end_idx": 1 }, + ], + "pattern_6": [ + { "start_idx": 4, "end_idx": 4 }, + ], }, ), ( @@ -133,6 +155,9 @@ def test_iter_token_matches(token_ids, match_ids, expected): "pattern_1": [28747, 32000], "pattern_2": [28747, 32000, 32000, 32000], "pattern_3": [28747, 0, 32000], + "pattern_4": PromptIndexTargets.start(), + "pattern_5": PromptIndexTargets.prefix([28747, 32000]), + "pattern_6": PromptIndexTargets.end(), }, { "pattern_1": [ @@ -143,6 +168,13 @@ def test_iter_token_matches(token_ids, match_ids, expected): { "start_idx": 1, "end_idx": 5 }, ], "pattern_3": [], + "pattern_4": [ + { "start_idx": 0, "end_idx": 0 }, + ], + "pattern_5": [], + "pattern_6": [ + { "start_idx": 10, "end_idx": 10 }, + ], }, ), ], @@ -189,10 +221,20 @@ def test_find_token_matches( { "pattern_1": "", "pattern_2": "", + "pattern_3": PromptIndexTargets.start(), + "pattern_4": PromptIndexTargets.prefix(""), + "pattern_5": PromptIndexTargets.end(), }, { "pattern_1": [{ "start_idx": 0, "end_idx": 0 }], "pattern_2": [], + "pattern_3": [ + { "start_idx": 0, "end_idx": 0 }, + ], + "pattern_4": [], + "pattern_5": [ + { "start_idx": 0, "end_idx": 0 }, + ], } ), ( @@ -201,6 +243,9 @@ def test_find_token_matches( "pattern_1": "", "pattern_2": "", "pattern_3": "", + "pattern_4": PromptIndexTargets.start(), + "pattern_5": PromptIndexTargets.prefix(""), + "pattern_6": PromptIndexTargets.end(), }, { "pattern_1": [ @@ -216,6 +261,15 @@ def test_find_token_matches( "pattern_3": [ { "start_idx": 0, "end_idx": 21 }, ], + "pattern_4": [ + { "start_idx": 0, "end_idx": 0 }, + ], + "pattern_5": [ + { "start_idx": 7, "end_idx": 7 }, + ], + "pattern_6": [ + { "start_idx": 28, "end_idx": 28 }, + ], }, ), ( @@ -224,6 +278,9 @@ def test_find_token_matches( "pattern_1": "Image:", "pattern_2": "Image:", "pattern_3": "Image:", + "pattern_4": PromptIndexTargets.start(), + "pattern_5": PromptIndexTargets.prefix("Image:"), + "pattern_6": PromptIndexTargets.end(), }, { "pattern_1": [ @@ -234,6 +291,15 @@ def test_find_token_matches( { "start_idx": 0, "end_idx": 27 }, ], "pattern_3": [], + "pattern_4": [ + { "start_idx": 0, "end_idx": 0 }, + ], + "pattern_5": [ + { "start_idx": 13, "end_idx": 13 }, + ], + "pattern_6": [ + { "start_idx": 48, "end_idx": 48 }, + ], }, ), # Test regex escape @@ -325,6 +391,100 @@ def test_find_text_matches( }, }, ), + # Test index targets + ( + "", + { + "pattern_1": PromptIndexTargets.start(), + "pattern_2": PromptIndexTargets.prefix(""), + "pattern_3": PromptIndexTargets.end(), + }, + { + "pattern_1": "1", + "pattern_2": "2", + "pattern_3": "3", + }, + { + PromptInsertion: { + 0: "", + 1: "13", + 2: "1133", + }, + PromptReplacement: { + 0: "", + 1: "13", + 2: "1133", + }, + }, + ), + ( + "", + { + "pattern_1": PromptIndexTargets.start(), + "pattern_2": PromptIndexTargets.prefix(""), + "pattern_3": PromptIndexTargets.end(), + }, + { + "pattern_1": "1", + "pattern_2": "2", + "pattern_3": "3", + }, + { + PromptInsertion: { + 0: "", + 1: "123", + 2: "112233", + }, + PromptReplacement: { + 0: "", + 1: "123", + 2: "112233", + }, + }, + ), + # Test different replacement per item + ( + "", + { + "pattern_1": "", + }, + { + "pattern_1": lambda idx: str(idx + 1), + }, + { + PromptInsertion: { + 0: "", + 1: "1", + 2: "12", + }, + PromptReplacement: { + 0: "", + 1: "1", + 2: "12", + }, + }, + ), + ( + "", + { + "pattern_1": PromptIndexTargets.prefix(""), + }, + { + "pattern_1": lambda idx: str(idx + 1), + }, + { + PromptInsertion: { + 0: "", + 1: "1", + 2: "12", + }, + PromptReplacement: { + 0: "", + 1: "1", + 2: "12", + }, + }, + ), ] ) # yapf: enable @@ -405,6 +565,100 @@ def test_find_update_text( }, }, ), + # Test index targets + ( + [], + { + "pattern_1": PromptIndexTargets.start(), + "pattern_2": PromptIndexTargets.prefix([32000]), + "pattern_3": PromptIndexTargets.end(), + }, + { + "pattern_1": [-1], + "pattern_2": [-2], + "pattern_3": [-3], + }, + { + PromptInsertion: { + 0: [], + 1: [-1, -3], + 2: [-1, -1, -3, -3], + }, + PromptReplacement: { + 0: [], + 1: [-1, -3], + 2: [-1, -1, -3, -3], + }, + }, + ), + ( + [32000], + { + "pattern_1": PromptIndexTargets.start(), + "pattern_2": PromptIndexTargets.prefix([32000]), + "pattern_3": PromptIndexTargets.end(), + }, + { + "pattern_1": [-1], + "pattern_2": [-2], + "pattern_3": [-3], + }, + { + PromptInsertion: { + 0: [32000], + 1: [-1, 32000, -2, -3], + 2: [-1, -1, 32000, -2, -2, -3, -3], + }, + PromptReplacement: { + 0: [32000], + 1: [-1, 32000, -2, -3], + 2: [-1, -1, 32000, -2, -2, -3, -3], + }, + }, + ), + # Test different replacement per item + ( + [32000, 32000, 32000], + { + "pattern_1": [32000], + }, + { + "pattern_1": lambda idx: [-(idx + 1)], + }, + { + PromptInsertion: { + 0: [32000, 32000, 32000], + 1: [32000, -1, 32000, 32000], + 2: [32000, -1, -2, 32000, 32000], + }, + PromptReplacement: { + 0: [32000, 32000, 32000], + 1: [-1, 32000, 32000], + 2: [-1, -2, 32000], + }, + }, + ), + ( + [32000, 32000, 32000], + { + "pattern_1": PromptIndexTargets.prefix([32000]), + }, + { + "pattern_1": lambda idx: [-(idx + 1)], + }, + { + PromptInsertion: { + 0: [32000, 32000, 32000], + 1: [32000, -1, 32000, 32000], + 2: [32000, -1, -2, 32000, 32000], + }, + PromptReplacement: { + 0: [32000, 32000, 32000], + 1: [32000, -1, 32000, 32000], + 2: [32000, -1, -2, 32000, 32000], + }, + }, + ), ] ) # yapf: enable diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 61f2f8974d91..8457f6294460 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -19,8 +19,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptInsertion, - PromptUpdate) + BaseProcessingInfo, PromptIndexTargets, + PromptInsertion, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -490,7 +490,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): return [ PromptInsertion( modality="image", - target="", + target=PromptIndexTargets.start(), insertion=image_tokens, ) ] diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 6fa1bb80995d..7a8510379455 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -25,7 +25,8 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, - PromptInsertion, PromptUpdate) + PromptIndexTargets, PromptInsertion, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -864,7 +865,7 @@ class Florence2MultiModalProcessor( return [ PromptInsertion( modality="image", - target="", + target=PromptIndexTargets.start(), insertion=image_tokens, ) ] diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 60af103189f8..21158f7e5802 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptInsertion, - PromptUpdate) + BaseProcessingInfo, PromptIndexTargets, + PromptInsertion, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import JSONTree, json_map_leaves @@ -1371,7 +1371,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): return [ PromptInsertion( modality="image", - target="<|endoftext|>", + target=PromptIndexTargets.prefix("<|endoftext|>"), insertion=get_insertion_molmo, ) ] diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index ac33af7c10c7..7232df074f84 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -8,7 +8,6 @@ from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, from dataclasses import dataclass, field from enum import Enum from functools import lru_cache -from itertools import groupby from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, TypeVar, Union, cast) @@ -40,6 +39,65 @@ PromptSeq = Union[str, list[int]] """A token sequence (list of token IDs) or text.""" +@dataclass +class PromptIndex: + """Resolves to an index in the prompt.""" + get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]] + + +class PromptIndexTargets: + + @staticmethod + def start() -> PromptIndex: + """ + Resolves to the start of the prompt (before the first token). + + This results in a match even if the prompt is empty. + """ + return PromptIndex(lambda tok, prompt: 0) + + @staticmethod + def prefix(seq: PromptSeq) -> PromptIndex: + """ + Resolves to a location in the prompt after the given prefix. + """ + + def get_match_index( + tokenizer: AnyTokenizer, + prompt: PromptSeq, + ) -> Optional[int]: + prefix = seq + + if isinstance(prompt, str): + if not isinstance(prefix, str): + # Make both `str` + prefix = decode_tokens(tokenizer, prefix) + else: + if isinstance(prefix, str): + # Make both `list[int]` + prefix = encode_tokens(tokenizer, prefix) + + match_idx = len(prefix) + return match_idx if prompt[:match_idx] == prefix else None + + return PromptIndex(get_match_index) + + @staticmethod + def end() -> PromptIndex: + """ + Resolves to the end of the prompt (after the last token). + + This results in a match even if the prompt is empty. + """ + return PromptIndex(lambda tok, prompt: len(prompt)) + + +PromptTarget = Union[PromptSeq, PromptIndex] +""" +The token sequence or text to update. +""" + + @dataclass class PromptUpdateDetails: """Details about the token sequence or text that are part of the update.""" @@ -84,7 +142,7 @@ class UpdateMode(str, Enum): @dataclass -class PromptUpdate: +class PromptUpdate(ABC): """ Defines how to update a prompt with placeholder tokens. """ @@ -92,7 +150,7 @@ class PromptUpdate: modality: str """The modality for which the update is made.""" - target: PromptSeq + target: PromptTarget """The token sequence (or text) to update.""" @property @@ -122,18 +180,7 @@ class PromptInsertion(PromptUpdate): Example: For each image, insert a number of ```` feature placeholders - equal to the feature size of the vision encoder at the start of the - prompt: - - .. code-block:: python - - PromptInsertion( - modality="image", - target="", - insertion="" * image_feature_size, - ) - - As above, but insert after the ```` token: + equal to the feature size of the vision encoder after the ```` token: .. code-block:: python @@ -142,6 +189,36 @@ class PromptInsertion(PromptUpdate): target="", insertion="" * image_feature_size, ) + + Insert these tokens at the start of the prompt: + + .. code-block:: python + + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion="" * image_feature_size, + ) + + Insert these tokens after a prefix ``Images:``: + + .. code-block:: python + + PromptInsertion( + modality="image", + target=PromptIndexTargets.prefix("Images:"), + insertion="" * image_feature_size, + ) + + Insert these tokens at the end of the prompt: + + .. code-block:: python + + PromptInsertion( + modality="image", + target=PromptIndexTargets.end(), + insertion="" * image_feature_size, + ) """ insertion: PromptUpdateContent = field(repr=False) @@ -345,10 +422,14 @@ class BoundPromptUpdate: return self._origin.modality @property - def target(self) -> _BoundPromptSequence: + def target(self) -> Union[_BoundPromptSequence, PromptIndex]: """The token sequence (or text) to update.""" - return _BoundPromptSequence.from_seq(self.tokenizer, - self._origin.target) + target = self._origin.target + + if isinstance(target, PromptIndex): + return target + + return _BoundPromptSequence.from_seq(self.tokenizer, target) @property def content(self) -> PromptUpdateContent: @@ -447,6 +528,19 @@ class _PromptTargetMatch(ABC): f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") +@dataclass(repr=False) +class _PromptTargetIndexMatch(_PromptTargetMatch): + match_idx: int + + @property + def start_idx(self) -> int: + return self.match_idx + + @property + def end_idx(self) -> int: + return self.match_idx + + @dataclass(repr=False) class _PromptTargetTokenMatch(_PromptTargetMatch): match: _TokenMatch @@ -496,9 +590,24 @@ def find_token_matches( prompt_updates: Sequence[BoundPromptUpdate], ) -> Sequence[_PromptTargetMatch]: """Return each target of :code:`prompt_updates` found in :code:`prompt`.""" + + def get_matches(update: BoundPromptUpdate): + target = update.target + + if isinstance(target, PromptIndex): + match_idx = target.get_match_index(update.tokenizer, prompt) + if match_idx is None: + return [] + + return [_PromptTargetIndexMatch(update, match_idx)] + + return [ + _PromptTargetTokenMatch(update, match) + for match in iter_token_matches(prompt, target.token_ids) + ] + return [ - _PromptTargetTokenMatch(update, match) for update in prompt_updates - for match in iter_token_matches(prompt, update.target.token_ids) + match for update in prompt_updates for match in get_matches(update) ] @@ -507,9 +616,24 @@ def find_text_matches( prompt_updates: Sequence[BoundPromptUpdate], ) -> Sequence[_PromptTargetMatch]: """Return each target of :code:`prompt_updates` found in :code:`prompt`.""" + + def get_matches(update: BoundPromptUpdate): + target = update.target + + if isinstance(target, PromptIndex): + match_idx = target.get_match_index(update.tokenizer, prompt) + if match_idx is None: + return [] + + return [_PromptTargetIndexMatch(update, match_idx)] + + return [ + _PromptTargetTextMatch(update, match) + for match in re.finditer(re.escape(target.text), prompt) + ] + return [ - _PromptTargetTextMatch(update, match) for update in prompt_updates - for match in re.finditer(re.escape(update.target.text), prompt) + match for update in prompt_updates for match in get_matches(update) ] @@ -547,45 +671,39 @@ def _apply_matches( prev_end_idx = 0 next_idx_by_modality = defaultdict[str, int](lambda: 0) - for (start_idx, end_idx), group in groupby( - _resolve_matches(prompt, mm_matches), - key=lambda x: (x.start_idx, x.end_idx), - ): - matches = tuple(group) - assert len(matches) == 1 + for match in _resolve_matches(prompt, mm_matches): + modality = match.modality - for match in matches: - modality = match.modality + item_start_idx = next_idx_by_modality[modality] + max_item_count = mm_item_counts.get(modality, 0) + if item_start_idx >= max_item_count: + continue - item_idx = next_idx_by_modality[modality] - if item_idx >= mm_item_counts.get(modality, 0): - continue + start_idx = match.start_idx + end_idx = match.end_idx + origin = match._origin + mode = origin.mode - origin = match._origin + if mode == UpdateMode.INSERT: + out_seqs.append(prompt[prev_end_idx:end_idx]) + num_inserts = max_item_count + elif mode == UpdateMode.REPLACE: + out_seqs.append(prompt[prev_end_idx:start_idx]) + num_inserts = max_item_count if start_idx == end_idx else 1 + else: + assert_never(mode) + + item_end_idx = min(item_start_idx + num_inserts, max_item_count) + + for item_idx in range(item_start_idx, item_end_idx): content = origin.get_content(item_idx) - mode = origin.mode + insert_seq = (content.full.text if isinstance(prompt, str) else + content.full.token_ids) - if mode == UpdateMode.INSERT: - out_seqs.append(prompt[prev_end_idx:end_idx]) - num_inserts = mm_item_counts.get(modality, 0) - elif mode == UpdateMode.REPLACE: - out_seqs.append(prompt[prev_end_idx:start_idx]) - num_inserts = 1 - else: - assert_never(mode) + out_seqs.append(insert_seq) - for _ in range(num_inserts): - if item_idx >= mm_item_counts.get(modality, 0): - continue - - if isinstance(prompt, str): - out_seqs.append(content.full.text) - else: - out_seqs.append(content.full.token_ids) - - next_idx_by_modality[modality] += 1 - - prev_end_idx = end_idx + prev_end_idx = end_idx + next_idx_by_modality[modality] += item_end_idx - item_start_idx out_seqs.append(prompt[prev_end_idx:])