mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 23:07:03 +08:00
[Optimization] Early return for _apply_matches and _iter_placeholders (#29668)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
8e7a891602
commit
1168768a2d
@ -727,18 +727,35 @@ def _find_matches(
|
|||||||
return mode, matches_to_apply
|
return mode, matches_to_apply
|
||||||
|
|
||||||
|
|
||||||
|
def _all_items_found(
|
||||||
|
mm_item_counts: dict[str, int],
|
||||||
|
mm_found_counts: dict[str, int],
|
||||||
|
) -> bool:
|
||||||
|
return all(
|
||||||
|
item_idx >= mm_item_counts[modality]
|
||||||
|
for modality, item_idx in mm_found_counts.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _apply_matches(
|
def _apply_matches(
|
||||||
prompt: _S,
|
prompt: _S,
|
||||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
|
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
|
||||||
prompt_len = len(prompt)
|
prompt_len = len(prompt)
|
||||||
|
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
|
||||||
|
|
||||||
out_seqs = list[str | list[int]]()
|
out_seqs = list[str | list[int]]()
|
||||||
out_result: MultiModalPromptUpdatesApplyResult = {
|
out_result: MultiModalPromptUpdatesApplyResult = {
|
||||||
m: [None] * len(items) for m, items in mm_prompt_updates.items()
|
m: [None] * len(items) for m, items in mm_prompt_updates.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mm_found_counts = {
|
||||||
|
m: sum(r is not None for r in res) for m, res in out_result.items()
|
||||||
|
}
|
||||||
|
if _all_items_found(mm_item_counts, mm_found_counts):
|
||||||
|
return [prompt], out_result
|
||||||
|
|
||||||
start_idx = prev_end_idx = 0
|
start_idx = prev_end_idx = 0
|
||||||
while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt
|
while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt
|
||||||
found = False
|
found = False
|
||||||
@ -776,6 +793,12 @@ def _apply_matches(
|
|||||||
# Exclude overlapping matches
|
# Exclude overlapping matches
|
||||||
start_idx = prev_end_idx = match.end_idx
|
start_idx = prev_end_idx = match.end_idx
|
||||||
|
|
||||||
|
mm_found_counts = {
|
||||||
|
m: sum(r is not None for r in res) for m, res in out_result.items()
|
||||||
|
}
|
||||||
|
if _all_items_found(mm_item_counts, mm_found_counts):
|
||||||
|
break
|
||||||
|
|
||||||
if not found:
|
if not found:
|
||||||
start_idx += 1
|
start_idx += 1
|
||||||
|
|
||||||
@ -832,12 +855,15 @@ def _iter_placeholders(
|
|||||||
|
|
||||||
Note that empty matches are ignored.
|
Note that empty matches are ignored.
|
||||||
"""
|
"""
|
||||||
prompt_len = len(prompt)
|
|
||||||
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
|
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
|
||||||
|
item_idx_by_modality = {modality: 0 for modality in mm_prompt_updates}
|
||||||
|
|
||||||
item_idx_by_modality = defaultdict[str, int](lambda: 0)
|
if _all_items_found(mm_item_counts, item_idx_by_modality):
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt_len = len(prompt)
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
|
|
||||||
while start_idx < prompt_len:
|
while start_idx < prompt_len:
|
||||||
found = False
|
found = False
|
||||||
|
|
||||||
@ -875,6 +901,9 @@ def _iter_placeholders(
|
|||||||
break
|
break
|
||||||
|
|
||||||
if found:
|
if found:
|
||||||
|
if _all_items_found(mm_item_counts, item_idx_by_modality):
|
||||||
|
return
|
||||||
|
|
||||||
break # Go back to the outer while loop
|
break # Go back to the outer while loop
|
||||||
|
|
||||||
if not found:
|
if not found:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user