mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 09:36:00 +08:00
[Refactor] Dynamic target and content for prompt updates (#23411)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
49ab23b3cc
commit
712d0f88d8
@ -17,13 +17,11 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
|||||||
PromptReplacement, apply_text_matches,
|
PromptReplacement, apply_text_matches,
|
||||||
apply_token_matches,
|
apply_token_matches,
|
||||||
find_mm_placeholders,
|
find_mm_placeholders,
|
||||||
find_text_matches, find_token_matches,
|
|
||||||
iter_token_matches,
|
iter_token_matches,
|
||||||
replace_token_matches)
|
replace_token_matches)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.multimodal.profiling import MultiModalProfiler
|
from vllm.multimodal.profiling import MultiModalProfiler
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import full_groupby
|
|
||||||
|
|
||||||
from .utils import random_image
|
from .utils import random_image
|
||||||
|
|
||||||
@ -75,12 +73,15 @@ from .utils import random_image
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize("start_idx", [0, 4, 8])
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
def test_iter_token_matches(token_ids, match_ids, expected):
|
def test_iter_token_matches(token_ids, match_ids, expected, start_idx):
|
||||||
result = list(iter_token_matches(token_ids, match_ids))
|
result = list(iter_token_matches(token_ids, match_ids,
|
||||||
|
start_idx=start_idx))
|
||||||
|
|
||||||
# Manually constructed results
|
# Manually constructed results
|
||||||
assert [item._asdict() for item in result] == expected
|
assert [item._asdict() for item in result
|
||||||
|
] == [item for item in expected if item["start_idx"] >= start_idx]
|
||||||
|
|
||||||
# Invariants
|
# Invariants
|
||||||
match_lens = [end - start for start, end in result]
|
match_lens = [end - start for start, end in result]
|
||||||
@ -241,21 +242,23 @@ def test_find_token_matches(
|
|||||||
# Should not be used since there is nothing to convert to token IDs
|
# Should not be used since there is nothing to convert to token IDs
|
||||||
mock_tokenizer = cast(AnyTokenizer, object())
|
mock_tokenizer = cast(AnyTokenizer, object())
|
||||||
|
|
||||||
prompt_updates = [
|
prompt_updates = {
|
||||||
update_type(key, target, []).bind(mock_tokenizer)
|
key: update_type(key, target, []).resolve(mock_tokenizer, 0)
|
||||||
for key, target in target_by_key.items()
|
for key, target in target_by_key.items()
|
||||||
]
|
}
|
||||||
result = find_token_matches(prompt, prompt_updates)
|
result = {
|
||||||
|
key: list(update.iter_token_matches(prompt, mock_tokenizer))
|
||||||
|
for key, update in prompt_updates.items()
|
||||||
|
}
|
||||||
|
|
||||||
# Only displayed on error
|
# Only displayed on error
|
||||||
print("result:", result)
|
print("result:", result)
|
||||||
|
|
||||||
# Manually constructed results
|
# Manually constructed results
|
||||||
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
|
|
||||||
assert {
|
assert {
|
||||||
key: [
|
key: [
|
||||||
dict(start_idx=item.start_idx, end_idx=item.end_idx)
|
dict(start_idx=item.start_idx, end_idx=item.end_idx)
|
||||||
for item in result_groups.get(key, [])
|
for item in result.get(key, [])
|
||||||
]
|
]
|
||||||
for key in expected_by_key
|
for key in expected_by_key
|
||||||
} == expected_by_key
|
} == expected_by_key
|
||||||
@ -388,21 +391,23 @@ def test_find_text_matches(
|
|||||||
# Should not be used since there is nothing to convert to text
|
# Should not be used since there is nothing to convert to text
|
||||||
mock_tokenizer = cast(AnyTokenizer, object())
|
mock_tokenizer = cast(AnyTokenizer, object())
|
||||||
|
|
||||||
prompt_updates = [
|
prompt_updates = {
|
||||||
update_type(key, target, []).bind(mock_tokenizer)
|
key: update_type(key, target, []).resolve(mock_tokenizer, 0)
|
||||||
for key, target in target_by_key.items()
|
for key, target in target_by_key.items()
|
||||||
]
|
}
|
||||||
result = find_text_matches(prompt, prompt_updates)
|
result = {
|
||||||
|
key: list(update.iter_text_matches(prompt, mock_tokenizer))
|
||||||
|
for key, update in prompt_updates.items()
|
||||||
|
}
|
||||||
|
|
||||||
# Only displayed on error
|
# Only displayed on error
|
||||||
print("result:", result)
|
print("result:", result)
|
||||||
|
|
||||||
# Manually constructed results
|
# Manually constructed results
|
||||||
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
|
|
||||||
assert {
|
assert {
|
||||||
key: [
|
key: [
|
||||||
dict(start_idx=item.start_idx, end_idx=item.end_idx)
|
dict(start_idx=item.start_idx, end_idx=item.end_idx)
|
||||||
for item in result_groups.get(key, [])
|
for item in result.get(key, [])
|
||||||
]
|
]
|
||||||
for key in expected_by_key
|
for key in expected_by_key
|
||||||
} == expected_by_key
|
} == expected_by_key
|
||||||
@ -552,39 +557,37 @@ def test_find_update_text(
|
|||||||
update_type,
|
update_type,
|
||||||
expected_by_mm_count,
|
expected_by_mm_count,
|
||||||
) in expected_by_update_type_mm_count.items():
|
) in expected_by_update_type_mm_count.items():
|
||||||
|
for mm_count, expected in expected_by_mm_count.items():
|
||||||
mm_prompt_updates = {
|
mm_prompt_updates = {
|
||||||
key:
|
key: [[
|
||||||
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
|
update_type(key, target,
|
||||||
|
repl_by_key[key]).resolve(mock_tokenizer, i)
|
||||||
|
] for i in range(mm_count)]
|
||||||
for key, target in target_by_key.items()
|
for key, target in target_by_key.items()
|
||||||
}
|
}
|
||||||
mm_matches = {
|
|
||||||
key: find_text_matches(prompt, updates)
|
|
||||||
for key, updates in mm_prompt_updates.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
for mm_count, expected in expected_by_mm_count.items():
|
new_prompt, result = apply_text_matches(
|
||||||
result = apply_text_matches(
|
|
||||||
prompt,
|
prompt,
|
||||||
mm_matches,
|
mm_prompt_updates,
|
||||||
{key: mm_count
|
mock_tokenizer,
|
||||||
for key in repl_by_key},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only displayed on error
|
# Only displayed on error
|
||||||
print("update_type:", update_type)
|
print("update_type:", update_type)
|
||||||
print("mm_count:", mm_count)
|
print("mm_count:", mm_count)
|
||||||
print("mm_matches:", mm_matches)
|
print("mm_prompt_updates:", mm_prompt_updates)
|
||||||
|
print("new_prompt:", new_prompt)
|
||||||
print("result:", result)
|
print("result:", result)
|
||||||
|
|
||||||
# Manually constructed results
|
# Manually constructed results
|
||||||
assert result == expected
|
assert new_prompt == expected
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
|
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
|
||||||
[
|
[
|
||||||
# Tokenized test cases of `test_find_replace_text`
|
# Tokenized test cases of `test_find_update_text`
|
||||||
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
|
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
|
||||||
(
|
(
|
||||||
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
@ -726,32 +729,30 @@ def test_find_update_tokens(
|
|||||||
update_type,
|
update_type,
|
||||||
expected_by_mm_count,
|
expected_by_mm_count,
|
||||||
) in expected_by_update_type_mm_count.items():
|
) in expected_by_update_type_mm_count.items():
|
||||||
|
for mm_count, expected in expected_by_mm_count.items():
|
||||||
mm_prompt_updates = {
|
mm_prompt_updates = {
|
||||||
key:
|
key: [[
|
||||||
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
|
update_type(key, target,
|
||||||
|
repl_by_key[key]).resolve(mock_tokenizer, i)
|
||||||
|
] for i in range(mm_count)]
|
||||||
for key, target in target_by_key.items()
|
for key, target in target_by_key.items()
|
||||||
}
|
}
|
||||||
mm_matches = {
|
|
||||||
key: find_token_matches(prompt, updates)
|
|
||||||
for key, updates in mm_prompt_updates.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
for mm_count, expected in expected_by_mm_count.items():
|
new_prompt, result = apply_token_matches(
|
||||||
result = apply_token_matches(
|
|
||||||
prompt,
|
prompt,
|
||||||
mm_matches,
|
mm_prompt_updates,
|
||||||
{key: mm_count
|
mock_tokenizer,
|
||||||
for key in repl_by_key},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only displayed on error
|
# Only displayed on error
|
||||||
print("update_type:", update_type)
|
print("update_type:", update_type)
|
||||||
print("mm_count:", mm_count)
|
print("mm_count:", mm_count)
|
||||||
print("mm_matches:", mm_matches)
|
print("mm_prompt_updates:", mm_prompt_updates)
|
||||||
|
print("new_prompt:", new_prompt)
|
||||||
print("result:", result)
|
print("result:", result)
|
||||||
|
|
||||||
# Manually constructed results
|
# Manually constructed results
|
||||||
assert result == expected
|
assert new_prompt == expected
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -878,17 +879,12 @@ def test_find_mm_placeholders(
|
|||||||
mock_tokenizer = cast(AnyTokenizer, object())
|
mock_tokenizer = cast(AnyTokenizer, object())
|
||||||
|
|
||||||
mm_prompt_updates = {
|
mm_prompt_updates = {
|
||||||
key: [update_type(key, [], repl).bind(mock_tokenizer)]
|
key: [[update_type(key, [], repl).resolve(mock_tokenizer, i)]
|
||||||
|
for i in range(3)]
|
||||||
for key, repl in repl_by_key.items()
|
for key, repl in repl_by_key.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
result = find_mm_placeholders(
|
result = find_mm_placeholders(prompt, mm_prompt_updates)
|
||||||
mm_prompt_updates,
|
|
||||||
prompt,
|
|
||||||
# Effectively match all occurrences in the prompt
|
|
||||||
{key: 3
|
|
||||||
for key in repl_by_key},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only displayed on error
|
# Only displayed on error
|
||||||
print("result:", result)
|
print("result:", result)
|
||||||
|
|||||||
@ -22,10 +22,12 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
|||||||
MultiModalDataItems)
|
MultiModalDataItems)
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, BoundPromptUpdate,
|
BaseProcessingInfo,
|
||||||
|
MultiModalPromptUpdates,
|
||||||
|
MultiModalPromptUpdatesApplyResult,
|
||||||
PlaceholderFeaturesInfo,
|
PlaceholderFeaturesInfo,
|
||||||
PromptReplacement, PromptTargetMatch,
|
PromptReplacement, PromptUpdate,
|
||||||
PromptUpdate, PromptUpdateDetails,
|
PromptUpdateDetails,
|
||||||
find_mm_placeholders,
|
find_mm_placeholders,
|
||||||
replace_token_matches)
|
replace_token_matches)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@ -337,14 +339,10 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
def _apply_token_matches(
|
def _apply_token_matches(
|
||||||
self,
|
self,
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
mm_prompt_updates: MultiModalPromptUpdates,
|
||||||
mm_item_counts: Mapping[str, int],
|
) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
|
||||||
) -> list[int]:
|
token_ids, res = super()._apply_token_matches(prompt,
|
||||||
token_ids = super()._apply_token_matches(
|
mm_prompt_updates)
|
||||||
prompt,
|
|
||||||
mm_matches,
|
|
||||||
mm_item_counts,
|
|
||||||
)
|
|
||||||
|
|
||||||
# "\n\n\n" and "\n\n\n\n" are single tokens
|
# "\n\n\n" and "\n\n\n\n" are single tokens
|
||||||
# Since our replacement can insert "\n\n" next to "\n"
|
# Since our replacement can insert "\n\n" next to "\n"
|
||||||
@ -373,13 +371,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
[newline_4],
|
[newline_4],
|
||||||
)
|
)
|
||||||
|
|
||||||
return token_ids
|
return token_ids, res
|
||||||
|
|
||||||
def _find_mm_placeholders(
|
def _find_mm_placeholders(
|
||||||
self,
|
self,
|
||||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
|
||||||
new_token_ids: list[int],
|
new_token_ids: list[int],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_prompt_updates: MultiModalPromptUpdates,
|
||||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||||
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
|
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
@ -404,8 +401,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
repl_token_ids.extend(repl_toks)
|
repl_token_ids.extend(repl_toks)
|
||||||
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
|
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
|
||||||
|
|
||||||
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
|
repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates)
|
||||||
mm_item_counts)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
modality: [
|
modality: [
|
||||||
|
|||||||
@ -29,10 +29,12 @@ from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
|
|||||||
MultiModalDataParser)
|
MultiModalDataParser)
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, BoundPromptUpdate,
|
BaseProcessingInfo,
|
||||||
|
MultiModalPromptUpdates,
|
||||||
|
MultiModalPromptUpdatesApplyResult,
|
||||||
PlaceholderFeaturesInfo,
|
PlaceholderFeaturesInfo,
|
||||||
PromptReplacement, PromptTargetMatch,
|
PromptReplacement, PromptUpdate,
|
||||||
PromptUpdate, PromptUpdateDetails,
|
PromptUpdateDetails,
|
||||||
find_mm_placeholders,
|
find_mm_placeholders,
|
||||||
replace_token_matches)
|
replace_token_matches)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@ -254,14 +256,10 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
|||||||
def _apply_token_matches(
|
def _apply_token_matches(
|
||||||
self,
|
self,
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
mm_prompt_updates: MultiModalPromptUpdates,
|
||||||
mm_item_counts: Mapping[str, int],
|
) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
|
||||||
) -> list[int]:
|
token_ids, res = super()._apply_token_matches(prompt,
|
||||||
token_ids = super()._apply_token_matches(
|
mm_prompt_updates)
|
||||||
prompt,
|
|
||||||
mm_matches,
|
|
||||||
mm_item_counts,
|
|
||||||
)
|
|
||||||
|
|
||||||
# "\n\n\n" and "\n\n\n\n" are single tokens
|
# "\n\n\n" and "\n\n\n\n" are single tokens
|
||||||
# Since our replacement can insert "\n\n" next to "\n"
|
# Since our replacement can insert "\n\n" next to "\n"
|
||||||
@ -290,13 +288,12 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
|||||||
[newline_4],
|
[newline_4],
|
||||||
)
|
)
|
||||||
|
|
||||||
return token_ids
|
return token_ids, res
|
||||||
|
|
||||||
def _find_mm_placeholders(
|
def _find_mm_placeholders(
|
||||||
self,
|
self,
|
||||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
|
||||||
new_token_ids: list[int],
|
new_token_ids: list[int],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_prompt_updates: MultiModalPromptUpdates,
|
||||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||||
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
|
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
@ -321,8 +318,7 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
|||||||
repl_token_ids.extend(repl_toks)
|
repl_token_ids.extend(repl_toks)
|
||||||
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
|
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
|
||||||
|
|
||||||
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
|
repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates)
|
||||||
mm_item_counts)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
modality: [
|
modality: [
|
||||||
|
|||||||
@ -828,26 +828,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
|||||||
target=[image_token_id] * num_image_tokens,
|
target=[image_token_id] * num_image_tokens,
|
||||||
replacement=get_replacement_mantis,
|
replacement=get_replacement_mantis,
|
||||||
)
|
)
|
||||||
])
|
], mm_item_counts)
|
||||||
|
|
||||||
prompt_ids, prompt, _ = self._apply_prompt_updates(
|
prompt_ids, prompt, _ = self._apply_prompt_updates(
|
||||||
result["prompt_token_ids"],
|
result["prompt_token_ids"],
|
||||||
mantis_mm_repls,
|
mantis_mm_repls,
|
||||||
mm_item_counts,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
unbound_orig_repls = self._get_prompt_updates(
|
orig_repls = self._get_mm_prompt_updates(
|
||||||
mm_items,
|
mm_items,
|
||||||
hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs,
|
||||||
mm_kwargs,
|
mm_kwargs,
|
||||||
)
|
)
|
||||||
orig_repls = self._bind_and_group_updates(unbound_orig_repls)
|
mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
|
||||||
|
|
||||||
mm_placeholders = self._find_mm_placeholders(
|
|
||||||
orig_repls,
|
|
||||||
prompt_ids,
|
|
||||||
mm_item_counts,
|
|
||||||
)
|
|
||||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||||
|
|
||||||
mm_placeholder_ranges = {
|
mm_placeholder_ranges = {
|
||||||
|
|||||||
@ -38,7 +38,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
|||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, BoundPromptUpdate,
|
BaseProcessingInfo,
|
||||||
|
MultiModalPromptUpdates,
|
||||||
PlaceholderFeaturesInfo,
|
PlaceholderFeaturesInfo,
|
||||||
PromptReplacement, PromptUpdate)
|
PromptReplacement, PromptUpdate)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@ -431,24 +432,21 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
|||||||
|
|
||||||
return [_IMAGE_TOKEN_ID] * num_image_tokens
|
return [_IMAGE_TOKEN_ID] * num_image_tokens
|
||||||
|
|
||||||
num_images = mm_items.get_count("image", strict=False)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
modality="image",
|
modality="image",
|
||||||
target=image_token,
|
target=image_tokens.__getitem__,
|
||||||
replacement=get_replacement_phi3v,
|
replacement=get_replacement_phi3v,
|
||||||
) for image_token in image_tokens[:num_images]
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _apply_prompt_updates(
|
def _apply_prompt_updates(
|
||||||
self,
|
self,
|
||||||
token_ids: list[int],
|
token_ids: list[int],
|
||||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
mm_prompt_updates: MultiModalPromptUpdates,
|
||||||
mm_item_counts: Mapping[str, int],
|
|
||||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||||
# align to hf behavior when there are images
|
# align to hf behavior when there are images
|
||||||
if len(mm_item_counts):
|
if len(mm_prompt_updates):
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
# to decode token_ids to the original text, we need to
|
# to decode token_ids to the original text, we need to
|
||||||
# 1. remove the first bos token
|
# 1. remove the first bos token
|
||||||
@ -484,7 +482,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
|||||||
token_ids, text, placeholders = super()._apply_prompt_updates(
|
token_ids, text, placeholders = super()._apply_prompt_updates(
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
mm_prompt_updates=mm_prompt_updates,
|
mm_prompt_updates=mm_prompt_updates,
|
||||||
mm_item_counts=mm_item_counts,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keep the behavior in line with HF processor
|
# Keep the behavior in line with HF processor
|
||||||
|
|||||||
@ -1032,8 +1032,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
|
|||||||
out_mm_kwargs: MultiModalKwargsItems,
|
out_mm_kwargs: MultiModalKwargsItems,
|
||||||
) -> Sequence[PromptUpdate]:
|
) -> Sequence[PromptUpdate]:
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
image_token_id = tokenizer.vocab[tokenizer.image_token]
|
image_token_id: int = tokenizer.vocab[tokenizer.image_token]
|
||||||
audio_token_id = tokenizer.vocab[tokenizer.audio_token]
|
audio_token_id: int = tokenizer.vocab[tokenizer.audio_token]
|
||||||
|
|
||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
audio_processor = self.info.get_feature_extractor(
|
audio_processor = self.info.get_feature_extractor(
|
||||||
@ -1053,9 +1053,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
|
|||||||
processor=hf_processor,
|
processor=hf_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_tokens = [image_token_id] * num_image_tokens
|
return [image_token_id] * num_image_tokens
|
||||||
|
|
||||||
return image_tokens
|
|
||||||
|
|
||||||
def get_audio_replacement_phi4mm(item_idx: int):
|
def get_audio_replacement_phi4mm(item_idx: int):
|
||||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||||
@ -1066,9 +1064,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
|
|||||||
audio_embed_size = self.info._compute_audio_embed_size(
|
audio_embed_size = self.info._compute_audio_embed_size(
|
||||||
audio_frames)
|
audio_frames)
|
||||||
|
|
||||||
audio_tokens = [audio_token_id] * audio_embed_size
|
return [audio_token_id] * audio_embed_size
|
||||||
|
|
||||||
return audio_tokens
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
|
|||||||
@ -824,9 +824,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
|
|||||||
processor=hf_processor,
|
processor=hf_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens
|
return [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens
|
||||||
|
|
||||||
return image_tokens
|
|
||||||
|
|
||||||
def get_audio_replacement_phi4mm(item_idx: int):
|
def get_audio_replacement_phi4mm(item_idx: int):
|
||||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||||
@ -837,28 +835,20 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
|
|||||||
audio_embed_size = self.info._compute_audio_embed_size(
|
audio_embed_size = self.info._compute_audio_embed_size(
|
||||||
audio_frames)
|
audio_frames)
|
||||||
|
|
||||||
audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
|
return [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
|
||||||
|
|
||||||
return audio_tokens
|
return [
|
||||||
|
|
||||||
num_images = mm_items.get_count("image", strict=False)
|
|
||||||
num_audios = mm_items.get_count("audio", strict=False)
|
|
||||||
|
|
||||||
image_repl = [
|
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
modality="image",
|
modality="image",
|
||||||
target=image_token,
|
target=image_tokens.__getitem__,
|
||||||
replacement=get_image_replacement_phi4mm,
|
replacement=get_image_replacement_phi4mm,
|
||||||
) for image_token in image_tokens[:num_images]
|
),
|
||||||
]
|
|
||||||
audio_repl = [
|
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
modality="audio",
|
modality="audio",
|
||||||
target=audio_token,
|
target=audio_tokens.__getitem__,
|
||||||
replacement=get_audio_replacement_phi4mm,
|
replacement=get_audio_replacement_phi4mm,
|
||||||
) for audio_token in audio_tokens[:num_audios]
|
),
|
||||||
]
|
]
|
||||||
return image_repl + audio_repl
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
|||||||
@ -309,9 +309,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
|||||||
|
|
||||||
if is_update_applied:
|
if is_update_applied:
|
||||||
mm_placeholders = self._find_mm_placeholders(
|
mm_placeholders = self._find_mm_placeholders(
|
||||||
mm_prompt_updates,
|
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
mm_item_counts,
|
mm_prompt_updates,
|
||||||
)
|
)
|
||||||
self._validate_mm_placeholders(
|
self._validate_mm_placeholders(
|
||||||
mm_placeholders,
|
mm_placeholders,
|
||||||
@ -328,7 +327,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
|||||||
) = self._apply_prompt_updates(
|
) = self._apply_prompt_updates(
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
mm_prompt_updates,
|
mm_prompt_updates,
|
||||||
mm_item_counts,
|
|
||||||
)
|
)
|
||||||
self._validate_mm_placeholders(
|
self._validate_mm_placeholders(
|
||||||
mm_placeholders,
|
mm_placeholders,
|
||||||
|
|||||||
@ -44,10 +44,21 @@ PromptSeq = Union[str, list[int]]
|
|||||||
"""A token sequence (list of token IDs) or text."""
|
"""A token sequence (list of token IDs) or text."""
|
||||||
|
|
||||||
|
|
||||||
|
class _GetMatchIndex(Protocol):
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
prompt: PromptSeq,
|
||||||
|
start_idx: int = 0,
|
||||||
|
) -> Optional[int]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptIndex:
|
class PromptIndex:
|
||||||
"""Resolves to an index in the prompt."""
|
"""Resolves to an index in the prompt."""
|
||||||
get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]]
|
get_match_index: _GetMatchIndex
|
||||||
|
|
||||||
|
|
||||||
class PromptIndexTargets:
|
class PromptIndexTargets:
|
||||||
@ -59,7 +70,7 @@ class PromptIndexTargets:
|
|||||||
|
|
||||||
This results in a match even if the prompt is empty.
|
This results in a match even if the prompt is empty.
|
||||||
"""
|
"""
|
||||||
return PromptIndex(lambda tok, prompt: 0)
|
return PromptIndex(lambda tokenizer, prompt, start_idx=0: 0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prefix(seq: PromptSeq) -> PromptIndex:
|
def prefix(seq: PromptSeq) -> PromptIndex:
|
||||||
@ -70,7 +81,11 @@ class PromptIndexTargets:
|
|||||||
def get_match_index(
|
def get_match_index(
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
prompt: PromptSeq,
|
prompt: PromptSeq,
|
||||||
|
start_idx: int = 0,
|
||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
|
if start_idx != 0:
|
||||||
|
return None
|
||||||
|
|
||||||
prefix = seq
|
prefix = seq
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
@ -96,14 +111,24 @@ class PromptIndexTargets:
|
|||||||
|
|
||||||
This results in a match even if the prompt is empty.
|
This results in a match even if the prompt is empty.
|
||||||
"""
|
"""
|
||||||
return PromptIndex(lambda tok, prompt: len(prompt))
|
return PromptIndex(lambda tokenizer, prompt, start_idx=0: len(prompt))
|
||||||
|
|
||||||
|
|
||||||
PromptTarget = Union[PromptSeq, PromptIndex]
|
UpdateTarget = Union[PromptSeq, PromptIndex]
|
||||||
"""
|
"""
|
||||||
The token sequence or text to update.
|
The token sequence or text to update.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
PromptUpdateTarget = Union[Callable[[int], UpdateTarget], UpdateTarget]
|
||||||
|
"""
|
||||||
|
Given the index of the processed item within
|
||||||
|
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
|
||||||
|
output the corresponding token sequence (or text).
|
||||||
|
|
||||||
|
For convenience, you can directly pass in the token sequence (or text)
|
||||||
|
instead of a function if it does not depend on the input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptUpdateDetails(Generic[_S]):
|
class PromptUpdateDetails(Generic[_S]):
|
||||||
@ -190,7 +215,7 @@ class PromptUpdate(ABC):
|
|||||||
modality: str
|
modality: str
|
||||||
"""The modality for which the update is made."""
|
"""The modality for which the update is made."""
|
||||||
|
|
||||||
target: PromptTarget
|
target: PromptUpdateTarget
|
||||||
"""The token sequence (or text) to update."""
|
"""The token sequence (or text) to update."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -205,10 +230,54 @@ class PromptUpdate(ABC):
|
|||||||
"""Defines how to update the prompt."""
|
"""Defines how to update the prompt."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate":
|
def _resolve_target(
|
||||||
return BoundPromptUpdate(
|
self,
|
||||||
_origin=self,
|
tokenizer: AnyTokenizer,
|
||||||
tokenizer=tokenizer,
|
item_idx: int,
|
||||||
|
) -> Union["_BoundPromptSequence", PromptIndex]:
|
||||||
|
target = self.target
|
||||||
|
if callable(target):
|
||||||
|
target = target(item_idx)
|
||||||
|
|
||||||
|
if isinstance(target, PromptIndex):
|
||||||
|
return target
|
||||||
|
|
||||||
|
return _BoundPromptSequence.from_seq(tokenizer, target)
|
||||||
|
|
||||||
|
def _resolve_content(
|
||||||
|
self,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
item_idx: int,
|
||||||
|
) -> "_BoundPromptContent":
|
||||||
|
content = self.content
|
||||||
|
if callable(content):
|
||||||
|
content = content(item_idx)
|
||||||
|
|
||||||
|
if not isinstance(content, PromptUpdateDetails):
|
||||||
|
content = PromptUpdateDetails.from_seq(content)
|
||||||
|
|
||||||
|
bound_full = _BoundPromptSequence.from_seq(tokenizer, content.full)
|
||||||
|
bound_content = _BoundPromptContent(full=bound_full,
|
||||||
|
is_embed=content.is_embed)
|
||||||
|
|
||||||
|
return bound_content
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
self,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
item_idx: int,
|
||||||
|
) -> "ResolvedPromptUpdate":
|
||||||
|
"""
|
||||||
|
Given the index of the processed item within
|
||||||
|
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
|
||||||
|
output a copy of this object with its lazy attributes resolved.
|
||||||
|
"""
|
||||||
|
return ResolvedPromptUpdate(
|
||||||
|
modality=self.modality,
|
||||||
|
item_idx=item_idx,
|
||||||
|
mode=self.mode,
|
||||||
|
target=self._resolve_target(tokenizer, item_idx),
|
||||||
|
content=self._resolve_content(tokenizer, item_idx),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -452,73 +521,90 @@ class _BoundPromptContent:
|
|||||||
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
|
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class PromptTargetMatch(NamedTuple):
|
||||||
class BoundPromptUpdate:
|
start_idx: int
|
||||||
|
end_idx: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ResolvedPromptUpdate:
|
||||||
"""
|
"""
|
||||||
A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] bound
|
A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] with its
|
||||||
to a tokenizer to automatically convert
|
lazy attributes resolved, apart from those related to tokenization.
|
||||||
[`target`][vllm.multimodal.processing.PromptUpdate.target] and the result of
|
|
||||||
[`get_content`][vllm.multimodal.processing.BoundPromptUpdate.get_content]
|
|
||||||
between token sequence and text representations.
|
|
||||||
"""
|
"""
|
||||||
_origin: PromptUpdate
|
|
||||||
tokenizer: AnyTokenizer = field(repr=False)
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
modality: str
|
||||||
self._content_cache = dict[int, _BoundPromptContent]()
|
"""The modality for which the update is made."""
|
||||||
|
|
||||||
@property
|
item_idx: int
|
||||||
def modality(self) -> str:
|
"""The index within `modality` of the item this update pertains to."""
|
||||||
return self._origin.modality
|
|
||||||
|
|
||||||
@property
|
mode: UpdateMode
|
||||||
def target(self) -> Union[_BoundPromptSequence, PromptIndex]:
|
"""Defines how to update the prompt."""
|
||||||
|
|
||||||
|
target: Union[_BoundPromptSequence, PromptIndex]
|
||||||
"""The token sequence (or text) to update."""
|
"""The token sequence (or text) to update."""
|
||||||
target = self._origin.target
|
|
||||||
|
content: _BoundPromptContent = field(repr=False)
|
||||||
|
"""The placeholder tokens that are part of the update."""
|
||||||
|
|
||||||
|
def iter_token_matches(
|
||||||
|
self,
|
||||||
|
prompt: list[int],
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
*,
|
||||||
|
start_idx: int = 0,
|
||||||
|
) -> Generator[PromptTargetMatch]:
|
||||||
|
"""Yield each instance of `self.target` found in `prompt`."""
|
||||||
|
target = self.target
|
||||||
|
|
||||||
if isinstance(target, PromptIndex):
|
if isinstance(target, PromptIndex):
|
||||||
return target
|
match_idx = target.get_match_index(tokenizer, prompt, start_idx)
|
||||||
|
if match_idx is not None:
|
||||||
|
yield PromptTargetMatch(match_idx, match_idx)
|
||||||
|
|
||||||
return _BoundPromptSequence.from_seq(self.tokenizer, target)
|
return
|
||||||
|
|
||||||
@property
|
for match in iter_token_matches(prompt,
|
||||||
def content(self) -> PromptUpdateContent:
|
target.token_ids,
|
||||||
"""The placeholder tokens that are part of the update."""
|
start_idx=start_idx):
|
||||||
return self._origin.content
|
yield PromptTargetMatch(match.start_idx, match.end_idx)
|
||||||
|
|
||||||
@property
|
def iter_text_matches(
|
||||||
def mode(self) -> UpdateMode:
|
self,
|
||||||
"""Defines how to update the prompt."""
|
prompt: str,
|
||||||
return self._origin.mode
|
tokenizer: AnyTokenizer,
|
||||||
|
*,
|
||||||
|
start_idx: int = 0,
|
||||||
|
) -> Generator[PromptTargetMatch]:
|
||||||
|
"""Yield each instance of `self.target` found in `prompt`."""
|
||||||
|
target = self.target
|
||||||
|
|
||||||
def get_content(self, item_idx: int) -> _BoundPromptContent:
|
if isinstance(target, PromptIndex):
|
||||||
"""
|
match_idx = target.get_match_index(tokenizer, prompt, start_idx)
|
||||||
Given the index of the processed item within
|
if match_idx is not None:
|
||||||
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
|
yield PromptTargetMatch(match_idx, match_idx)
|
||||||
output the token sequence (or text) to update.
|
|
||||||
"""
|
|
||||||
content = self.content
|
|
||||||
if callable(content):
|
|
||||||
cache_key = item_idx
|
|
||||||
if cache_key in self._content_cache:
|
|
||||||
return self._content_cache[cache_key]
|
|
||||||
|
|
||||||
content = content(item_idx)
|
return
|
||||||
else:
|
|
||||||
cache_key = None
|
|
||||||
|
|
||||||
if not isinstance(content, PromptUpdateDetails):
|
for match in re.finditer(re.escape(target.text), prompt,
|
||||||
content = PromptUpdateDetails.from_seq(content)
|
pos=start_idx):
|
||||||
|
yield PromptTargetMatch(match.start(), match.end())
|
||||||
|
|
||||||
bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
|
def iter_matches(
|
||||||
content.full)
|
self,
|
||||||
bound_content = _BoundPromptContent(full=bound_full,
|
prompt: Union[list[int], str],
|
||||||
is_embed=content.is_embed)
|
tokenizer: AnyTokenizer,
|
||||||
|
*,
|
||||||
|
start_idx: int = 0,
|
||||||
|
) -> Generator[PromptTargetMatch]:
|
||||||
|
"""Yield each instance of `self.target` found in `prompt`."""
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
return self.iter_text_matches(prompt,
|
||||||
|
tokenizer,
|
||||||
|
start_idx=start_idx)
|
||||||
|
|
||||||
if cache_key is not None:
|
return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx)
|
||||||
self._content_cache[cache_key] = bound_content
|
|
||||||
|
|
||||||
return bound_content
|
|
||||||
|
|
||||||
|
|
||||||
class _TokenMatch(NamedTuple):
|
class _TokenMatch(NamedTuple):
|
||||||
@ -529,6 +615,8 @@ class _TokenMatch(NamedTuple):
|
|||||||
def iter_token_matches(
|
def iter_token_matches(
|
||||||
token_ids: list[int],
|
token_ids: list[int],
|
||||||
match_ids: list[int],
|
match_ids: list[int],
|
||||||
|
*,
|
||||||
|
start_idx: int = 0,
|
||||||
) -> Generator[_TokenMatch]:
|
) -> Generator[_TokenMatch]:
|
||||||
"""
|
"""
|
||||||
Yield each occurrence of `match_ids` in `token_ids`.
|
Yield each occurrence of `match_ids` in `token_ids`.
|
||||||
@ -541,7 +629,6 @@ def iter_token_matches(
|
|||||||
if match_len == 0:
|
if match_len == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
start_idx = 0
|
|
||||||
while start_idx < prompt_len - match_len + 1:
|
while start_idx < prompt_len - match_len + 1:
|
||||||
end_idx = start_idx + match_len
|
end_idx = start_idx + match_len
|
||||||
|
|
||||||
@ -581,68 +668,6 @@ def replace_token_matches(
|
|||||||
return flatten_2d_lists(out_seqs)
|
return flatten_2d_lists(out_seqs)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(repr=False)
|
|
||||||
class PromptTargetMatch(ABC):
|
|
||||||
_origin: BoundPromptUpdate
|
|
||||||
|
|
||||||
@property
|
|
||||||
def modality(self) -> str:
|
|
||||||
return self._origin.modality
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def start_idx(self) -> int:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def end_idx(self) -> int:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (f"{type(self).__name__}(modality={self.modality!r}, "
|
|
||||||
f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(repr=False)
|
|
||||||
class _PromptTargetIndexMatch(PromptTargetMatch):
|
|
||||||
match_idx: int
|
|
||||||
|
|
||||||
@property
|
|
||||||
def start_idx(self) -> int:
|
|
||||||
return self.match_idx
|
|
||||||
|
|
||||||
@property
|
|
||||||
def end_idx(self) -> int:
|
|
||||||
return self.match_idx
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(repr=False)
|
|
||||||
class _PromptTargetTokenMatch(PromptTargetMatch):
|
|
||||||
match: _TokenMatch
|
|
||||||
|
|
||||||
@property
|
|
||||||
def start_idx(self) -> int:
|
|
||||||
return self.match.start_idx
|
|
||||||
|
|
||||||
@property
|
|
||||||
def end_idx(self) -> int:
|
|
||||||
return self.match.end_idx
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(repr=False)
|
|
||||||
class _PromptTargetTextMatch(PromptTargetMatch):
|
|
||||||
match: re.Match[str]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def start_idx(self) -> int:
|
|
||||||
return self.match.start()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def end_idx(self) -> int:
|
|
||||||
return self.match.end()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PlaceholderFeaturesInfo:
|
class PlaceholderFeaturesInfo:
|
||||||
modality: str
|
modality: str
|
||||||
@ -665,163 +690,158 @@ class PlaceholderFeaturesInfo:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_token_matches(
|
_MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
|
||||||
prompt: list[int],
|
|
||||||
prompt_updates: Sequence[BoundPromptUpdate],
|
|
||||||
) -> Sequence[PromptTargetMatch]:
|
|
||||||
"""Return each target of `prompt_updates` found in `prompt`."""
|
|
||||||
|
|
||||||
def get_matches(update: BoundPromptUpdate):
|
|
||||||
target = update.target
|
|
||||||
|
|
||||||
if isinstance(target, PromptIndex):
|
|
||||||
match_idx = target.get_match_index(update.tokenizer, prompt)
|
|
||||||
if match_idx is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return [_PromptTargetIndexMatch(update, match_idx)]
|
|
||||||
|
|
||||||
return [
|
|
||||||
_PromptTargetTokenMatch(update, match)
|
|
||||||
for match in iter_token_matches(prompt, target.token_ids)
|
|
||||||
]
|
|
||||||
|
|
||||||
return [
|
|
||||||
match for update in prompt_updates for match in get_matches(update)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def find_text_matches(
|
def _find_matches(
|
||||||
prompt: str,
|
prompt: _S,
|
||||||
prompt_updates: Sequence[BoundPromptUpdate],
|
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||||
) -> Sequence[PromptTargetMatch]:
|
tokenizer: AnyTokenizer,
|
||||||
"""Return each target of `prompt_updates` found in `prompt`."""
|
*,
|
||||||
|
prev_end_idx: int = 0,
|
||||||
|
current_result: "MultiModalPromptUpdatesApplyResult",
|
||||||
|
) -> tuple[Optional[UpdateMode], list[_MatchToApply]]:
|
||||||
|
mode: Optional[UpdateMode] = None
|
||||||
|
mm_matches = dict[tuple[str, int], tuple[PromptTargetMatch, int]]()
|
||||||
|
|
||||||
def get_matches(update: BoundPromptUpdate):
|
for modality, modality_updates in mm_prompt_updates.items():
|
||||||
target = update.target
|
for item_idx, item_updates in enumerate(modality_updates):
|
||||||
|
if current_result[modality][item_idx] is not None:
|
||||||
|
continue # Updates have already been applied for this item
|
||||||
|
|
||||||
if isinstance(target, PromptIndex):
|
for update_idx, update in enumerate(item_updates):
|
||||||
match_idx = target.get_match_index(update.tokenizer, prompt)
|
if (modality, item_idx) in mm_matches:
|
||||||
if match_idx is None:
|
break # Already found a match for this item
|
||||||
return []
|
|
||||||
|
|
||||||
return [_PromptTargetIndexMatch(update, match_idx)]
|
for match in update.iter_matches(
|
||||||
|
prompt,
|
||||||
|
tokenizer,
|
||||||
|
start_idx=prev_end_idx,
|
||||||
|
):
|
||||||
|
# All matches should share the same mode
|
||||||
|
if mode is None:
|
||||||
|
mode = update.mode
|
||||||
|
elif mode != update.mode:
|
||||||
|
continue
|
||||||
|
|
||||||
return [
|
mm_matches[(modality, item_idx)] = match, update_idx
|
||||||
_PromptTargetTextMatch(update, match)
|
break # Get only the first valid match per item
|
||||||
for match in re.finditer(re.escape(target.text), prompt)
|
|
||||||
]
|
|
||||||
|
|
||||||
return [
|
# Prioritize earlier matches
|
||||||
match for update in prompt_updates for match in get_matches(update)
|
matches_to_apply = sorted(mm_matches.items(), key=lambda item: item[1][0])
|
||||||
]
|
|
||||||
|
|
||||||
|
# To avoid conflicts, only replace one non-empty item at a time
|
||||||
|
if mode == UpdateMode.REPLACE:
|
||||||
|
matches_to_apply_ = list[_MatchToApply]()
|
||||||
|
has_non_empty_matches = False
|
||||||
|
|
||||||
def _resolve_matches(
|
for item in matches_to_apply:
|
||||||
prompt: PromptSeq,
|
_, (match, _) = item
|
||||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
if match.start_idx == match.end_idx:
|
||||||
) -> list[PromptTargetMatch]:
|
matches_to_apply_.append(item)
|
||||||
"""
|
elif not has_non_empty_matches:
|
||||||
Resolve `mm_matches` to ensure that there are no overlapping matches,
|
has_non_empty_matches = True
|
||||||
and sort them such that earlier matches take priority over later ones.
|
matches_to_apply_.append(item)
|
||||||
"""
|
|
||||||
matches = [m for matches in mm_matches.values() for m in matches]
|
|
||||||
|
|
||||||
seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
|
matches_to_apply = matches_to_apply_
|
||||||
|
|
||||||
for match in matches:
|
return mode, matches_to_apply
|
||||||
for idx in range(match.start_idx, match.end_idx):
|
|
||||||
if seen_matches[idx] is not None:
|
|
||||||
raise ValueError("Found overlapping matches "
|
|
||||||
f"({seen_matches[idx]} and {match}) "
|
|
||||||
f"at index={idx} of prompt={prompt}")
|
|
||||||
|
|
||||||
seen_matches[idx] = match
|
|
||||||
|
|
||||||
return sorted(matches, key=lambda x: x.start_idx)
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_matches(
|
def _apply_matches(
|
||||||
prompt: _S,
|
prompt: _S,
|
||||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||||
mm_item_counts: Mapping[str, int],
|
tokenizer: AnyTokenizer,
|
||||||
) -> list[_S]:
|
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
|
||||||
"""Apply the updates in `mm_matches` to `prompt`."""
|
prompt_len = len(prompt)
|
||||||
|
|
||||||
out_seqs = list[Union[str, list[int]]]()
|
out_seqs = list[Union[str, list[int]]]()
|
||||||
prev_end_idx = 0
|
out_result: MultiModalPromptUpdatesApplyResult = {
|
||||||
next_idx_by_modality = defaultdict[str, int](lambda: 0)
|
m: [None] * len(items)
|
||||||
|
for m, items in mm_prompt_updates.items()
|
||||||
|
}
|
||||||
|
|
||||||
for match in _resolve_matches(prompt, mm_matches):
|
start_idx = prev_end_idx = 0
|
||||||
modality = match.modality
|
while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt
|
||||||
|
found = False
|
||||||
|
|
||||||
item_start_idx = next_idx_by_modality[modality]
|
mode, matches_to_apply = _find_matches(
|
||||||
max_item_count = mm_item_counts.get(modality, 0)
|
prompt,
|
||||||
if item_start_idx >= max_item_count:
|
mm_prompt_updates,
|
||||||
continue
|
tokenizer,
|
||||||
|
prev_end_idx=prev_end_idx,
|
||||||
|
current_result=out_result,
|
||||||
|
)
|
||||||
|
|
||||||
start_idx = match.start_idx
|
if mode is not None:
|
||||||
end_idx = match.end_idx
|
for (modality, item_idx), (match, update_idx) in matches_to_apply:
|
||||||
origin = match._origin
|
found = True
|
||||||
mode = origin.mode
|
|
||||||
|
matched_update = mm_prompt_updates[modality][item_idx][
|
||||||
|
update_idx]
|
||||||
|
matched_content = matched_update.content
|
||||||
|
|
||||||
if mode == UpdateMode.INSERT:
|
if mode == UpdateMode.INSERT:
|
||||||
out_seqs.append(prompt[prev_end_idx:end_idx])
|
end_idx_to_insert = match.end_idx
|
||||||
num_inserts = max_item_count
|
|
||||||
elif mode == UpdateMode.REPLACE:
|
elif mode == UpdateMode.REPLACE:
|
||||||
out_seqs.append(prompt[prev_end_idx:start_idx])
|
end_idx_to_insert = match.start_idx
|
||||||
num_inserts = max_item_count if start_idx == end_idx else 1
|
|
||||||
else:
|
else:
|
||||||
assert_never(mode)
|
assert_never(mode)
|
||||||
|
|
||||||
item_end_idx = min(item_start_idx + num_inserts, max_item_count)
|
out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
|
||||||
|
out_seqs.append(matched_content.full.text if isinstance(
|
||||||
|
prompt, str) else matched_content.full.token_ids)
|
||||||
|
out_result[modality][item_idx] = update_idx
|
||||||
|
|
||||||
for item_idx in range(item_start_idx, item_end_idx):
|
# Exclude overlapping matches
|
||||||
content = origin.get_content(item_idx)
|
start_idx = prev_end_idx = match.end_idx
|
||||||
insert_seq = (content.full.text if isinstance(prompt, str) else
|
|
||||||
content.full.token_ids)
|
|
||||||
|
|
||||||
out_seqs.append(insert_seq)
|
if not found:
|
||||||
|
start_idx += 1
|
||||||
prev_end_idx = end_idx
|
|
||||||
next_idx_by_modality[modality] += item_end_idx - item_start_idx
|
|
||||||
|
|
||||||
out_seqs.append(prompt[prev_end_idx:])
|
out_seqs.append(prompt[prev_end_idx:])
|
||||||
|
|
||||||
return cast(list[_S], out_seqs)
|
return cast(list[_S], out_seqs), out_result
|
||||||
|
|
||||||
|
|
||||||
def apply_token_matches(
|
def apply_token_matches(
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||||
mm_item_counts: Mapping[str, int],
|
tokenizer: AnyTokenizer,
|
||||||
) -> list[int]:
|
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
|
||||||
"""Apply the updates in `mm_matches` to `prompt`."""
|
"""
|
||||||
if not mm_matches:
|
Apply the updates in `mm_prompt_updates` to `prompt`.
|
||||||
return prompt
|
|
||||||
|
|
||||||
token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts)
|
Matches are exclusive even when multiple modalities share
|
||||||
|
the same placeholder tokens. In that case, the modality that
|
||||||
|
appears earlier in `mm_prompt_updates` takes priority.
|
||||||
|
"""
|
||||||
|
token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates,
|
||||||
|
tokenizer)
|
||||||
|
|
||||||
return flatten_2d_lists(token_id_seqs)
|
return flatten_2d_lists(token_id_seqs), result
|
||||||
|
|
||||||
|
|
||||||
def apply_text_matches(
|
def apply_text_matches(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||||
mm_item_counts: Mapping[str, int],
|
tokenizer: AnyTokenizer,
|
||||||
) -> str:
|
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
|
||||||
"""Apply the updates in `mm_matches` to `prompt`."""
|
"""
|
||||||
if not mm_matches:
|
Apply the updates in `mm_prompt_updates` to `prompt`.
|
||||||
return prompt
|
|
||||||
|
|
||||||
texts = _apply_matches(prompt, mm_matches, mm_item_counts)
|
Matches are exclusive even when multiple modalities share
|
||||||
|
the same placeholder tokens. In that case, the modality that
|
||||||
|
appears earlier in `mm_prompt_updates` takes priority.
|
||||||
|
"""
|
||||||
|
texts, result = _apply_matches(prompt, mm_prompt_updates, tokenizer)
|
||||||
|
|
||||||
return "".join(texts)
|
return "".join(texts), result
|
||||||
|
|
||||||
|
|
||||||
def _iter_placeholders(
|
def _iter_placeholders(
|
||||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||||
) -> Iterable[PlaceholderFeaturesInfo]:
|
) -> Iterable[PlaceholderFeaturesInfo]:
|
||||||
"""
|
"""
|
||||||
Yield each set of placeholder tokens found in `prompt`.
|
Yield each set of placeholder tokens found in `prompt`.
|
||||||
@ -833,6 +853,8 @@ def _iter_placeholders(
|
|||||||
Note that empty matches are ignored.
|
Note that empty matches are ignored.
|
||||||
"""
|
"""
|
||||||
prompt_len = len(prompt)
|
prompt_len = len(prompt)
|
||||||
|
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
|
||||||
|
|
||||||
item_idx_by_modality = defaultdict[str, int](lambda: 0)
|
item_idx_by_modality = defaultdict[str, int](lambda: 0)
|
||||||
|
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
@ -844,8 +866,8 @@ def _iter_placeholders(
|
|||||||
if item_idx >= mm_item_counts.get(modality, 0):
|
if item_idx >= mm_item_counts.get(modality, 0):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for update_info in modality_updates:
|
for update in modality_updates[item_idx]:
|
||||||
content = update_info.get_content(item_idx)
|
content = update.content
|
||||||
content_tokens_full = content.full.token_ids
|
content_tokens_full = content.full.token_ids
|
||||||
content_len_full = len(content_tokens_full)
|
content_len_full = len(content_tokens_full)
|
||||||
end_idx_full = start_idx + content_len_full
|
end_idx_full = start_idx + content_len_full
|
||||||
@ -880,11 +902,10 @@ def _iter_placeholders(
|
|||||||
|
|
||||||
|
|
||||||
def find_mm_placeholders(
|
def find_mm_placeholders(
|
||||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||||
it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts)
|
it = _iter_placeholders(prompt, mm_prompt_updates)
|
||||||
return dict(full_groupby_modality(it))
|
return dict(full_groupby_modality(it))
|
||||||
|
|
||||||
|
|
||||||
@ -989,12 +1010,20 @@ A collection of hashes with a similar structure as
|
|||||||
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
|
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MultiModalPromptUpdates = dict[str, Sequence[BoundPromptUpdate]]
|
MultiModalPromptUpdates = Mapping[str, list[Sequence[ResolvedPromptUpdate]]]
|
||||||
"""
|
"""
|
||||||
A collection of prompt updates with a similar structure as
|
A collection of prompt updates with a similar structure as
|
||||||
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
|
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
MultiModalPromptUpdatesApplyResult = Mapping[str, list[Optional[int]]]
|
||||||
|
"""
|
||||||
|
For an item `MultiModalPromptUpdates[k][i]`,
|
||||||
|
`MultiModalPromptUpdatesApplyResult[k][i]` represents the index of the
|
||||||
|
`ResolvedPromptUpdate` instance that has been applied, or `None` if none of the
|
||||||
|
`ResolvedPromptUpdate` instances have been applied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class MultiModalProcessingInfo(NamedTuple):
|
class MultiModalProcessingInfo(NamedTuple):
|
||||||
kwargs: MultiModalKwargsItems
|
kwargs: MultiModalKwargsItems
|
||||||
@ -1126,14 +1155,60 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _bind_and_group_updates(
|
||||||
|
self,
|
||||||
|
prompt_updates: Sequence[PromptUpdate],
|
||||||
|
mm_item_counts: Mapping[str, int],
|
||||||
|
) -> MultiModalPromptUpdates:
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
|
||||||
|
return {
|
||||||
|
modality:
|
||||||
|
[[update.resolve(tokenizer, item_idx) for update in updates]
|
||||||
|
for item_idx in range(mm_item_counts.get(modality, 0))]
|
||||||
|
for modality, updates in full_groupby_modality(prompt_updates)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_mm_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargsItems,
|
||||||
|
) -> MultiModalPromptUpdates:
|
||||||
|
unbound_prompt_updates = self._get_prompt_updates(
|
||||||
|
mm_items=mm_items,
|
||||||
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
|
out_mm_kwargs=out_mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_prompt_updates = self._bind_and_group_updates(
|
||||||
|
unbound_prompt_updates,
|
||||||
|
mm_items.get_all_counts(),
|
||||||
|
)
|
||||||
|
|
||||||
|
for modality, prompt_updates in mm_prompt_updates.items():
|
||||||
|
for item_idx, item_prompt_updates in enumerate(prompt_updates):
|
||||||
|
if len(item_prompt_updates) > 1:
|
||||||
|
logger.warning_once(
|
||||||
|
"Detected %d prompt updates for `mm_items[%r][%s]`. "
|
||||||
|
"Multiple prompt updates per item is now "
|
||||||
|
"deprecated and may be removed in v0.13. "
|
||||||
|
"Instead, please specify dynamic update targets "
|
||||||
|
"in the same prompt update definition by passing "
|
||||||
|
"a function to `PromptUpdate.target`.",
|
||||||
|
len(prompt_updates),
|
||||||
|
modality,
|
||||||
|
item_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
return mm_prompt_updates
|
||||||
|
|
||||||
def _find_mm_placeholders(
|
def _find_mm_placeholders(
|
||||||
self,
|
self,
|
||||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
|
||||||
new_token_ids: list[int],
|
new_token_ids: list[int],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_prompt_updates: MultiModalPromptUpdates,
|
||||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||||
return find_mm_placeholders(mm_prompt_updates, new_token_ids,
|
return find_mm_placeholders(new_token_ids, mm_prompt_updates)
|
||||||
mm_item_counts)
|
|
||||||
|
|
||||||
def _get_hf_mm_data(
|
def _get_hf_mm_data(
|
||||||
self,
|
self,
|
||||||
@ -1421,13 +1496,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
|
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
|
||||||
tokenization_kwargs)
|
tokenization_kwargs)
|
||||||
|
|
||||||
unbound_prompt_updates = self._get_prompt_updates(
|
mm_prompt_updates = self._get_mm_prompt_updates(
|
||||||
mm_data_items,
|
mm_data_items,
|
||||||
hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs,
|
||||||
mm_kwargs,
|
mm_kwargs,
|
||||||
)
|
)
|
||||||
mm_prompt_updates = self._bind_and_group_updates(
|
|
||||||
unbound_prompt_updates)
|
|
||||||
|
|
||||||
mm_info = MultiModalProcessingInfo(
|
mm_info = MultiModalProcessingInfo(
|
||||||
kwargs=mm_kwargs,
|
kwargs=mm_kwargs,
|
||||||
@ -1497,13 +1570,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_missing_kwargs=mm_missing_kwargs,
|
mm_missing_kwargs=mm_missing_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
unbound_prompt_updates = self._get_prompt_updates(
|
mm_prompt_updates = self._get_mm_prompt_updates(
|
||||||
mm_data_items,
|
mm_data_items,
|
||||||
hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs,
|
||||||
mm_kwargs,
|
mm_kwargs,
|
||||||
)
|
)
|
||||||
mm_prompt_updates = self._bind_and_group_updates(
|
|
||||||
unbound_prompt_updates)
|
|
||||||
|
|
||||||
mm_info = MultiModalProcessingInfo(
|
mm_info = MultiModalProcessingInfo(
|
||||||
kwargs=mm_kwargs,
|
kwargs=mm_kwargs,
|
||||||
@ -1513,47 +1584,33 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
|
|
||||||
return prompt_ids, mm_info, is_update_applied
|
return prompt_ids, mm_info, is_update_applied
|
||||||
|
|
||||||
def _bind_and_group_updates(
|
|
||||||
self,
|
|
||||||
prompt_updates: Sequence[PromptUpdate],
|
|
||||||
) -> dict[str, Sequence[BoundPromptUpdate]]:
|
|
||||||
tokenizer = self.info.get_tokenizer()
|
|
||||||
|
|
||||||
it = (update.bind(tokenizer) for update in prompt_updates)
|
|
||||||
return dict(full_groupby_modality(it))
|
|
||||||
|
|
||||||
def _apply_token_matches(
|
def _apply_token_matches(
|
||||||
self,
|
self,
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
mm_prompt_updates: MultiModalPromptUpdates,
|
||||||
mm_item_counts: Mapping[str, int],
|
) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
|
||||||
) -> list[int]:
|
tokenizer = self.info.get_tokenizer()
|
||||||
return apply_token_matches(prompt, mm_matches, mm_item_counts)
|
return apply_token_matches(prompt, mm_prompt_updates, tokenizer)
|
||||||
|
|
||||||
def _apply_text_matches(
|
def _apply_text_matches(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
mm_prompt_updates: MultiModalPromptUpdates,
|
||||||
mm_item_counts: Mapping[str, int],
|
) -> tuple[str, MultiModalPromptUpdatesApplyResult]:
|
||||||
) -> str:
|
tokenizer = self.info.get_tokenizer()
|
||||||
return apply_text_matches(prompt, mm_matches, mm_item_counts)
|
return apply_text_matches(prompt, mm_prompt_updates, tokenizer)
|
||||||
|
|
||||||
def _apply_prompt_updates(
|
def _apply_prompt_updates(
|
||||||
self,
|
self,
|
||||||
token_ids: list[int],
|
token_ids: list[int],
|
||||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
mm_prompt_updates: MultiModalPromptUpdates,
|
||||||
mm_item_counts: Mapping[str, int],
|
|
||||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
|
||||||
mm_token_matches = {
|
new_token_ids, match_result = self._apply_token_matches(
|
||||||
modality: find_token_matches(token_ids, updates)
|
token_ids,
|
||||||
for modality, updates in mm_prompt_updates.items()
|
mm_prompt_updates,
|
||||||
}
|
)
|
||||||
mm_match_counts = {
|
|
||||||
modality: len(matches)
|
|
||||||
for modality, matches in mm_token_matches.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# If the search text does not represent a special token,
|
# If the search text does not represent a special token,
|
||||||
# it may have different token IDs in the prompt, because
|
# it may have different token IDs in the prompt, because
|
||||||
@ -1566,48 +1623,38 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
# of the search text in the prompt, we instead perform string-based
|
# of the search text in the prompt, we instead perform string-based
|
||||||
# updates on the decoded token IDs, then encode them back.
|
# updates on the decoded token IDs, then encode them back.
|
||||||
if all(
|
if all(
|
||||||
mm_match_counts.get(modality, 0) >= item_count
|
all(update_idx is not None for update_idx in update_idxs)
|
||||||
for modality, item_count in mm_item_counts.items()
|
for update_idxs in match_result.values()):
|
||||||
): # yapf: disable
|
new_text = decode_tokens(tokenizer, new_token_ids)
|
||||||
token_ids = self._apply_token_matches(
|
|
||||||
token_ids,
|
|
||||||
mm_token_matches,
|
|
||||||
mm_item_counts,
|
|
||||||
)
|
|
||||||
|
|
||||||
text = decode_tokens(tokenizer, token_ids)
|
|
||||||
matched_updates = {
|
|
||||||
modality: [match._origin for match in token_matches]
|
|
||||||
for modality, token_matches in mm_token_matches.items()
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
text = decode_tokens(tokenizer, token_ids)
|
new_text, match_result = self._apply_text_matches(
|
||||||
|
decode_tokens(tokenizer, token_ids),
|
||||||
mm_text_matches = {
|
mm_prompt_updates,
|
||||||
modality: find_text_matches(text, updates)
|
|
||||||
for modality, updates in mm_prompt_updates.items()
|
|
||||||
}
|
|
||||||
text = self._apply_text_matches(
|
|
||||||
text,
|
|
||||||
mm_text_matches,
|
|
||||||
mm_item_counts,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
token_ids = encode_tokens(tokenizer,
|
new_token_ids = encode_tokens(
|
||||||
text,
|
tokenizer,
|
||||||
add_special_tokens=False)
|
new_text,
|
||||||
matched_updates = {
|
add_special_tokens=False,
|
||||||
modality: [match._origin for match in token_matches]
|
)
|
||||||
for modality, token_matches in mm_text_matches.items()
|
|
||||||
}
|
matched_updates = defaultdict[
|
||||||
|
str, list[Sequence[ResolvedPromptUpdate]]](list)
|
||||||
|
for modality, update_idxs in match_result.items():
|
||||||
|
for item_idx, update_idx in enumerate(update_idxs):
|
||||||
|
assert update_idx is not None, (
|
||||||
|
"Failed to apply prompt replacement for "
|
||||||
|
f"mm_items[{modality!r}][{item_idx}]")
|
||||||
|
|
||||||
|
matched_updates[modality].append(
|
||||||
|
[mm_prompt_updates[modality][item_idx][update_idx]])
|
||||||
|
|
||||||
placeholders = self._find_mm_placeholders(
|
placeholders = self._find_mm_placeholders(
|
||||||
matched_updates,
|
new_token_ids,
|
||||||
token_ids,
|
dict(matched_updates),
|
||||||
mm_item_counts,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return token_ids, text, placeholders
|
return new_token_ids, new_text, placeholders
|
||||||
|
|
||||||
def _validate_mm_kwargs(
|
def _validate_mm_kwargs(
|
||||||
self,
|
self,
|
||||||
@ -1661,9 +1708,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
|
|
||||||
if is_update_applied:
|
if is_update_applied:
|
||||||
mm_placeholders = self._find_mm_placeholders(
|
mm_placeholders = self._find_mm_placeholders(
|
||||||
mm_prompt_updates,
|
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
mm_item_counts,
|
mm_prompt_updates,
|
||||||
)
|
)
|
||||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||||
|
|
||||||
@ -1677,7 +1723,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
) = self._apply_prompt_updates(
|
) = self._apply_prompt_updates(
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
mm_prompt_updates,
|
mm_prompt_updates,
|
||||||
mm_item_counts,
|
|
||||||
)
|
)
|
||||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user