mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 14:04:25 +08:00
[Bugfix] Fix O(n²) multimodal string prompt processing (#29667)
Signed-off-by: mertunsall <mertunsal1905@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
6173682b6e
commit
c625d7b1c6
@ -15,6 +15,7 @@ from vllm.multimodal.processing import (
|
|||||||
PromptIndexTargets,
|
PromptIndexTargets,
|
||||||
PromptInsertion,
|
PromptInsertion,
|
||||||
PromptReplacement,
|
PromptReplacement,
|
||||||
|
_apply_matches,
|
||||||
apply_text_matches,
|
apply_text_matches,
|
||||||
apply_token_matches,
|
apply_token_matches,
|
||||||
find_mm_placeholders,
|
find_mm_placeholders,
|
||||||
@ -1075,3 +1076,38 @@ def test_hf_processor_call_kwargs(
|
|||||||
|
|
||||||
result = ctx.call_hf_processor(processor, {}, inference_kwargs)
|
result = ctx.call_hf_processor(processor, {}, inference_kwargs)
|
||||||
assert result == expected_kwargs
|
assert result == expected_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_matches_no_match_exits_quickly():
|
||||||
|
"""
|
||||||
|
Test that _apply_matches exits quickly when no matches are found.
|
||||||
|
|
||||||
|
Previously, _apply_matches had O(n²) behavior when no match was found
|
||||||
|
because it would increment start_idx by 1 each iteration while
|
||||||
|
re-scanning the entire prompt from prev_end_idx=0.
|
||||||
|
|
||||||
|
With the fix, it should exit immediately when no match is found.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
mock_tokenizer = cast(AnyTokenizer, object())
|
||||||
|
|
||||||
|
# Create a long prompt with no placeholder
|
||||||
|
long_prompt = "x" * 10000
|
||||||
|
|
||||||
|
# Create update looking for a placeholder that doesn't exist
|
||||||
|
mm_prompt_updates = {
|
||||||
|
"image": [[PromptReplacement("image", "<image>", "REPLACED").resolve(0)]]
|
||||||
|
}
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
result, _ = _apply_matches(
|
||||||
|
long_prompt,
|
||||||
|
mm_prompt_updates,
|
||||||
|
mock_tokenizer,
|
||||||
|
)
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
|
||||||
|
# Should complete in < 100ms (was taking seconds before the fix)
|
||||||
|
assert elapsed < 0.1, f"_apply_matches took {elapsed:.2f}s, expected < 0.1s"
|
||||||
|
assert "".join(result) == long_prompt
|
||||||
|
|||||||
@ -742,7 +742,6 @@ def _apply_matches(
|
|||||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
|
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
|
||||||
prompt_len = len(prompt)
|
|
||||||
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
|
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
|
||||||
|
|
||||||
out_seqs = list[str | list[int]]()
|
out_seqs = list[str | list[int]]()
|
||||||
@ -750,16 +749,15 @@ def _apply_matches(
|
|||||||
m: [None] * len(items) for m, items in mm_prompt_updates.items()
|
m: [None] * len(items) for m, items in mm_prompt_updates.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Early exit if no items to find
|
||||||
mm_found_counts = {
|
mm_found_counts = {
|
||||||
m: sum(r is not None for r in res) for m, res in out_result.items()
|
m: sum(r is not None for r in res) for m, res in out_result.items()
|
||||||
}
|
}
|
||||||
if _all_items_found(mm_item_counts, mm_found_counts):
|
if _all_items_found(mm_item_counts, mm_found_counts):
|
||||||
return [prompt], out_result
|
return [prompt], out_result
|
||||||
|
|
||||||
start_idx = prev_end_idx = 0
|
prev_end_idx = 0
|
||||||
while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt
|
while True:
|
||||||
found = False
|
|
||||||
|
|
||||||
mode, matches_to_apply = _find_matches(
|
mode, matches_to_apply = _find_matches(
|
||||||
prompt,
|
prompt,
|
||||||
mm_prompt_updates,
|
mm_prompt_updates,
|
||||||
@ -768,39 +766,37 @@ def _apply_matches(
|
|||||||
current_result=out_result,
|
current_result=out_result,
|
||||||
)
|
)
|
||||||
|
|
||||||
if mode is not None:
|
if mode is None:
|
||||||
for (modality, item_idx), (match, update_idx) in matches_to_apply:
|
break # No more matches to find
|
||||||
found = True
|
|
||||||
|
|
||||||
matched_update = mm_prompt_updates[modality][item_idx][update_idx]
|
for (modality, item_idx), (match, update_idx) in matches_to_apply:
|
||||||
matched_content = matched_update.content.full
|
matched_update = mm_prompt_updates[modality][item_idx][update_idx]
|
||||||
|
matched_content = matched_update.content.full
|
||||||
|
|
||||||
if mode == UpdateMode.INSERT:
|
if mode == UpdateMode.INSERT:
|
||||||
end_idx_to_insert = match.end_idx
|
end_idx_to_insert = match.end_idx
|
||||||
elif mode == UpdateMode.REPLACE:
|
elif mode == UpdateMode.REPLACE:
|
||||||
end_idx_to_insert = match.start_idx
|
end_idx_to_insert = match.start_idx
|
||||||
else:
|
else:
|
||||||
assert_never(mode)
|
assert_never(mode)
|
||||||
|
|
||||||
out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
|
out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
|
||||||
out_seqs.append(
|
out_seqs.append(
|
||||||
_seq2text(tokenizer, matched_content)
|
_seq2text(tokenizer, matched_content)
|
||||||
if isinstance(prompt, str)
|
if isinstance(prompt, str)
|
||||||
else _seq2tokens(tokenizer, matched_content)
|
else _seq2tokens(tokenizer, matched_content)
|
||||||
)
|
)
|
||||||
out_result[modality][item_idx] = update_idx
|
out_result[modality][item_idx] = update_idx
|
||||||
|
|
||||||
# Exclude overlapping matches
|
# Exclude overlapping matches
|
||||||
start_idx = prev_end_idx = match.end_idx
|
prev_end_idx = match.end_idx
|
||||||
|
|
||||||
mm_found_counts = {
|
# Early exit if all items found
|
||||||
m: sum(r is not None for r in res) for m, res in out_result.items()
|
mm_found_counts = {
|
||||||
}
|
m: sum(r is not None for r in res) for m, res in out_result.items()
|
||||||
if _all_items_found(mm_item_counts, mm_found_counts):
|
}
|
||||||
break
|
if _all_items_found(mm_item_counts, mm_found_counts):
|
||||||
|
break
|
||||||
if not found:
|
|
||||||
start_idx += 1
|
|
||||||
|
|
||||||
out_seqs.append(prompt[prev_end_idx:])
|
out_seqs.append(prompt[prev_end_idx:])
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user