mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 16:14:37 +08:00
[Optimization] Early return for _apply_matches and _iter_placeholders (#29668)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
8e7a891602
commit
1168768a2d
@ -727,18 +727,35 @@ def _find_matches(
|
||||
return mode, matches_to_apply
|
||||
|
||||
|
||||
def _all_items_found(
|
||||
mm_item_counts: dict[str, int],
|
||||
mm_found_counts: dict[str, int],
|
||||
) -> bool:
|
||||
return all(
|
||||
item_idx >= mm_item_counts[modality]
|
||||
for modality, item_idx in mm_found_counts.items()
|
||||
)
|
||||
|
||||
|
||||
def _apply_matches(
|
||||
prompt: _S,
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
|
||||
prompt_len = len(prompt)
|
||||
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
|
||||
|
||||
out_seqs = list[str | list[int]]()
|
||||
out_result: MultiModalPromptUpdatesApplyResult = {
|
||||
m: [None] * len(items) for m, items in mm_prompt_updates.items()
|
||||
}
|
||||
|
||||
mm_found_counts = {
|
||||
m: sum(r is not None for r in res) for m, res in out_result.items()
|
||||
}
|
||||
if _all_items_found(mm_item_counts, mm_found_counts):
|
||||
return [prompt], out_result
|
||||
|
||||
start_idx = prev_end_idx = 0
|
||||
while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt
|
||||
found = False
|
||||
@ -776,6 +793,12 @@ def _apply_matches(
|
||||
# Exclude overlapping matches
|
||||
start_idx = prev_end_idx = match.end_idx
|
||||
|
||||
mm_found_counts = {
|
||||
m: sum(r is not None for r in res) for m, res in out_result.items()
|
||||
}
|
||||
if _all_items_found(mm_item_counts, mm_found_counts):
|
||||
break
|
||||
|
||||
if not found:
|
||||
start_idx += 1
|
||||
|
||||
@ -832,12 +855,15 @@ def _iter_placeholders(
|
||||
|
||||
Note that empty matches are ignored.
|
||||
"""
|
||||
prompt_len = len(prompt)
|
||||
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
|
||||
item_idx_by_modality = {modality: 0 for modality in mm_prompt_updates}
|
||||
|
||||
item_idx_by_modality = defaultdict[str, int](lambda: 0)
|
||||
if _all_items_found(mm_item_counts, item_idx_by_modality):
|
||||
return
|
||||
|
||||
prompt_len = len(prompt)
|
||||
start_idx = 0
|
||||
|
||||
while start_idx < prompt_len:
|
||||
found = False
|
||||
|
||||
@ -875,6 +901,9 @@ def _iter_placeholders(
|
||||
break
|
||||
|
||||
if found:
|
||||
if _all_items_found(mm_item_counts, item_idx_by_modality):
|
||||
return
|
||||
|
||||
break # Go back to the outer while loop
|
||||
|
||||
if not found:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user