mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 17:25:01 +08:00
[Misc] Clean up multi-modal processor (#11207)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
a1c02058ba
commit
b10609e6a1
@ -92,10 +92,7 @@ def run_fuyu(question: str, modality: str):
|
|||||||
def run_phi3v(question: str, modality: str):
|
def run_phi3v(question: str, modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
|
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
|
||||||
# Note: The default setting of max_num_seqs (256) and
|
|
||||||
# max_model_len (128k) for this model may cause OOM.
|
|
||||||
# You may lower either to run this example on lower-end GPUs.
|
|
||||||
|
|
||||||
# num_crops is an override kwarg to the multimodal image processor;
|
# num_crops is an override kwarg to the multimodal image processor;
|
||||||
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
|
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
|
||||||
|
|||||||
@ -2,10 +2,9 @@ from typing import cast
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
|
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
|
||||||
_PlaceholderInfo, find_text_matches,
|
find_text_matches, find_token_matches,
|
||||||
find_token_matches, iter_placeholders,
|
iter_placeholders, iter_token_matches,
|
||||||
iter_token_matches,
|
|
||||||
replace_text_matches,
|
replace_text_matches,
|
||||||
replace_token_matches)
|
replace_token_matches)
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
@ -314,8 +313,8 @@ def test_find_replace_text(
|
|||||||
result = replace_text_matches(
|
result = replace_text_matches(
|
||||||
prompt,
|
prompt,
|
||||||
matches,
|
matches,
|
||||||
MultiModalDataItems({key: [None] * mm_count
|
{key: mm_count
|
||||||
for key in repl_by_key}),
|
for key in repl_by_key},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only displayed on error
|
# Only displayed on error
|
||||||
@ -380,8 +379,8 @@ def test_find_replace_tokens(
|
|||||||
result = replace_token_matches(
|
result = replace_token_matches(
|
||||||
prompt,
|
prompt,
|
||||||
matches,
|
matches,
|
||||||
MultiModalDataItems({key: [None] * mm_count
|
{key: mm_count
|
||||||
for key in repl_by_key}),
|
for key in repl_by_key},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only displayed on error
|
# Only displayed on error
|
||||||
@ -476,7 +475,7 @@ def test_iter_placeholders(
|
|||||||
prompt_repls,
|
prompt_repls,
|
||||||
prompt,
|
prompt,
|
||||||
# Effectively match all occurrences in the prompt
|
# Effectively match all occurrences in the prompt
|
||||||
MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
|
{key: 3 for key in repl_by_key},
|
||||||
))
|
))
|
||||||
|
|
||||||
# Only displayed on error
|
# Only displayed on error
|
||||||
|
|||||||
@ -403,18 +403,17 @@ def _resolve_matches(
|
|||||||
def _replace_matches(
|
def _replace_matches(
|
||||||
prompt: _S,
|
prompt: _S,
|
||||||
matches: Sequence[_PromptReplacementMatch],
|
matches: Sequence[_PromptReplacementMatch],
|
||||||
mm_items: MultiModalDataItems,
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> list[_S]:
|
) -> list[_S]:
|
||||||
out_seqs = list[_S]()
|
out_seqs = list[_S]()
|
||||||
prev_end_idx = 0
|
prev_end_idx = 0
|
||||||
next_idx_by_modality = {modality: 0 for modality in mm_items}
|
next_idx_by_modality = {modality: 0 for modality in mm_item_counts}
|
||||||
|
|
||||||
for match in _resolve_matches(prompt, matches):
|
for match in _resolve_matches(prompt, matches):
|
||||||
modality = match.modality
|
modality = match.modality
|
||||||
modal_items = mm_items[modality]
|
|
||||||
|
|
||||||
item_idx = next_idx_by_modality[modality]
|
item_idx = next_idx_by_modality[modality]
|
||||||
if item_idx >= len(modal_items):
|
if item_idx >= mm_item_counts[modality]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
start_idx = match.start_idx
|
start_idx = match.start_idx
|
||||||
@ -441,13 +440,13 @@ def _replace_matches(
|
|||||||
def replace_token_matches(
|
def replace_token_matches(
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
matches: Sequence[_PromptReplacementTokenMatch],
|
matches: Sequence[_PromptReplacementTokenMatch],
|
||||||
mm_items: MultiModalDataItems,
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
||||||
if not matches:
|
if not matches:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
token_id_seqs = _replace_matches(prompt, matches, mm_items)
|
token_id_seqs = _replace_matches(prompt, matches, mm_item_counts)
|
||||||
|
|
||||||
return flatten_2d_lists(token_id_seqs)
|
return flatten_2d_lists(token_id_seqs)
|
||||||
|
|
||||||
@ -455,13 +454,13 @@ def replace_token_matches(
|
|||||||
def replace_text_matches(
|
def replace_text_matches(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
matches: Sequence[_PromptReplacementTextMatch],
|
matches: Sequence[_PromptReplacementTextMatch],
|
||||||
mm_items: MultiModalDataItems,
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
||||||
if not matches:
|
if not matches:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
texts = _replace_matches(prompt, matches, mm_items)
|
texts = _replace_matches(prompt, matches, mm_item_counts)
|
||||||
|
|
||||||
return "".join(texts)
|
return "".join(texts)
|
||||||
|
|
||||||
@ -470,9 +469,9 @@ def _iter_modality_placeholders(
|
|||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
modality: str,
|
modality: str,
|
||||||
modality_repls: Sequence[_BoundPromptReplacement],
|
modality_repls: Sequence[_BoundPromptReplacement],
|
||||||
modal_items: list[Any],
|
modal_item_count: int,
|
||||||
) -> Iterable[_PlaceholderInfo]:
|
) -> Iterable[_PlaceholderInfo]:
|
||||||
if len(modal_items) == 0:
|
if modal_item_count == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt_len = len(prompt)
|
prompt_len = len(prompt)
|
||||||
@ -499,7 +498,7 @@ def _iter_modality_placeholders(
|
|||||||
)
|
)
|
||||||
|
|
||||||
item_index += 1
|
item_index += 1
|
||||||
if item_index >= len(modal_items):
|
if item_index >= modal_item_count:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Exclude overlapping matches
|
# Exclude overlapping matches
|
||||||
@ -514,7 +513,7 @@ def _iter_modality_placeholders(
|
|||||||
def iter_placeholders(
|
def iter_placeholders(
|
||||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
mm_items: MultiModalDataItems,
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> Iterable[_PlaceholderInfo]:
|
) -> Iterable[_PlaceholderInfo]:
|
||||||
"""
|
"""
|
||||||
Yield each set of placeholder tokens found in :code:`prompt`.
|
Yield each set of placeholder tokens found in :code:`prompt`.
|
||||||
@ -523,13 +522,13 @@ def iter_placeholders(
|
|||||||
"""
|
"""
|
||||||
repls_by_modality = dict(full_groupby_modality(prompt_repls))
|
repls_by_modality = dict(full_groupby_modality(prompt_repls))
|
||||||
|
|
||||||
for modality, modal_items in mm_items.items():
|
for modality, modal_item_count in mm_item_counts.items():
|
||||||
if modality in repls_by_modality:
|
if modality in repls_by_modality:
|
||||||
yield from _iter_modality_placeholders(
|
yield from _iter_modality_placeholders(
|
||||||
prompt,
|
prompt,
|
||||||
modality,
|
modality,
|
||||||
repls_by_modality[modality],
|
repls_by_modality[modality],
|
||||||
modal_items,
|
modal_item_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -590,10 +589,10 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
self,
|
self,
|
||||||
all_prompt_repls: Sequence[_BoundPromptReplacement],
|
all_prompt_repls: Sequence[_BoundPromptReplacement],
|
||||||
new_token_ids: list[int],
|
new_token_ids: list[int],
|
||||||
mm_items: MultiModalDataItems,
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> list[_PlaceholderInfo]:
|
) -> list[_PlaceholderInfo]:
|
||||||
return list(
|
return list(
|
||||||
iter_placeholders(all_prompt_repls, new_token_ids, mm_items))
|
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
|
||||||
|
|
||||||
def _apply_hf_processor(
|
def _apply_hf_processor(
|
||||||
self,
|
self,
|
||||||
@ -655,10 +654,9 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
|
|
||||||
def _apply_prompt_replacements(
|
def _apply_prompt_replacements(
|
||||||
self,
|
self,
|
||||||
mm_items: MultiModalDataItems,
|
|
||||||
hf_inputs: BatchFeature,
|
|
||||||
token_ids: list[int],
|
token_ids: list[int],
|
||||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||||
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
||||||
tokenizer = self._get_tokenizer()
|
tokenizer = self._get_tokenizer()
|
||||||
|
|
||||||
@ -675,13 +673,13 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
# of the search text in the prompt, we instead perform string
|
# of the search text in the prompt, we instead perform string
|
||||||
# replacement on the decoded token IDs, then encode them back.
|
# replacement on the decoded token IDs, then encode them back.
|
||||||
if all(
|
if all(
|
||||||
len(matches) >= len(mm_items[modality])
|
len(matches) >= mm_item_counts[modality]
|
||||||
for modality, matches in full_groupby_modality(token_matches)
|
for modality, matches in full_groupby_modality(token_matches)
|
||||||
): # yapf: disable
|
): # yapf: disable
|
||||||
token_ids = replace_token_matches(
|
token_ids = replace_token_matches(
|
||||||
token_ids,
|
token_ids,
|
||||||
token_matches,
|
token_matches,
|
||||||
mm_items,
|
mm_item_counts,
|
||||||
)
|
)
|
||||||
|
|
||||||
text = _decode(tokenizer, token_ids)
|
text = _decode(tokenizer, token_ids)
|
||||||
@ -693,14 +691,14 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
text = replace_text_matches(
|
text = replace_text_matches(
|
||||||
text,
|
text,
|
||||||
text_matches,
|
text_matches,
|
||||||
mm_items,
|
mm_item_counts,
|
||||||
)
|
)
|
||||||
|
|
||||||
token_ids = _encode(tokenizer, text)
|
token_ids = _encode(tokenizer, text)
|
||||||
matched_repls = [match.prompt_repl for match in text_matches]
|
matched_repls = [match.prompt_repl for match in text_matches]
|
||||||
|
|
||||||
placeholders = self._find_placeholders(matched_repls, token_ids,
|
placeholders = self._find_placeholders(matched_repls, token_ids,
|
||||||
mm_items)
|
mm_item_counts)
|
||||||
|
|
||||||
return token_ids, text, placeholders
|
return token_ids, text, placeholders
|
||||||
|
|
||||||
@ -737,8 +735,9 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
|
|
||||||
# If HF processor already inserts placeholder tokens,
|
# If HF processor already inserts placeholder tokens,
|
||||||
# there is no need for us to insert them
|
# there is no need for us to insert them
|
||||||
|
mm_item_counts = {m: len(items) for m, items in mm_items.items()}
|
||||||
all_placeholders = self._find_placeholders(all_prompt_repls,
|
all_placeholders = self._find_placeholders(all_prompt_repls,
|
||||||
prompt_ids, mm_items)
|
prompt_ids, mm_item_counts)
|
||||||
|
|
||||||
if all_placeholders:
|
if all_placeholders:
|
||||||
prompt_text = _decode(tokenizer, prompt_ids)
|
prompt_text = _decode(tokenizer, prompt_ids)
|
||||||
@ -748,10 +747,9 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
prompt_text,
|
prompt_text,
|
||||||
all_placeholders,
|
all_placeholders,
|
||||||
) = self._apply_prompt_replacements(
|
) = self._apply_prompt_replacements(
|
||||||
mm_items,
|
|
||||||
hf_inputs,
|
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
all_prompt_repls,
|
all_prompt_repls,
|
||||||
|
mm_item_counts,
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_placeholders = {
|
mm_placeholders = {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user