mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:16:06 +08:00
[Misc] Clean up multi-modal processor (#11207)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
a1c02058ba
commit
b10609e6a1
@ -92,10 +92,7 @@ def run_fuyu(question: str, modality: str):
|
||||
def run_phi3v(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
|
||||
# Note: The default setting of max_num_seqs (256) and
|
||||
# max_model_len (128k) for this model may cause OOM.
|
||||
# You may lower either to run this example on lower-end GPUs.
|
||||
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
|
||||
|
||||
# num_crops is an override kwarg to the multimodal image processor;
|
||||
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
|
||||
|
||||
@ -2,10 +2,9 @@ from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
|
||||
_PlaceholderInfo, find_text_matches,
|
||||
find_token_matches, iter_placeholders,
|
||||
iter_token_matches,
|
||||
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_placeholders, iter_token_matches,
|
||||
replace_text_matches,
|
||||
replace_token_matches)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -314,8 +313,8 @@ def test_find_replace_text(
|
||||
result = replace_text_matches(
|
||||
prompt,
|
||||
matches,
|
||||
MultiModalDataItems({key: [None] * mm_count
|
||||
for key in repl_by_key}),
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
@ -380,8 +379,8 @@ def test_find_replace_tokens(
|
||||
result = replace_token_matches(
|
||||
prompt,
|
||||
matches,
|
||||
MultiModalDataItems({key: [None] * mm_count
|
||||
for key in repl_by_key}),
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
@ -476,7 +475,7 @@ def test_iter_placeholders(
|
||||
prompt_repls,
|
||||
prompt,
|
||||
# Effectively match all occurrences in the prompt
|
||||
MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
|
||||
{key: 3 for key in repl_by_key},
|
||||
))
|
||||
|
||||
# Only displayed on error
|
||||
|
||||
@ -403,18 +403,17 @@ def _resolve_matches(
|
||||
def _replace_matches(
|
||||
prompt: _S,
|
||||
matches: Sequence[_PromptReplacementMatch],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[_S]:
|
||||
out_seqs = list[_S]()
|
||||
prev_end_idx = 0
|
||||
next_idx_by_modality = {modality: 0 for modality in mm_items}
|
||||
next_idx_by_modality = {modality: 0 for modality in mm_item_counts}
|
||||
|
||||
for match in _resolve_matches(prompt, matches):
|
||||
modality = match.modality
|
||||
modal_items = mm_items[modality]
|
||||
|
||||
item_idx = next_idx_by_modality[modality]
|
||||
if item_idx >= len(modal_items):
|
||||
if item_idx >= mm_item_counts[modality]:
|
||||
continue
|
||||
|
||||
start_idx = match.start_idx
|
||||
@ -441,13 +440,13 @@ def _replace_matches(
|
||||
def replace_token_matches(
|
||||
prompt: list[int],
|
||||
matches: Sequence[_PromptReplacementTokenMatch],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[int]:
|
||||
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
||||
if not matches:
|
||||
return prompt
|
||||
|
||||
token_id_seqs = _replace_matches(prompt, matches, mm_items)
|
||||
token_id_seqs = _replace_matches(prompt, matches, mm_item_counts)
|
||||
|
||||
return flatten_2d_lists(token_id_seqs)
|
||||
|
||||
@ -455,13 +454,13 @@ def replace_token_matches(
|
||||
def replace_text_matches(
|
||||
prompt: str,
|
||||
matches: Sequence[_PromptReplacementTextMatch],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> str:
|
||||
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
||||
if not matches:
|
||||
return prompt
|
||||
|
||||
texts = _replace_matches(prompt, matches, mm_items)
|
||||
texts = _replace_matches(prompt, matches, mm_item_counts)
|
||||
|
||||
return "".join(texts)
|
||||
|
||||
@ -470,9 +469,9 @@ def _iter_modality_placeholders(
|
||||
prompt: list[int],
|
||||
modality: str,
|
||||
modality_repls: Sequence[_BoundPromptReplacement],
|
||||
modal_items: list[Any],
|
||||
modal_item_count: int,
|
||||
) -> Iterable[_PlaceholderInfo]:
|
||||
if len(modal_items) == 0:
|
||||
if modal_item_count == 0:
|
||||
return
|
||||
|
||||
prompt_len = len(prompt)
|
||||
@ -499,7 +498,7 @@ def _iter_modality_placeholders(
|
||||
)
|
||||
|
||||
item_index += 1
|
||||
if item_index >= len(modal_items):
|
||||
if item_index >= modal_item_count:
|
||||
return
|
||||
|
||||
# Exclude overlapping matches
|
||||
@ -514,7 +513,7 @@ def _iter_modality_placeholders(
|
||||
def iter_placeholders(
|
||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
prompt: list[int],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> Iterable[_PlaceholderInfo]:
|
||||
"""
|
||||
Yield each set of placeholder tokens found in :code:`prompt`.
|
||||
@ -523,13 +522,13 @@ def iter_placeholders(
|
||||
"""
|
||||
repls_by_modality = dict(full_groupby_modality(prompt_repls))
|
||||
|
||||
for modality, modal_items in mm_items.items():
|
||||
for modality, modal_item_count in mm_item_counts.items():
|
||||
if modality in repls_by_modality:
|
||||
yield from _iter_modality_placeholders(
|
||||
prompt,
|
||||
modality,
|
||||
repls_by_modality[modality],
|
||||
modal_items,
|
||||
modal_item_count,
|
||||
)
|
||||
|
||||
|
||||
@ -590,10 +589,10 @@ class BaseMultiModalProcessor(ABC):
|
||||
self,
|
||||
all_prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
new_token_ids: list[int],
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[_PlaceholderInfo]:
|
||||
return list(
|
||||
iter_placeholders(all_prompt_repls, new_token_ids, mm_items))
|
||||
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
|
||||
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
@ -655,10 +654,9 @@ class BaseMultiModalProcessor(ABC):
|
||||
|
||||
def _apply_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
token_ids: list[int],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
@ -675,13 +673,13 @@ class BaseMultiModalProcessor(ABC):
|
||||
# of the search text in the prompt, we instead perform string
|
||||
# replacement on the decoded token IDs, then encode them back.
|
||||
if all(
|
||||
len(matches) >= len(mm_items[modality])
|
||||
len(matches) >= mm_item_counts[modality]
|
||||
for modality, matches in full_groupby_modality(token_matches)
|
||||
): # yapf: disable
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
token_matches,
|
||||
mm_items,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
text = _decode(tokenizer, token_ids)
|
||||
@ -693,14 +691,14 @@ class BaseMultiModalProcessor(ABC):
|
||||
text = replace_text_matches(
|
||||
text,
|
||||
text_matches,
|
||||
mm_items,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
token_ids = _encode(tokenizer, text)
|
||||
matched_repls = [match.prompt_repl for match in text_matches]
|
||||
|
||||
placeholders = self._find_placeholders(matched_repls, token_ids,
|
||||
mm_items)
|
||||
mm_item_counts)
|
||||
|
||||
return token_ids, text, placeholders
|
||||
|
||||
@ -737,8 +735,9 @@ class BaseMultiModalProcessor(ABC):
|
||||
|
||||
# If HF processor already inserts placeholder tokens,
|
||||
# there is no need for us to insert them
|
||||
mm_item_counts = {m: len(items) for m, items in mm_items.items()}
|
||||
all_placeholders = self._find_placeholders(all_prompt_repls,
|
||||
prompt_ids, mm_items)
|
||||
prompt_ids, mm_item_counts)
|
||||
|
||||
if all_placeholders:
|
||||
prompt_text = _decode(tokenizer, prompt_ids)
|
||||
@ -748,10 +747,9 @@ class BaseMultiModalProcessor(ABC):
|
||||
prompt_text,
|
||||
all_placeholders,
|
||||
) = self._apply_prompt_replacements(
|
||||
mm_items,
|
||||
hf_inputs,
|
||||
prompt_ids,
|
||||
all_prompt_repls,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
mm_placeholders = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user