mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 13:14:34 +08:00
[Bugfix] Fix O(n²) multimodal string prompt processing (#29667)
Signed-off-by: mertunsall <mertunsal1905@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
6173682b6e
commit
c625d7b1c6
@ -15,6 +15,7 @@ from vllm.multimodal.processing import (
|
||||
PromptIndexTargets,
|
||||
PromptInsertion,
|
||||
PromptReplacement,
|
||||
_apply_matches,
|
||||
apply_text_matches,
|
||||
apply_token_matches,
|
||||
find_mm_placeholders,
|
||||
@ -1075,3 +1076,38 @@ def test_hf_processor_call_kwargs(
|
||||
|
||||
result = ctx.call_hf_processor(processor, {}, inference_kwargs)
|
||||
assert result == expected_kwargs
|
||||
|
||||
|
||||
def test_apply_matches_no_match_exits_quickly():
|
||||
"""
|
||||
Test that _apply_matches exits quickly when no matches are found.
|
||||
|
||||
Previously, _apply_matches had O(n²) behavior when no match was found
|
||||
because it would increment start_idx by 1 each iteration while
|
||||
re-scanning the entire prompt from prev_end_idx=0.
|
||||
|
||||
With the fix, it should exit immediately when no match is found.
|
||||
"""
|
||||
import time
|
||||
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
# Create a long prompt with no placeholder
|
||||
long_prompt = "x" * 10000
|
||||
|
||||
# Create update looking for a placeholder that doesn't exist
|
||||
mm_prompt_updates = {
|
||||
"image": [[PromptReplacement("image", "<image>", "REPLACED").resolve(0)]]
|
||||
}
|
||||
|
||||
start = time.perf_counter()
|
||||
result, _ = _apply_matches(
|
||||
long_prompt,
|
||||
mm_prompt_updates,
|
||||
mock_tokenizer,
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
# Should complete in < 100ms (was taking seconds before the fix)
|
||||
assert elapsed < 0.1, f"_apply_matches took {elapsed:.2f}s, expected < 0.1s"
|
||||
assert "".join(result) == long_prompt
|
||||
|
||||
@ -742,7 +742,6 @@ def _apply_matches(
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
|
||||
prompt_len = len(prompt)
|
||||
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
|
||||
|
||||
out_seqs = list[str | list[int]]()
|
||||
@ -750,16 +749,15 @@ def _apply_matches(
|
||||
m: [None] * len(items) for m, items in mm_prompt_updates.items()
|
||||
}
|
||||
|
||||
# Early exit if no items to find
|
||||
mm_found_counts = {
|
||||
m: sum(r is not None for r in res) for m, res in out_result.items()
|
||||
}
|
||||
if _all_items_found(mm_item_counts, mm_found_counts):
|
||||
return [prompt], out_result
|
||||
|
||||
start_idx = prev_end_idx = 0
|
||||
while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt
|
||||
found = False
|
||||
|
||||
prev_end_idx = 0
|
||||
while True:
|
||||
mode, matches_to_apply = _find_matches(
|
||||
prompt,
|
||||
mm_prompt_updates,
|
||||
@ -768,39 +766,37 @@ def _apply_matches(
|
||||
current_result=out_result,
|
||||
)
|
||||
|
||||
if mode is not None:
|
||||
for (modality, item_idx), (match, update_idx) in matches_to_apply:
|
||||
found = True
|
||||
if mode is None:
|
||||
break # No more matches to find
|
||||
|
||||
matched_update = mm_prompt_updates[modality][item_idx][update_idx]
|
||||
matched_content = matched_update.content.full
|
||||
for (modality, item_idx), (match, update_idx) in matches_to_apply:
|
||||
matched_update = mm_prompt_updates[modality][item_idx][update_idx]
|
||||
matched_content = matched_update.content.full
|
||||
|
||||
if mode == UpdateMode.INSERT:
|
||||
end_idx_to_insert = match.end_idx
|
||||
elif mode == UpdateMode.REPLACE:
|
||||
end_idx_to_insert = match.start_idx
|
||||
else:
|
||||
assert_never(mode)
|
||||
if mode == UpdateMode.INSERT:
|
||||
end_idx_to_insert = match.end_idx
|
||||
elif mode == UpdateMode.REPLACE:
|
||||
end_idx_to_insert = match.start_idx
|
||||
else:
|
||||
assert_never(mode)
|
||||
|
||||
out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
|
||||
out_seqs.append(
|
||||
_seq2text(tokenizer, matched_content)
|
||||
if isinstance(prompt, str)
|
||||
else _seq2tokens(tokenizer, matched_content)
|
||||
)
|
||||
out_result[modality][item_idx] = update_idx
|
||||
out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
|
||||
out_seqs.append(
|
||||
_seq2text(tokenizer, matched_content)
|
||||
if isinstance(prompt, str)
|
||||
else _seq2tokens(tokenizer, matched_content)
|
||||
)
|
||||
out_result[modality][item_idx] = update_idx
|
||||
|
||||
# Exclude overlapping matches
|
||||
start_idx = prev_end_idx = match.end_idx
|
||||
# Exclude overlapping matches
|
||||
prev_end_idx = match.end_idx
|
||||
|
||||
mm_found_counts = {
|
||||
m: sum(r is not None for r in res) for m, res in out_result.items()
|
||||
}
|
||||
if _all_items_found(mm_item_counts, mm_found_counts):
|
||||
break
|
||||
|
||||
if not found:
|
||||
start_idx += 1
|
||||
# Early exit if all items found
|
||||
mm_found_counts = {
|
||||
m: sum(r is not None for r in res) for m, res in out_result.items()
|
||||
}
|
||||
if _all_items_found(mm_item_counts, mm_found_counts):
|
||||
break
|
||||
|
||||
out_seqs.append(prompt[prev_end_idx:])
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user