[Optimization] Early return for _apply_matches and _iter_placeholders (#29668)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-28 21:26:47 +08:00 committed by GitHub
parent 8e7a891602
commit 1168768a2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -727,18 +727,35 @@ def _find_matches(
return mode, matches_to_apply
def _all_items_found(
mm_item_counts: dict[str, int],
mm_found_counts: dict[str, int],
) -> bool:
return all(
item_idx >= mm_item_counts[modality]
for modality, item_idx in mm_found_counts.items()
)
def _apply_matches(
prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
prompt_len = len(prompt)
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
out_seqs = list[str | list[int]]()
out_result: MultiModalPromptUpdatesApplyResult = {
m: [None] * len(items) for m, items in mm_prompt_updates.items()
}
mm_found_counts = {
m: sum(r is not None for r in res) for m, res in out_result.items()
}
if _all_items_found(mm_item_counts, mm_found_counts):
return [prompt], out_result
start_idx = prev_end_idx = 0
while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt
found = False
@ -776,6 +793,12 @@ def _apply_matches(
# Exclude overlapping matches
start_idx = prev_end_idx = match.end_idx
mm_found_counts = {
m: sum(r is not None for r in res) for m, res in out_result.items()
}
if _all_items_found(mm_item_counts, mm_found_counts):
break
if not found:
start_idx += 1
@ -832,12 +855,15 @@ def _iter_placeholders(
Note that empty matches are ignored.
"""
prompt_len = len(prompt)
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
item_idx_by_modality = {modality: 0 for modality in mm_prompt_updates}
item_idx_by_modality = defaultdict[str, int](lambda: 0)
if _all_items_found(mm_item_counts, item_idx_by_modality):
return
prompt_len = len(prompt)
start_idx = 0
while start_idx < prompt_len:
found = False
@ -875,6 +901,9 @@ def _iter_placeholders(
break
if found:
if _all_items_found(mm_item_counts, item_idx_by_modality):
return
break # Go back to the outer while loop
if not found: