[Refactor] Dynamic target and content for prompt updates (#23411)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-25 14:39:58 +08:00 committed by GitHub
parent 49ab23b3cc
commit 712d0f88d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 463 additions and 456 deletions

View File

@ -17,13 +17,11 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptReplacement, apply_text_matches, PromptReplacement, apply_text_matches,
apply_token_matches, apply_token_matches,
find_mm_placeholders, find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches, iter_token_matches,
replace_token_matches) replace_token_matches)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby
from .utils import random_image from .utils import random_image
@ -75,12 +73,15 @@ from .utils import random_image
), ),
], ],
) )
@pytest.mark.parametrize("start_idx", [0, 4, 8])
# yapf: enable # yapf: enable
def test_iter_token_matches(token_ids, match_ids, expected): def test_iter_token_matches(token_ids, match_ids, expected, start_idx):
result = list(iter_token_matches(token_ids, match_ids)) result = list(iter_token_matches(token_ids, match_ids,
start_idx=start_idx))
# Manually constructed results # Manually constructed results
assert [item._asdict() for item in result] == expected assert [item._asdict() for item in result
] == [item for item in expected if item["start_idx"] >= start_idx]
# Invariants # Invariants
match_lens = [end - start for start, end in result] match_lens = [end - start for start, end in result]
@ -241,21 +242,23 @@ def test_find_token_matches(
# Should not be used since there is nothing to convert to token IDs # Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = [ prompt_updates = {
update_type(key, target, []).bind(mock_tokenizer) key: update_type(key, target, []).resolve(mock_tokenizer, 0)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] }
result = find_token_matches(prompt, prompt_updates) result = {
key: list(update.iter_token_matches(prompt, mock_tokenizer))
for key, update in prompt_updates.items()
}
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)
# Manually constructed results # Manually constructed results
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
assert { assert {
key: [ key: [
dict(start_idx=item.start_idx, end_idx=item.end_idx) dict(start_idx=item.start_idx, end_idx=item.end_idx)
for item in result_groups.get(key, []) for item in result.get(key, [])
] ]
for key in expected_by_key for key in expected_by_key
} == expected_by_key } == expected_by_key
@ -388,21 +391,23 @@ def test_find_text_matches(
# Should not be used since there is nothing to convert to text # Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = [ prompt_updates = {
update_type(key, target, []).bind(mock_tokenizer) key: update_type(key, target, []).resolve(mock_tokenizer, 0)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] }
result = find_text_matches(prompt, prompt_updates) result = {
key: list(update.iter_text_matches(prompt, mock_tokenizer))
for key, update in prompt_updates.items()
}
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)
# Manually constructed results # Manually constructed results
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
assert { assert {
key: [ key: [
dict(start_idx=item.start_idx, end_idx=item.end_idx) dict(start_idx=item.start_idx, end_idx=item.end_idx)
for item in result_groups.get(key, []) for item in result.get(key, [])
] ]
for key in expected_by_key for key in expected_by_key
} == expected_by_key } == expected_by_key
@ -552,39 +557,37 @@ def test_find_update_text(
update_type, update_type,
expected_by_mm_count, expected_by_mm_count,
) in expected_by_update_type_mm_count.items(): ) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_text_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
for mm_count, expected in expected_by_mm_count.items(): for mm_count, expected in expected_by_mm_count.items():
result = apply_text_matches( mm_prompt_updates = {
key: [[
update_type(key, target,
repl_by_key[key]).resolve(mock_tokenizer, i)
] for i in range(mm_count)]
for key, target in target_by_key.items()
}
new_prompt, result = apply_text_matches(
prompt, prompt,
mm_matches, mm_prompt_updates,
{key: mm_count mock_tokenizer,
for key in repl_by_key},
) )
# Only displayed on error # Only displayed on error
print("update_type:", update_type) print("update_type:", update_type)
print("mm_count:", mm_count) print("mm_count:", mm_count)
print("mm_matches:", mm_matches) print("mm_prompt_updates:", mm_prompt_updates)
print("new_prompt:", new_prompt)
print("result:", result) print("result:", result)
# Manually constructed results # Manually constructed results
assert result == expected assert new_prompt == expected
# yapf: disable # yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[ [
# Tokenized test cases of `test_find_replace_text` # Tokenized test cases of `test_find_update_text`
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
( (
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
@ -726,32 +729,30 @@ def test_find_update_tokens(
update_type, update_type,
expected_by_mm_count, expected_by_mm_count,
) in expected_by_update_type_mm_count.items(): ) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_token_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
for mm_count, expected in expected_by_mm_count.items(): for mm_count, expected in expected_by_mm_count.items():
result = apply_token_matches( mm_prompt_updates = {
key: [[
update_type(key, target,
repl_by_key[key]).resolve(mock_tokenizer, i)
] for i in range(mm_count)]
for key, target in target_by_key.items()
}
new_prompt, result = apply_token_matches(
prompt, prompt,
mm_matches, mm_prompt_updates,
{key: mm_count mock_tokenizer,
for key in repl_by_key},
) )
# Only displayed on error # Only displayed on error
print("update_type:", update_type) print("update_type:", update_type)
print("mm_count:", mm_count) print("mm_count:", mm_count)
print("mm_matches:", mm_matches) print("mm_prompt_updates:", mm_prompt_updates)
print("new_prompt:", new_prompt)
print("result:", result) print("result:", result)
# Manually constructed results # Manually constructed results
assert result == expected assert new_prompt == expected
# yapf: disable # yapf: disable
@ -878,17 +879,12 @@ def test_find_mm_placeholders(
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_updates = { mm_prompt_updates = {
key: [update_type(key, [], repl).bind(mock_tokenizer)] key: [[update_type(key, [], repl).resolve(mock_tokenizer, i)]
for i in range(3)]
for key, repl in repl_by_key.items() for key, repl in repl_by_key.items()
} }
result = find_mm_placeholders( result = find_mm_placeholders(prompt, mm_prompt_updates)
mm_prompt_updates,
prompt,
# Effectively match all occurrences in the prompt
{key: 3
for key in repl_by_key},
)
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)

View File

@ -22,10 +22,12 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate, BaseProcessingInfo,
MultiModalPromptUpdates,
MultiModalPromptUpdatesApplyResult,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch, PromptReplacement, PromptUpdate,
PromptUpdate, PromptUpdateDetails, PromptUpdateDetails,
find_mm_placeholders, find_mm_placeholders,
replace_token_matches) replace_token_matches)
# yapf: enable # yapf: enable
@ -337,14 +339,10 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
def _apply_token_matches( def _apply_token_matches(
self, self,
prompt: list[int], prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_prompt_updates: MultiModalPromptUpdates,
mm_item_counts: Mapping[str, int], ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
) -> list[int]: token_ids, res = super()._apply_token_matches(prompt,
token_ids = super()._apply_token_matches( mm_prompt_updates)
prompt,
mm_matches,
mm_item_counts,
)
# "\n\n\n" and "\n\n\n\n" are single tokens # "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n" # Since our replacement can insert "\n\n" next to "\n"
@ -373,13 +371,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
[newline_4], [newline_4],
) )
return token_ids return token_ids, res
def _find_mm_placeholders( def _find_mm_placeholders(
self, self,
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int], new_token_ids: list[int],
mm_item_counts: Mapping[str, int], mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
@ -404,8 +401,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
repl_token_ids.extend(repl_toks) repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates)
mm_item_counts)
return { return {
modality: [ modality: [

View File

@ -29,10 +29,12 @@ from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate, BaseProcessingInfo,
MultiModalPromptUpdates,
MultiModalPromptUpdatesApplyResult,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch, PromptReplacement, PromptUpdate,
PromptUpdate, PromptUpdateDetails, PromptUpdateDetails,
find_mm_placeholders, find_mm_placeholders,
replace_token_matches) replace_token_matches)
# yapf: enable # yapf: enable
@ -254,14 +256,10 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
def _apply_token_matches( def _apply_token_matches(
self, self,
prompt: list[int], prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_prompt_updates: MultiModalPromptUpdates,
mm_item_counts: Mapping[str, int], ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
) -> list[int]: token_ids, res = super()._apply_token_matches(prompt,
token_ids = super()._apply_token_matches( mm_prompt_updates)
prompt,
mm_matches,
mm_item_counts,
)
# "\n\n\n" and "\n\n\n\n" are single tokens # "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n" # Since our replacement can insert "\n\n" next to "\n"
@ -290,13 +288,12 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
[newline_4], [newline_4],
) )
return token_ids return token_ids, res
def _find_mm_placeholders( def _find_mm_placeholders(
self, self,
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int], new_token_ids: list[int],
mm_item_counts: Mapping[str, int], mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
@ -321,8 +318,7 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
repl_token_ids.extend(repl_toks) repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates)
mm_item_counts)
return { return {
modality: [ modality: [

View File

@ -828,26 +828,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
target=[image_token_id] * num_image_tokens, target=[image_token_id] * num_image_tokens,
replacement=get_replacement_mantis, replacement=get_replacement_mantis,
) )
]) ], mm_item_counts)
prompt_ids, prompt, _ = self._apply_prompt_updates( prompt_ids, prompt, _ = self._apply_prompt_updates(
result["prompt_token_ids"], result["prompt_token_ids"],
mantis_mm_repls, mantis_mm_repls,
mm_item_counts,
) )
unbound_orig_repls = self._get_prompt_updates( orig_repls = self._get_mm_prompt_updates(
mm_items, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
mm_kwargs, mm_kwargs,
) )
orig_repls = self._bind_and_group_updates(unbound_orig_repls) mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
mm_placeholders = self._find_mm_placeholders(
orig_repls,
prompt_ids,
mm_item_counts,
)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts) self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
mm_placeholder_ranges = { mm_placeholder_ranges = {

View File

@ -38,7 +38,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate, BaseProcessingInfo,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
# yapf: enable # yapf: enable
@ -431,24 +432,21 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
return [_IMAGE_TOKEN_ID] * num_image_tokens return [_IMAGE_TOKEN_ID] * num_image_tokens
num_images = mm_items.get_count("image", strict=False)
return [ return [
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=image_token, target=image_tokens.__getitem__,
replacement=get_replacement_phi3v, replacement=get_replacement_phi3v,
) for image_token in image_tokens[:num_images] )
] ]
def _apply_prompt_updates( def _apply_prompt_updates(
self, self,
token_ids: list[int], token_ids: list[int],
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], mm_prompt_updates: MultiModalPromptUpdates,
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
# align to hf behavior when there are images # align to hf behavior when there are images
if len(mm_item_counts): if len(mm_prompt_updates):
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
# to decode token_ids to the original text, we need to # to decode token_ids to the original text, we need to
# 1. remove the first bos token # 1. remove the first bos token
@ -484,7 +482,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
token_ids, text, placeholders = super()._apply_prompt_updates( token_ids, text, placeholders = super()._apply_prompt_updates(
token_ids=token_ids, token_ids=token_ids,
mm_prompt_updates=mm_prompt_updates, mm_prompt_updates=mm_prompt_updates,
mm_item_counts=mm_item_counts,
) )
# Keep the behavior in line with HF processor # Keep the behavior in line with HF processor

View File

@ -1032,8 +1032,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
out_mm_kwargs: MultiModalKwargsItems, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
image_token_id = tokenizer.vocab[tokenizer.image_token] image_token_id: int = tokenizer.vocab[tokenizer.image_token]
audio_token_id = tokenizer.vocab[tokenizer.audio_token] audio_token_id: int = tokenizer.vocab[tokenizer.audio_token]
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
audio_processor = self.info.get_feature_extractor( audio_processor = self.info.get_feature_extractor(
@ -1053,9 +1053,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
processor=hf_processor, processor=hf_processor,
) )
image_tokens = [image_token_id] * num_image_tokens return [image_token_id] * num_image_tokens
return image_tokens
def get_audio_replacement_phi4mm(item_idx: int): def get_audio_replacement_phi4mm(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems) audios = mm_items.get_items("audio", AudioProcessorItems)
@ -1066,9 +1064,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
audio_embed_size = self.info._compute_audio_embed_size( audio_embed_size = self.info._compute_audio_embed_size(
audio_frames) audio_frames)
audio_tokens = [audio_token_id] * audio_embed_size return [audio_token_id] * audio_embed_size
return audio_tokens
return [ return [
PromptReplacement( PromptReplacement(

View File

@ -824,9 +824,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
processor=hf_processor, processor=hf_processor,
) )
image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens return [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens
return image_tokens
def get_audio_replacement_phi4mm(item_idx: int): def get_audio_replacement_phi4mm(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems) audios = mm_items.get_items("audio", AudioProcessorItems)
@ -837,28 +835,20 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
audio_embed_size = self.info._compute_audio_embed_size( audio_embed_size = self.info._compute_audio_embed_size(
audio_frames) audio_frames)
audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size return [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
return audio_tokens return [
num_images = mm_items.get_count("image", strict=False)
num_audios = mm_items.get_count("audio", strict=False)
image_repl = [
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=image_token, target=image_tokens.__getitem__,
replacement=get_image_replacement_phi4mm, replacement=get_image_replacement_phi4mm,
) for image_token in image_tokens[:num_images] ),
]
audio_repl = [
PromptReplacement( PromptReplacement(
modality="audio", modality="audio",
target=audio_token, target=audio_tokens.__getitem__,
replacement=get_audio_replacement_phi4mm, replacement=get_audio_replacement_phi4mm,
) for audio_token in audio_tokens[:num_audios] ),
] ]
return image_repl + audio_repl
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(

View File

@ -309,9 +309,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
if is_update_applied: if is_update_applied:
mm_placeholders = self._find_mm_placeholders( mm_placeholders = self._find_mm_placeholders(
mm_prompt_updates,
prompt_ids, prompt_ids,
mm_item_counts, mm_prompt_updates,
) )
self._validate_mm_placeholders( self._validate_mm_placeholders(
mm_placeholders, mm_placeholders,
@ -328,7 +327,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
) = self._apply_prompt_updates( ) = self._apply_prompt_updates(
prompt_ids, prompt_ids,
mm_prompt_updates, mm_prompt_updates,
mm_item_counts,
) )
self._validate_mm_placeholders( self._validate_mm_placeholders(
mm_placeholders, mm_placeholders,

View File

@ -44,10 +44,21 @@ PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text.""" """A token sequence (list of token IDs) or text."""
class _GetMatchIndex(Protocol):
def __call__(
self,
tokenizer: AnyTokenizer,
prompt: PromptSeq,
start_idx: int = 0,
) -> Optional[int]:
...
@dataclass @dataclass
class PromptIndex: class PromptIndex:
"""Resolves to an index in the prompt.""" """Resolves to an index in the prompt."""
get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]] get_match_index: _GetMatchIndex
class PromptIndexTargets: class PromptIndexTargets:
@ -59,7 +70,7 @@ class PromptIndexTargets:
This results in a match even if the prompt is empty. This results in a match even if the prompt is empty.
""" """
return PromptIndex(lambda tok, prompt: 0) return PromptIndex(lambda tokenizer, prompt, start_idx=0: 0)
@staticmethod @staticmethod
def prefix(seq: PromptSeq) -> PromptIndex: def prefix(seq: PromptSeq) -> PromptIndex:
@ -70,7 +81,11 @@ class PromptIndexTargets:
def get_match_index( def get_match_index(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
prompt: PromptSeq, prompt: PromptSeq,
start_idx: int = 0,
) -> Optional[int]: ) -> Optional[int]:
if start_idx != 0:
return None
prefix = seq prefix = seq
if isinstance(prompt, str): if isinstance(prompt, str):
@ -96,14 +111,24 @@ class PromptIndexTargets:
This results in a match even if the prompt is empty. This results in a match even if the prompt is empty.
""" """
return PromptIndex(lambda tok, prompt: len(prompt)) return PromptIndex(lambda tokenizer, prompt, start_idx=0: len(prompt))
PromptTarget = Union[PromptSeq, PromptIndex] UpdateTarget = Union[PromptSeq, PromptIndex]
""" """
The token sequence or text to update. The token sequence or text to update.
""" """
PromptUpdateTarget = Union[Callable[[int], UpdateTarget], UpdateTarget]
"""
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the corresponding token sequence (or text).
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""
@dataclass @dataclass
class PromptUpdateDetails(Generic[_S]): class PromptUpdateDetails(Generic[_S]):
@ -190,7 +215,7 @@ class PromptUpdate(ABC):
modality: str modality: str
"""The modality for which the update is made.""" """The modality for which the update is made."""
target: PromptTarget target: PromptUpdateTarget
"""The token sequence (or text) to update.""" """The token sequence (or text) to update."""
@property @property
@ -205,10 +230,54 @@ class PromptUpdate(ABC):
"""Defines how to update the prompt.""" """Defines how to update the prompt."""
raise NotImplementedError raise NotImplementedError
def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate": def _resolve_target(
return BoundPromptUpdate( self,
_origin=self, tokenizer: AnyTokenizer,
tokenizer=tokenizer, item_idx: int,
) -> Union["_BoundPromptSequence", PromptIndex]:
target = self.target
if callable(target):
target = target(item_idx)
if isinstance(target, PromptIndex):
return target
return _BoundPromptSequence.from_seq(tokenizer, target)
def _resolve_content(
self,
tokenizer: AnyTokenizer,
item_idx: int,
) -> "_BoundPromptContent":
content = self.content
if callable(content):
content = content(item_idx)
if not isinstance(content, PromptUpdateDetails):
content = PromptUpdateDetails.from_seq(content)
bound_full = _BoundPromptSequence.from_seq(tokenizer, content.full)
bound_content = _BoundPromptContent(full=bound_full,
is_embed=content.is_embed)
return bound_content
def resolve(
self,
tokenizer: AnyTokenizer,
item_idx: int,
) -> "ResolvedPromptUpdate":
"""
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output a copy of this object with its lazy attributes resolved.
"""
return ResolvedPromptUpdate(
modality=self.modality,
item_idx=item_idx,
mode=self.mode,
target=self._resolve_target(tokenizer, item_idx),
content=self._resolve_content(tokenizer, item_idx),
) )
@ -452,73 +521,90 @@ class _BoundPromptContent:
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
@dataclass class PromptTargetMatch(NamedTuple):
class BoundPromptUpdate: start_idx: int
end_idx: int
@dataclass(frozen=True)
class ResolvedPromptUpdate:
""" """
A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] bound A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] with its
to a tokenizer to automatically convert lazy attributes resolved, apart from those related to tokenization.
[`target`][vllm.multimodal.processing.PromptUpdate.target] and the result of
[`get_content`][vllm.multimodal.processing.BoundPromptUpdate.get_content]
between token sequence and text representations.
""" """
_origin: PromptUpdate
tokenizer: AnyTokenizer = field(repr=False)
def __post_init__(self) -> None: modality: str
self._content_cache = dict[int, _BoundPromptContent]() """The modality for which the update is made."""
@property item_idx: int
def modality(self) -> str: """The index within `modality` of the item this update pertains to."""
return self._origin.modality
@property mode: UpdateMode
def target(self) -> Union[_BoundPromptSequence, PromptIndex]: """Defines how to update the prompt."""
"""The token sequence (or text) to update."""
target = self._origin.target target: Union[_BoundPromptSequence, PromptIndex]
"""The token sequence (or text) to update."""
content: _BoundPromptContent = field(repr=False)
"""The placeholder tokens that are part of the update."""
def iter_token_matches(
self,
prompt: list[int],
tokenizer: AnyTokenizer,
*,
start_idx: int = 0,
) -> Generator[PromptTargetMatch]:
"""Yield each instance of `self.target` found in `prompt`."""
target = self.target
if isinstance(target, PromptIndex): if isinstance(target, PromptIndex):
return target match_idx = target.get_match_index(tokenizer, prompt, start_idx)
if match_idx is not None:
yield PromptTargetMatch(match_idx, match_idx)
return _BoundPromptSequence.from_seq(self.tokenizer, target) return
@property for match in iter_token_matches(prompt,
def content(self) -> PromptUpdateContent: target.token_ids,
"""The placeholder tokens that are part of the update.""" start_idx=start_idx):
return self._origin.content yield PromptTargetMatch(match.start_idx, match.end_idx)
@property def iter_text_matches(
def mode(self) -> UpdateMode: self,
"""Defines how to update the prompt.""" prompt: str,
return self._origin.mode tokenizer: AnyTokenizer,
*,
start_idx: int = 0,
) -> Generator[PromptTargetMatch]:
"""Yield each instance of `self.target` found in `prompt`."""
target = self.target
def get_content(self, item_idx: int) -> _BoundPromptContent: if isinstance(target, PromptIndex):
""" match_idx = target.get_match_index(tokenizer, prompt, start_idx)
Given the index of the processed item within if match_idx is not None:
[`modality`][vllm.multimodal.processing.PromptUpdate.modality], yield PromptTargetMatch(match_idx, match_idx)
output the token sequence (or text) to update.
"""
content = self.content
if callable(content):
cache_key = item_idx
if cache_key in self._content_cache:
return self._content_cache[cache_key]
content = content(item_idx) return
else:
cache_key = None
if not isinstance(content, PromptUpdateDetails): for match in re.finditer(re.escape(target.text), prompt,
content = PromptUpdateDetails.from_seq(content) pos=start_idx):
yield PromptTargetMatch(match.start(), match.end())
bound_full = _BoundPromptSequence.from_seq(self.tokenizer, def iter_matches(
content.full) self,
bound_content = _BoundPromptContent(full=bound_full, prompt: Union[list[int], str],
is_embed=content.is_embed) tokenizer: AnyTokenizer,
*,
start_idx: int = 0,
) -> Generator[PromptTargetMatch]:
"""Yield each instance of `self.target` found in `prompt`."""
if isinstance(prompt, str):
return self.iter_text_matches(prompt,
tokenizer,
start_idx=start_idx)
if cache_key is not None: return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx)
self._content_cache[cache_key] = bound_content
return bound_content
class _TokenMatch(NamedTuple): class _TokenMatch(NamedTuple):
@ -529,6 +615,8 @@ class _TokenMatch(NamedTuple):
def iter_token_matches( def iter_token_matches(
token_ids: list[int], token_ids: list[int],
match_ids: list[int], match_ids: list[int],
*,
start_idx: int = 0,
) -> Generator[_TokenMatch]: ) -> Generator[_TokenMatch]:
""" """
Yield each occurrence of `match_ids` in `token_ids`. Yield each occurrence of `match_ids` in `token_ids`.
@ -541,7 +629,6 @@ def iter_token_matches(
if match_len == 0: if match_len == 0:
return return
start_idx = 0
while start_idx < prompt_len - match_len + 1: while start_idx < prompt_len - match_len + 1:
end_idx = start_idx + match_len end_idx = start_idx + match_len
@ -581,68 +668,6 @@ def replace_token_matches(
return flatten_2d_lists(out_seqs) return flatten_2d_lists(out_seqs)
@dataclass(repr=False)
class PromptTargetMatch(ABC):
_origin: BoundPromptUpdate
@property
def modality(self) -> str:
return self._origin.modality
@property
@abstractmethod
def start_idx(self) -> int:
raise NotImplementedError
@property
@abstractmethod
def end_idx(self) -> int:
raise NotImplementedError
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r}, "
f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
@dataclass(repr=False)
class _PromptTargetIndexMatch(PromptTargetMatch):
match_idx: int
@property
def start_idx(self) -> int:
return self.match_idx
@property
def end_idx(self) -> int:
return self.match_idx
@dataclass(repr=False)
class _PromptTargetTokenMatch(PromptTargetMatch):
match: _TokenMatch
@property
def start_idx(self) -> int:
return self.match.start_idx
@property
def end_idx(self) -> int:
return self.match.end_idx
@dataclass(repr=False)
class _PromptTargetTextMatch(PromptTargetMatch):
match: re.Match[str]
@property
def start_idx(self) -> int:
return self.match.start()
@property
def end_idx(self) -> int:
return self.match.end()
@dataclass @dataclass
class PlaceholderFeaturesInfo: class PlaceholderFeaturesInfo:
modality: str modality: str
@ -665,163 +690,158 @@ class PlaceholderFeaturesInfo:
) )
def find_token_matches( _MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
prompt: list[int],
prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[PromptTargetMatch]:
"""Return each target of `prompt_updates` found in `prompt`."""
def get_matches(update: BoundPromptUpdate):
target = update.target
if isinstance(target, PromptIndex):
match_idx = target.get_match_index(update.tokenizer, prompt)
if match_idx is None:
return []
return [_PromptTargetIndexMatch(update, match_idx)]
return [
_PromptTargetTokenMatch(update, match)
for match in iter_token_matches(prompt, target.token_ids)
]
return [
match for update in prompt_updates for match in get_matches(update)
]
def find_text_matches( def _find_matches(
prompt: str, prompt: _S,
prompt_updates: Sequence[BoundPromptUpdate], mm_prompt_updates: "MultiModalPromptUpdates",
) -> Sequence[PromptTargetMatch]: tokenizer: AnyTokenizer,
"""Return each target of `prompt_updates` found in `prompt`.""" *,
prev_end_idx: int = 0,
current_result: "MultiModalPromptUpdatesApplyResult",
) -> tuple[Optional[UpdateMode], list[_MatchToApply]]:
mode: Optional[UpdateMode] = None
mm_matches = dict[tuple[str, int], tuple[PromptTargetMatch, int]]()
def get_matches(update: BoundPromptUpdate): for modality, modality_updates in mm_prompt_updates.items():
target = update.target for item_idx, item_updates in enumerate(modality_updates):
if current_result[modality][item_idx] is not None:
continue # Updates have already been applied for this item
if isinstance(target, PromptIndex): for update_idx, update in enumerate(item_updates):
match_idx = target.get_match_index(update.tokenizer, prompt) if (modality, item_idx) in mm_matches:
if match_idx is None: break # Already found a match for this item
return []
return [_PromptTargetIndexMatch(update, match_idx)] for match in update.iter_matches(
prompt,
tokenizer,
start_idx=prev_end_idx,
):
# All matches should share the same mode
if mode is None:
mode = update.mode
elif mode != update.mode:
continue
return [ mm_matches[(modality, item_idx)] = match, update_idx
_PromptTargetTextMatch(update, match) break # Get only the first valid match per item
for match in re.finditer(re.escape(target.text), prompt)
]
return [ # Prioritize earlier matches
match for update in prompt_updates for match in get_matches(update) matches_to_apply = sorted(mm_matches.items(), key=lambda item: item[1][0])
]
# To avoid conflicts, only replace one non-empty item at a time
if mode == UpdateMode.REPLACE:
matches_to_apply_ = list[_MatchToApply]()
has_non_empty_matches = False
def _resolve_matches( for item in matches_to_apply:
prompt: PromptSeq, _, (match, _) = item
mm_matches: Mapping[str, Sequence[PromptTargetMatch]], if match.start_idx == match.end_idx:
) -> list[PromptTargetMatch]: matches_to_apply_.append(item)
""" elif not has_non_empty_matches:
Resolve `mm_matches` to ensure that there are no overlapping matches, has_non_empty_matches = True
and sort them such that earlier matches take priority over later ones. matches_to_apply_.append(item)
"""
matches = [m for matches in mm_matches.values() for m in matches]
seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt) matches_to_apply = matches_to_apply_
for match in matches: return mode, matches_to_apply
for idx in range(match.start_idx, match.end_idx):
if seen_matches[idx] is not None:
raise ValueError("Found overlapping matches "
f"({seen_matches[idx]} and {match}) "
f"at index={idx} of prompt={prompt}")
seen_matches[idx] = match
return sorted(matches, key=lambda x: x.start_idx)
def _apply_matches( def _apply_matches(
prompt: _S, prompt: _S,
mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_prompt_updates: "MultiModalPromptUpdates",
mm_item_counts: Mapping[str, int], tokenizer: AnyTokenizer,
) -> list[_S]: ) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
"""Apply the updates in `mm_matches` to `prompt`.""" prompt_len = len(prompt)
out_seqs = list[Union[str, list[int]]]() out_seqs = list[Union[str, list[int]]]()
prev_end_idx = 0 out_result: MultiModalPromptUpdatesApplyResult = {
next_idx_by_modality = defaultdict[str, int](lambda: 0) m: [None] * len(items)
for m, items in mm_prompt_updates.items()
}
for match in _resolve_matches(prompt, mm_matches): start_idx = prev_end_idx = 0
modality = match.modality while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt
found = False
item_start_idx = next_idx_by_modality[modality] mode, matches_to_apply = _find_matches(
max_item_count = mm_item_counts.get(modality, 0) prompt,
if item_start_idx >= max_item_count: mm_prompt_updates,
continue tokenizer,
prev_end_idx=prev_end_idx,
current_result=out_result,
)
start_idx = match.start_idx if mode is not None:
end_idx = match.end_idx for (modality, item_idx), (match, update_idx) in matches_to_apply:
origin = match._origin found = True
mode = origin.mode
if mode == UpdateMode.INSERT: matched_update = mm_prompt_updates[modality][item_idx][
out_seqs.append(prompt[prev_end_idx:end_idx]) update_idx]
num_inserts = max_item_count matched_content = matched_update.content
elif mode == UpdateMode.REPLACE:
out_seqs.append(prompt[prev_end_idx:start_idx])
num_inserts = max_item_count if start_idx == end_idx else 1
else:
assert_never(mode)
item_end_idx = min(item_start_idx + num_inserts, max_item_count) if mode == UpdateMode.INSERT:
end_idx_to_insert = match.end_idx
elif mode == UpdateMode.REPLACE:
end_idx_to_insert = match.start_idx
else:
assert_never(mode)
for item_idx in range(item_start_idx, item_end_idx): out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
content = origin.get_content(item_idx) out_seqs.append(matched_content.full.text if isinstance(
insert_seq = (content.full.text if isinstance(prompt, str) else prompt, str) else matched_content.full.token_ids)
content.full.token_ids) out_result[modality][item_idx] = update_idx
out_seqs.append(insert_seq) # Exclude overlapping matches
start_idx = prev_end_idx = match.end_idx
prev_end_idx = end_idx if not found:
next_idx_by_modality[modality] += item_end_idx - item_start_idx start_idx += 1
out_seqs.append(prompt[prev_end_idx:]) out_seqs.append(prompt[prev_end_idx:])
return cast(list[_S], out_seqs) return cast(list[_S], out_seqs), out_result
def apply_token_matches( def apply_token_matches(
prompt: list[int], prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_prompt_updates: "MultiModalPromptUpdates",
mm_item_counts: Mapping[str, int], tokenizer: AnyTokenizer,
) -> list[int]: ) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
"""Apply the updates in `mm_matches` to `prompt`.""" """
if not mm_matches: Apply the updates in `mm_prompt_updates` to `prompt`.
return prompt
token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts) Matches are exclusive even when multiple modalities share
the same placeholder tokens. In that case, the modality that
appears earlier in `mm_prompt_updates` takes priority.
"""
token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates,
tokenizer)
return flatten_2d_lists(token_id_seqs) return flatten_2d_lists(token_id_seqs), result
def apply_text_matches( def apply_text_matches(
prompt: str, prompt: str,
mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_prompt_updates: "MultiModalPromptUpdates",
mm_item_counts: Mapping[str, int], tokenizer: AnyTokenizer,
) -> str: ) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
"""Apply the updates in `mm_matches` to `prompt`.""" """
if not mm_matches: Apply the updates in `mm_prompt_updates` to `prompt`.
return prompt
texts = _apply_matches(prompt, mm_matches, mm_item_counts) Matches are exclusive even when multiple modalities share
the same placeholder tokens. In that case, the modality that
appears earlier in `mm_prompt_updates` takes priority.
"""
texts, result = _apply_matches(prompt, mm_prompt_updates, tokenizer)
return "".join(texts) return "".join(texts), result
def _iter_placeholders( def _iter_placeholders(
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
prompt: list[int], prompt: list[int],
mm_item_counts: Mapping[str, int], mm_prompt_updates: "MultiModalPromptUpdates",
) -> Iterable[PlaceholderFeaturesInfo]: ) -> Iterable[PlaceholderFeaturesInfo]:
""" """
Yield each set of placeholder tokens found in `prompt`. Yield each set of placeholder tokens found in `prompt`.
@ -833,6 +853,8 @@ def _iter_placeholders(
Note that empty matches are ignored. Note that empty matches are ignored.
""" """
prompt_len = len(prompt) prompt_len = len(prompt)
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
item_idx_by_modality = defaultdict[str, int](lambda: 0) item_idx_by_modality = defaultdict[str, int](lambda: 0)
start_idx = 0 start_idx = 0
@ -844,8 +866,8 @@ def _iter_placeholders(
if item_idx >= mm_item_counts.get(modality, 0): if item_idx >= mm_item_counts.get(modality, 0):
continue continue
for update_info in modality_updates: for update in modality_updates[item_idx]:
content = update_info.get_content(item_idx) content = update.content
content_tokens_full = content.full.token_ids content_tokens_full = content.full.token_ids
content_len_full = len(content_tokens_full) content_len_full = len(content_tokens_full)
end_idx_full = start_idx + content_len_full end_idx_full = start_idx + content_len_full
@ -880,11 +902,10 @@ def _iter_placeholders(
def find_mm_placeholders( def find_mm_placeholders(
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
prompt: list[int], prompt: list[int],
mm_item_counts: Mapping[str, int], mm_prompt_updates: "MultiModalPromptUpdates",
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts) it = _iter_placeholders(prompt, mm_prompt_updates)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
@ -989,12 +1010,20 @@ A collection of hashes with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
""" """
MultiModalPromptUpdates = dict[str, Sequence[BoundPromptUpdate]] MultiModalPromptUpdates = Mapping[str, list[Sequence[ResolvedPromptUpdate]]]
""" """
A collection of prompt updates with a similar structure as A collection of prompt updates with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
""" """
MultiModalPromptUpdatesApplyResult = Mapping[str, list[Optional[int]]]
"""
For an item `MultiModalPromptUpdates[k][i]`,
`MultiModalPromptUpdatesApplyResult[k][i]` represents the index of the
`ResolvedPromptUpdate` instance that has been applied, or `None` if none of the
`ResolvedPromptUpdate` instances have been applied.
"""
class MultiModalProcessingInfo(NamedTuple): class MultiModalProcessingInfo(NamedTuple):
kwargs: MultiModalKwargsItems kwargs: MultiModalKwargsItems
@ -1126,14 +1155,60 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
""" """
raise NotImplementedError raise NotImplementedError
def _bind_and_group_updates(
self,
prompt_updates: Sequence[PromptUpdate],
mm_item_counts: Mapping[str, int],
) -> MultiModalPromptUpdates:
tokenizer = self.info.get_tokenizer()
return {
modality:
[[update.resolve(tokenizer, item_idx) for update in updates]
for item_idx in range(mm_item_counts.get(modality, 0))]
for modality, updates in full_groupby_modality(prompt_updates)
}
def _get_mm_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> MultiModalPromptUpdates:
unbound_prompt_updates = self._get_prompt_updates(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
out_mm_kwargs=out_mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates,
mm_items.get_all_counts(),
)
for modality, prompt_updates in mm_prompt_updates.items():
for item_idx, item_prompt_updates in enumerate(prompt_updates):
if len(item_prompt_updates) > 1:
logger.warning_once(
"Detected %d prompt updates for `mm_items[%r][%s]`. "
"Multiple prompt updates per item is now "
"deprecated and may be removed in v0.13. "
"Instead, please specify dynamic update targets "
"in the same prompt update definition by passing "
"a function to `PromptUpdate.target`.",
len(prompt_updates),
modality,
item_idx,
)
return mm_prompt_updates
def _find_mm_placeholders( def _find_mm_placeholders(
self, self,
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int], new_token_ids: list[int],
mm_item_counts: Mapping[str, int], mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
return find_mm_placeholders(mm_prompt_updates, new_token_ids, return find_mm_placeholders(new_token_ids, mm_prompt_updates)
mm_item_counts)
def _get_hf_mm_data( def _get_hf_mm_data(
self, self,
@ -1421,13 +1496,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs) tokenization_kwargs)
unbound_prompt_updates = self._get_prompt_updates( mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items, mm_data_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
mm_kwargs, mm_kwargs,
) )
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_info = MultiModalProcessingInfo( mm_info = MultiModalProcessingInfo(
kwargs=mm_kwargs, kwargs=mm_kwargs,
@ -1497,13 +1570,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_kwargs=mm_missing_kwargs, mm_missing_kwargs=mm_missing_kwargs,
) )
unbound_prompt_updates = self._get_prompt_updates( mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items, mm_data_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
mm_kwargs, mm_kwargs,
) )
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_info = MultiModalProcessingInfo( mm_info = MultiModalProcessingInfo(
kwargs=mm_kwargs, kwargs=mm_kwargs,
@ -1513,47 +1584,33 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return prompt_ids, mm_info, is_update_applied return prompt_ids, mm_info, is_update_applied
def _bind_and_group_updates(
self,
prompt_updates: Sequence[PromptUpdate],
) -> dict[str, Sequence[BoundPromptUpdate]]:
tokenizer = self.info.get_tokenizer()
it = (update.bind(tokenizer) for update in prompt_updates)
return dict(full_groupby_modality(it))
def _apply_token_matches( def _apply_token_matches(
self, self,
prompt: list[int], prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_prompt_updates: MultiModalPromptUpdates,
mm_item_counts: Mapping[str, int], ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
) -> list[int]: tokenizer = self.info.get_tokenizer()
return apply_token_matches(prompt, mm_matches, mm_item_counts) return apply_token_matches(prompt, mm_prompt_updates, tokenizer)
def _apply_text_matches( def _apply_text_matches(
self, self,
prompt: str, prompt: str,
mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_prompt_updates: MultiModalPromptUpdates,
mm_item_counts: Mapping[str, int], ) -> tuple[str, MultiModalPromptUpdatesApplyResult]:
) -> str: tokenizer = self.info.get_tokenizer()
return apply_text_matches(prompt, mm_matches, mm_item_counts) return apply_text_matches(prompt, mm_prompt_updates, tokenizer)
def _apply_prompt_updates( def _apply_prompt_updates(
self, self,
token_ids: list[int], token_ids: list[int],
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], mm_prompt_updates: MultiModalPromptUpdates,
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
mm_token_matches = { new_token_ids, match_result = self._apply_token_matches(
modality: find_token_matches(token_ids, updates) token_ids,
for modality, updates in mm_prompt_updates.items() mm_prompt_updates,
} )
mm_match_counts = {
modality: len(matches)
for modality, matches in mm_token_matches.items()
}
# If the search text does not represent a special token, # If the search text does not represent a special token,
# it may have different token IDs in the prompt, because # it may have different token IDs in the prompt, because
@ -1566,48 +1623,38 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# of the search text in the prompt, we instead perform string-based # of the search text in the prompt, we instead perform string-based
# updates on the decoded token IDs, then encode them back. # updates on the decoded token IDs, then encode them back.
if all( if all(
mm_match_counts.get(modality, 0) >= item_count all(update_idx is not None for update_idx in update_idxs)
for modality, item_count in mm_item_counts.items() for update_idxs in match_result.values()):
): # yapf: disable new_text = decode_tokens(tokenizer, new_token_ids)
token_ids = self._apply_token_matches(
token_ids,
mm_token_matches,
mm_item_counts,
)
text = decode_tokens(tokenizer, token_ids)
matched_updates = {
modality: [match._origin for match in token_matches]
for modality, token_matches in mm_token_matches.items()
}
else: else:
text = decode_tokens(tokenizer, token_ids) new_text, match_result = self._apply_text_matches(
decode_tokens(tokenizer, token_ids),
mm_text_matches = { mm_prompt_updates,
modality: find_text_matches(text, updates)
for modality, updates in mm_prompt_updates.items()
}
text = self._apply_text_matches(
text,
mm_text_matches,
mm_item_counts,
) )
token_ids = encode_tokens(tokenizer, new_token_ids = encode_tokens(
text, tokenizer,
add_special_tokens=False) new_text,
matched_updates = { add_special_tokens=False,
modality: [match._origin for match in token_matches] )
for modality, token_matches in mm_text_matches.items()
} matched_updates = defaultdict[
str, list[Sequence[ResolvedPromptUpdate]]](list)
for modality, update_idxs in match_result.items():
for item_idx, update_idx in enumerate(update_idxs):
assert update_idx is not None, (
"Failed to apply prompt replacement for "
f"mm_items[{modality!r}][{item_idx}]")
matched_updates[modality].append(
[mm_prompt_updates[modality][item_idx][update_idx]])
placeholders = self._find_mm_placeholders( placeholders = self._find_mm_placeholders(
matched_updates, new_token_ids,
token_ids, dict(matched_updates),
mm_item_counts,
) )
return token_ids, text, placeholders return new_token_ids, new_text, placeholders
def _validate_mm_kwargs( def _validate_mm_kwargs(
self, self,
@ -1661,9 +1708,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
if is_update_applied: if is_update_applied:
mm_placeholders = self._find_mm_placeholders( mm_placeholders = self._find_mm_placeholders(
mm_prompt_updates,
prompt_ids, prompt_ids,
mm_item_counts, mm_prompt_updates,
) )
self._validate_mm_placeholders(mm_placeholders, mm_item_counts) self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
@ -1677,7 +1723,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) = self._apply_prompt_updates( ) = self._apply_prompt_updates(
prompt_ids, prompt_ids,
mm_prompt_updates, mm_prompt_updates,
mm_item_counts,
) )
self._validate_mm_placeholders(mm_placeholders, mm_item_counts) self._validate_mm_placeholders(mm_placeholders, mm_item_counts)