[Bugfix] Fix O(n²) multimodal string prompt processing (#29667)

Signed-off-by: mertunsall <mertunsal1905@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Mert Unsal 2025-11-28 16:10:39 -08:00 committed by GitHub
parent 6173682b6e
commit c625d7b1c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 65 additions and 33 deletions

View File

@ -15,6 +15,7 @@ from vllm.multimodal.processing import (
PromptIndexTargets,
PromptInsertion,
PromptReplacement,
_apply_matches,
apply_text_matches,
apply_token_matches,
find_mm_placeholders,
@ -1075,3 +1076,38 @@ def test_hf_processor_call_kwargs(
result = ctx.call_hf_processor(processor, {}, inference_kwargs)
assert result == expected_kwargs
def test_apply_matches_no_match_exits_quickly():
"""
Test that _apply_matches exits quickly when no matches are found.
Previously, _apply_matches had O() behavior when no match was found
because it would increment start_idx by 1 each iteration while
re-scanning the entire prompt from prev_end_idx=0.
With the fix, it should exit immediately when no match is found.
"""
import time
mock_tokenizer = cast(AnyTokenizer, object())
# Create a long prompt with no placeholder
long_prompt = "x" * 10000
# Create update looking for a placeholder that doesn't exist
mm_prompt_updates = {
"image": [[PromptReplacement("image", "<image>", "REPLACED").resolve(0)]]
}
start = time.perf_counter()
result, _ = _apply_matches(
long_prompt,
mm_prompt_updates,
mock_tokenizer,
)
elapsed = time.perf_counter() - start
# Should complete in < 100ms (was taking seconds before the fix)
assert elapsed < 0.1, f"_apply_matches took {elapsed:.2f}s, expected < 0.1s"
assert "".join(result) == long_prompt

View File

@ -742,7 +742,6 @@ def _apply_matches(
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
prompt_len = len(prompt)
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
out_seqs = list[str | list[int]]()
@ -750,16 +749,15 @@ def _apply_matches(
m: [None] * len(items) for m, items in mm_prompt_updates.items()
}
# Early exit if no items to find
mm_found_counts = {
m: sum(r is not None for r in res) for m, res in out_result.items()
}
if _all_items_found(mm_item_counts, mm_found_counts):
return [prompt], out_result
start_idx = prev_end_idx = 0
while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt
found = False
prev_end_idx = 0
while True:
mode, matches_to_apply = _find_matches(
prompt,
mm_prompt_updates,
@ -768,39 +766,37 @@ def _apply_matches(
current_result=out_result,
)
if mode is not None:
for (modality, item_idx), (match, update_idx) in matches_to_apply:
found = True
if mode is None:
break # No more matches to find
matched_update = mm_prompt_updates[modality][item_idx][update_idx]
matched_content = matched_update.content.full
for (modality, item_idx), (match, update_idx) in matches_to_apply:
matched_update = mm_prompt_updates[modality][item_idx][update_idx]
matched_content = matched_update.content.full
if mode == UpdateMode.INSERT:
end_idx_to_insert = match.end_idx
elif mode == UpdateMode.REPLACE:
end_idx_to_insert = match.start_idx
else:
assert_never(mode)
if mode == UpdateMode.INSERT:
end_idx_to_insert = match.end_idx
elif mode == UpdateMode.REPLACE:
end_idx_to_insert = match.start_idx
else:
assert_never(mode)
out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
out_seqs.append(
_seq2text(tokenizer, matched_content)
if isinstance(prompt, str)
else _seq2tokens(tokenizer, matched_content)
)
out_result[modality][item_idx] = update_idx
out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
out_seqs.append(
_seq2text(tokenizer, matched_content)
if isinstance(prompt, str)
else _seq2tokens(tokenizer, matched_content)
)
out_result[modality][item_idx] = update_idx
# Exclude overlapping matches
start_idx = prev_end_idx = match.end_idx
# Exclude overlapping matches
prev_end_idx = match.end_idx
mm_found_counts = {
m: sum(r is not None for r in res) for m, res in out_result.items()
}
if _all_items_found(mm_item_counts, mm_found_counts):
break
if not found:
start_idx += 1
# Early exit if all items found
mm_found_counts = {
m: sum(r is not None for r in res) for m, res in out_result.items()
}
if _all_items_found(mm_item_counts, mm_found_counts):
break
out_seqs.append(prompt[prev_end_idx:])