diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index 990eac82d516c..c8046d2485060 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -720,13 +720,13 @@ def _get_mm_fields_config( ::::: -### Prompt replacements +### Prompt updates -Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to -return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances. +Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` to +return a list of {class}`~vllm.multimodal.processing.PromptUpdate` instances. -Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace -operation performed by the HF processor. +Each {class}`~vllm.multimodal.processing.PromptUpdate` instance specifies an update operation +(e.g.: insertion, replacement) performed by the HF processor. ::::{tab-set} :::{tab-item} Basic example: LLaVA @@ -743,15 +743,15 @@ for sample in text: ``` It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`). -Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` as follows: +Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` as follows: ```python -def _get_prompt_replacements( +def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, -) -> list[PromptReplacement]: +) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -859,7 +859,7 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch( ) ``` -To accommodate this, instead of a string you can return an instance of `PromptReplacementDetails` +To accommodate this, instead of a string you can return an instance of `PromptUpdateDetails` with different `full` and `feature` attributes: ```python @@ -878,7 +878,7 @@ def get_replacement_fuyu(item_idx: int): image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows - return PromptReplacementDetails( + return PromptUpdateDetails( full=image_tokens + [bos_token_id], features=image_tokens, ) @@ -888,12 +888,12 @@ Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the we can search for it to conduct the replacement at the start of the string: ```python -def _get_prompt_replacements( +def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, -) -> list[PromptReplacement]: +) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id assert isinstance(bos_token_id, int) @@ -913,7 +913,7 @@ def _get_prompt_replacements( image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows - return PromptReplacementDetails( + return PromptUpdateDetails( full=image_tokens + [bos_token_id], features=image_tokens, ) diff --git a/docs/source/design/mm_processing.md b/docs/source/design/mm_processing.md index a0d01205e638c..2a4dac786d4bc 100644 --- a/docs/source/design/mm_processing.md +++ b/docs/source/design/mm_processing.md @@ -6,11 +6,16 @@ To enable various optimizations in vLLM such as [chunked prefill](#chunked-prefi Here are the main features of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`: -## Prompt Replacement Detection +## Prompt Update Detection -One of the main responsibilies of HF processor is to replace input placeholder tokens (e.g. `` for a single image) with feature placeholder tokens (e.g. `...`, the number of which equals to the feature size). The information about which tokens have been replaced is key to finding the correspondence between placeholder feature tokens and multi-modal inputs. +One of the main responsibilies of HF processor is to update the prompt with placeholder tokens. For example: -In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptReplacement` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. Given this specification, we can automatically detect whether HF has replaced the input placeholder tokens by checking whether the feature placeholder tokens exist in the prompt. +- Insert feature placeholder tokens (e.g. `...`, the number of which equals to the feature size) at the start of the string. +- Replace existing input placeholder tokens (e.g. `` for a single image) with feature placeholder tokens (e.g. `...`, the number of which equals to the feature size). + +The information about which tokens have been updated is key to finding the correspondence between placeholder feature tokens and multi-modal inputs. + +In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptUpdate` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`. We can automatically detect whether HF has updated the prompt by checking the existence of the updated tokens. ## Tokenized Prompt Inputs @@ -22,7 +27,7 @@ Consider that HF processors follow these main steps: 1. Tokenize the text 2. Process multi-modal inputs -3. Perform prompt replacement +3. Perform prompt updates And we require that: @@ -44,16 +49,16 @@ Moreover, since the tokenized text has not passed through the HF processor, we h We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data. -(mm-automatic-prompt-replacement)= +(mm-automatic-prompt-updating)= -### Automatic prompt replacement +### Automatic prompt updating We address the second issue by implementing model-agnostic code in -{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_replacements` to automatically replace input placeholder tokens with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. +{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_updates` to automatically update the prompt with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`. ### Summary -With the help of dummy text and automatic prompt replacement, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`. +With the help of dummy text and automatic prompt updating, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`. ## Processor Output Caching @@ -61,4 +66,4 @@ Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238) When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache. -Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt replacement code, we apply [automatic prompt replacement](#mm-automatic-prompt-replacement) afterwards to keep the output tokens and multi-modal data consistent with each other. +Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt updating code, we apply [automatic prompt updating](#mm-automatic-prompt-updating) afterwards to keep the output tokens and multi-modal data consistent with each other. diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index c2fbe83abc837..878b159250063 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -14,12 +14,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (PlaceholderFeaturesInfo, - PromptReplacement, + PromptInsertion, PromptReplacement, + apply_text_matches, + apply_token_matches, find_mm_placeholders, find_text_matches, find_token_matches, - iter_token_matches, - replace_text_matches, - replace_token_matches) + iter_token_matches) # yapf: enable from vllm.multimodal.profiling import MultiModalProfiler from vllm.transformers_utils.tokenizer import (AnyTokenizer, @@ -102,7 +102,7 @@ def test_iter_token_matches(token_ids, match_ids, expected): { "pattern_1": [], "pattern_2": [], - } + }, ), ( [32000, 32000, 32000, 32000], @@ -147,16 +147,22 @@ def test_iter_token_matches(token_ids, match_ids, expected): ), ], ) +@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement]) # yapf: enable -def test_find_token_matches(prompt, target_by_key, expected_by_key): +def test_find_token_matches( + prompt, + target_by_key, + expected_by_key, + update_type, +): # Should not be used since there is nothing to convert to token IDs mock_tokenizer = cast(AnyTokenizer, object()) - prompt_repls = [ - PromptReplacement(key, target, []).bind(mock_tokenizer) + prompt_updates = [ + update_type(key, target, []).bind(mock_tokenizer) for key, target in target_by_key.items() ] - result = find_token_matches(prompt, prompt_repls) + result = find_token_matches(prompt, prompt_updates) # Only displayed on error print("result:", result) @@ -254,16 +260,22 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key): ), ], ) +@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement]) # yapf: enable -def test_find_text_matches(prompt, target_by_key, expected_by_key): +def test_find_text_matches( + prompt, + target_by_key, + expected_by_key, + update_type, +): # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - prompt_repls = [ - PromptReplacement(key, target, []).bind(mock_tokenizer) + prompt_updates = [ + update_type(key, target, []).bind(mock_tokenizer) for key, target in target_by_key.items() ] - result = find_text_matches(prompt, prompt_repls) + result = find_text_matches(prompt, prompt_updates) # Only displayed on error print("result:", result) @@ -281,7 +293,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # yapf: disable @pytest.mark.parametrize( - ("prompt", "target_by_key", "repl_by_key"), + ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 [ ( "Image:Image:!", @@ -300,58 +312,66 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # Test dynamic replacement (beyond the form of `unit * count`) "pattern_3": "?!?", }, + { + PromptInsertion: { + 0: "Image:Image:!", + 1: "Image:Image:!?!?", + 2: "Image:Image:!?!??!?", # noqa: E501 + }, + PromptReplacement: { + 0: "Image:Image:!", + 1: "Image:?!?", + 2: "?!?", + }, + }, ), ] ) -@pytest.mark.parametrize( - ("mm_count", "expected"), - [ - (0, "Image:Image:!"), - (1, "Image:?!?"), - (2, "?!?"), - ] -) # yapf: enable -def test_find_replace_text( +def test_find_update_text( prompt, target_by_key, repl_by_key, - mm_count, - expected, + expected_by_update_type_mm_count, ): # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - mm_prompt_repls = { - key: [ - PromptReplacement(key, target, - repl_by_key[key]).bind(mock_tokenizer) - ] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_text_matches(prompt, prompt_repls) - for key, prompt_repls in mm_prompt_repls.items() - } + for ( + update_type, + expected_by_mm_count, + ) in expected_by_update_type_mm_count.items(): + mm_prompt_updates = { + key: + [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] + for key, target in target_by_key.items() + } + mm_matches = { + key: find_text_matches(prompt, updates) + for key, updates in mm_prompt_updates.items() + } - result = replace_text_matches( - prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, - ) + for mm_count, expected in expected_by_mm_count.items(): + result = apply_text_matches( + prompt, + mm_matches, + {key: mm_count + for key in repl_by_key}, + ) - # Only displayed on error - print("mm_matches:", mm_matches) - print("result:", result) + # Only displayed on error + print("update_type:", update_type) + print("mm_count:", mm_count) + print("mm_matches:", mm_matches) + print("result:", result) - # Manually constructed results - assert result == expected + # Manually constructed results + assert result == expected # yapf: disable @pytest.mark.parametrize( - ("prompt", "target_by_key", "repl_by_key"), + ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 [ # Tokenized test cases of `test_find_replace_text` # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf @@ -372,53 +392,61 @@ def test_find_replace_text( # Test dynamic replacement (beyond the form of `unit * count`) "pattern_3": [1550, 918, 1550], }, + { + PromptInsertion: { + 0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], + 1: [1, 9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550], # noqa: E501 + 2: [1, 9833, 28747, 32000, 32000, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550, 1550, 918, 1550], # noqa: E501 + }, + PromptReplacement: { + 0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], + 1: [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], # noqa: E501 + 2: [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], + }, + }, ), ] ) -@pytest.mark.parametrize( - ("mm_count", "expected"), - [ - (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]), - (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]), - (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]), - ] -) # yapf: enable -def test_find_replace_tokens( +def test_find_update_tokens( prompt, target_by_key, repl_by_key, - mm_count, - expected, + expected_by_update_type_mm_count, ): # Should not be used since there is nothing to convert to tokens mock_tokenizer = cast(AnyTokenizer, object()) - mm_prompt_repls = { - key: [ - PromptReplacement(key, target, - repl_by_key[key]).bind(mock_tokenizer) - ] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_token_matches(prompt, prompt_repls) - for key, prompt_repls in mm_prompt_repls.items() - } + for ( + update_type, + expected_by_mm_count, + ) in expected_by_update_type_mm_count.items(): + mm_prompt_updates = { + key: + [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] + for key, target in target_by_key.items() + } + mm_matches = { + key: find_token_matches(prompt, updates) + for key, updates in mm_prompt_updates.items() + } - result = replace_token_matches( - prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, - ) + for mm_count, expected in expected_by_mm_count.items(): + result = apply_token_matches( + prompt, + mm_matches, + {key: mm_count + for key in repl_by_key}, + ) - # Only displayed on error - print("mm_matches:", mm_matches) - print("result:", result) + # Only displayed on error + print("update_type:", update_type) + print("mm_count:", mm_count) + print("mm_matches:", mm_matches) + print("result:", result) - # Manually constructed results - assert result == expected + # Manually constructed results + assert result == expected # yapf: disable @@ -524,22 +552,24 @@ def test_find_replace_tokens( ), ] ) +@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement]) # yapf: enable def test_find_mm_placeholders( repl_by_key, prompt, expected, + update_type, ): # Should not be used since there is nothing to convert to tokens mock_tokenizer = cast(AnyTokenizer, object()) - mm_prompt_repls = { - key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)] + mm_prompt_updates = { + key: [update_type(key, [], repl).bind(mock_tokenizer)] for key, repl in repl_by_key.items() } result = find_mm_placeholders( - mm_prompt_repls, + mm_prompt_updates, prompt, # Effectively match all occurrences in the prompt {key: 3 diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 656e9b037d969..061a9a5bd2bcc 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 - -from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, - Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import List, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -26,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -457,12 +457,12 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): pixel_mask=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 23bb3cd07f1d4..61f2f8974d91e 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -19,8 +19,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + BaseProcessingInfo, PromptInsertion, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -474,30 +474,24 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): image_embeds=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() - bos_token_id = tokenizer.bos_token_id - assert isinstance(bos_token_id, int) - image_token_id = vocab[""] num_image_tokens = self.info.get_num_image_tokens() image_tokens = [image_token_id] * num_image_tokens return [ - PromptReplacement( + PromptInsertion( modality="image", - target=[bos_token_id], - replacement=PromptReplacementDetails( - full=image_tokens + [bos_token_id], - features=image_tokens, - ), + target="", + insertion=image_tokens, ) ] diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index e91399b2674df..9d597e240951a 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set, - Tuple, TypedDict, Union) +from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -35,7 +35,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -141,12 +141,12 @@ class ChameleonMultiModalProcessor( ) -> Mapping[str, MultiModalFieldConfig]: return dict(pixel_values=MultiModalFieldConfig.batched("image")) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -162,7 +162,7 @@ class ChameleonMultiModalProcessor( PromptReplacement( modality="image", target=[image_token_id], - replacement=PromptReplacementDetails( + replacement=PromptUpdateDetails( full=([image_start_id] + image_tokens + [image_end_id]), features=image_tokens, ), @@ -371,7 +371,7 @@ class ChameleonDecoderLayer(nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if residual is None: residual = hidden_states diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index ea217e2444040..3d2e452bb50ec 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -3,9 +3,9 @@ # adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py """Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" import math +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -26,7 +26,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache, - PromptReplacement) + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, @@ -281,12 +281,12 @@ class DeepseekVL2MultiModalProcessor( image_embeds=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token_id = hf_processor.image_token_id diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index b71d0de8d707d..c51fcf3d438bc 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import math +from collections import OrderedDict +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, OrderedDict, - Set, Tuple, TypedDict, Union) +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -24,8 +25,7 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, - PromptReplacement, - PromptReplacementDetails) + PromptInsertion, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -803,7 +803,7 @@ class Florence2DummyInputsBuilder( class Florence2MultiModalProcessor( EncDecMultiModalProcessor[Florence2ProcessingInfo]): - def _hf_processor_applies_repl( + def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, @@ -850,26 +850,22 @@ class Florence2MultiModalProcessor( ) -> Mapping[str, MultiModalFieldConfig]: return dict(pixel_values=MultiModalFieldConfig.batched("image")) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() pad_token_id = hf_config.pad_token_id - bos_token_id = hf_config.bos_token_id num_image_tokens = self.info.get_max_image_tokens() image_tokens = [pad_token_id] * num_image_tokens return [ - PromptReplacement( + PromptInsertion( modality="image", - target=[bos_token_id], - replacement=PromptReplacementDetails( - full=image_tokens + [bos_token_id], - features=image_tokens, - ), + target="", + insertion=image_tokens, ) ] diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 7e4cc6bac5e61..581ec54b2cab7 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -17,8 +17,8 @@ # limitations under the License. """ PyTorch Fuyu model.""" import math -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict) +from collections.abc import Iterable, Mapping, Sequence +from typing import List, Literal, Optional, Set, Tuple, TypedDict import torch import torch.nn as nn @@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -203,12 +203,12 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ) -> Mapping[str, MultiModalFieldConfig]: return dict(image_patches=MultiModalFieldConfig.batched("image")) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id assert isinstance(bos_token_id, int) @@ -228,7 +228,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows - return PromptReplacementDetails( + return PromptUpdateDetails( full=image_tokens + [bos_token_id], features=image_tokens, ) diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 48543c5642ea4..ca34c4f8d53f4 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -4,7 +4,8 @@ # https://github.com/THUDM/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" from argparse import Namespace -from typing import Literal, Mapping, Optional, TypedDict, Union +from collections.abc import Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union import torch from torch import nn @@ -32,7 +33,7 @@ from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, BatchFeature, MultiModalFieldConfig, - PromptReplacement) + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig @@ -480,7 +481,7 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): - def _hf_processor_applies_repl( + def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, @@ -495,12 +496,12 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): ) -> Mapping[str, MultiModalFieldConfig]: return dict(pixel_values=MultiModalFieldConfig.batched("image")) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() boi_token_id = hf_config.boi_token_id diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index bab9c256b9aa0..d336d7521a271 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -7,7 +7,8 @@ # Copyright (c) 2024 H2O.AI # Licensed under Apache 2.0 License [see LICENSE for details] # -------------------------------------------------------- -from typing import Mapping, Optional +from collections.abc import Mapping, Sequence +from typing import Optional import torch from PIL import Image @@ -20,7 +21,7 @@ from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, - PromptReplacementDetails) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -487,12 +488,12 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] f"{type(self).__name__} does not support processing cache with " "multi-image support enabled.") - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) if "image_num_patches" in out_mm_kwargs: @@ -527,7 +528,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] if num_patches is not None: assert isinstance(num_patches, int) - return PromptReplacementDetails( + return PromptUpdateDetails( full=hf_processor.get_image_repl_full(feature_size, num_patches), features=hf_processor.get_image_repl_features( diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 0a8763cf910ca..286a75339d20e 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -16,8 +16,8 @@ """Inference-only Idefics3 model compatible with HuggingFace weights.""" import math -from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set, - Tuple, TypedDict, Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import Dict, List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.utils.checkpoint @@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, MultiModalDataItems, MultiModalFieldConfig, - PromptReplacement) + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -274,12 +274,12 @@ class Idefics3MultimodalProcessor( image_embeds=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.image_token.content diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 52ddb279cca39..48c2eb8c9f6e2 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -7,9 +7,10 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, TypeVar, Union) +from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar, + Union) import torch import torch.nn as nn @@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -599,12 +600,12 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): image_token_id=MultiModalFieldConfig.shared("image", num_images), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) if "image_num_patches" in out_mm_kwargs: @@ -636,7 +637,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): if num_patches is not None: assert isinstance(num_patches, int) - return PromptReplacementDetails( + return PromptUpdateDetails( full=hf_processor.get_image_repl_full(feature_size, num_patches), features=hf_processor.get_image_repl_features( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 72b1591306f26..8318a496e6088 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 from abc import abstractmethod +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Final, Iterable, List, Literal, Mapping, Optional, - Protocol, Set, Tuple, TypedDict, TypeVar, Union) +from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, + TypedDict, TypeVar, Union) import torch import torch.nn as nn @@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache, - PromptReplacement) + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -222,12 +223,12 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): ) -> Mapping[str, MultiModalFieldConfig]: raise NotImplementedError - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -328,12 +329,12 @@ class PixtralHFMultiModalProcessor( image_embeds=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_config = self.info.get_hf_config() tokenizer = self.info.get_tokenizer() @@ -789,7 +790,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ")", # 3 tokens ]) - mantis_mm_repls = self._bind_and_group_repls([ + mantis_mm_repls = self._bind_and_group_updates([ PromptReplacement( modality="image", target=[image_token_id] * num_image_tokens, @@ -797,18 +798,18 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ) ]) - prompt_ids, prompt, _ = self._apply_prompt_replacements( + prompt_ids, prompt, _ = self._apply_prompt_updates( result["prompt_token_ids"], mantis_mm_repls, mm_item_counts, ) - unbound_orig_repls = self._get_prompt_replacements( + unbound_orig_repls = self._get_prompt_updates( mm_items, hf_processor_mm_kwargs, mm_kwargs, ) - orig_repls = self._bind_and_group_repls(unbound_orig_repls) + orig_repls = self._bind_and_group_updates(unbound_orig_repls) mm_placeholders = self._find_mm_placeholders( orig_repls, diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 807d6977ed409..ca9406657df58 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 import math +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -21,7 +21,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -183,12 +184,12 @@ class LlavaNextVideoMultiModalProcessor( ) -> Mapping[str, MultiModalFieldConfig]: return dict(pixel_values_videos=MultiModalFieldConfig.batched("video")) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() video_token_id = hf_config.video_token_index diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index e57eea4286e94..e87ef24ce2ca0 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import math +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Final, Iterable, List, Literal, Mapping, Optional, - Protocol, Set, Tuple, TypedDict, Union) +from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, + TypedDict, Union) import torch import torch.nn as nn @@ -22,7 +23,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) -from vllm.multimodal.processing import PromptReplacement +from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.profiling import ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -347,13 +348,13 @@ class LlavaOnevisionMultiModalProcessor( ) return BatchFeature(combined_outputs) - def _hf_processor_applies_repl( + def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], ) -> bool: - base_result = super()._hf_processor_applies_repl( + base_result = super()._hf_processor_applies_updates( prompt_text=prompt_text, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, @@ -361,13 +362,13 @@ class LlavaOnevisionMultiModalProcessor( return base_result and mm_items.get_count("video", strict=False) == 0 - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: - image_repls = super()._get_prompt_replacements( + ) -> Sequence[PromptUpdate]: + image_repls = super()._get_prompt_updates( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, out_mm_kwargs=out_mm_kwargs, @@ -392,7 +393,8 @@ class LlavaOnevisionMultiModalProcessor( return [video_token_id] * num_video_tokens - return image_repls + [ + return [ + *image_repls, PromptReplacement( modality="video", target=[video_token_id], diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index e6111f46143db..f35c230c0cea2 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -22,9 +22,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" +from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Set, Tuple, TypedDict, Union) +from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, + TypedDict, Union) import torch from torch import nn @@ -356,10 +357,10 @@ class MiniCPMOMultiModalProcessor( inputs["audio"]["audio_lens"][index]) return super().get_prompt_texts_by_modality(inputs, modality, index) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]: + out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]: placeholder = { "image": self.info.image_pattern, "video": self.info.video_pattern, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2699958331f3d..fb6ea53acf9e4 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -25,9 +25,10 @@ import math import re from collections import Counter +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property, partial -from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Set, Tuple, TypedDict, Union) +from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, + TypedDict, Union) import numpy as np import torch @@ -732,7 +733,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): } } - def _hf_processor_applies_repl( + def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, @@ -740,10 +741,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ) -> bool: return False - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]: + out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]: placeholder = { "image": self.info.image_pattern, "video": self.info.video_pattern, diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 459928fe3fb0e..36e653e41e1bf 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -15,8 +15,8 @@ # limitations under the License. """PyTorch Mllama model.""" import math -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import numpy as np import torch @@ -59,7 +59,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataDict, MultiModalDataItems) from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, - PromptReplacement) + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from .clip import CLIPMLP @@ -243,12 +243,12 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] image_token_id = self.info.get_hf_config().image_token_index return [image_token_id] * num_images - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: token_per_chunk = self.info.get_token_per_chunk_from_config() image_token_id = self.info.get_hf_config().image_token_index diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index cc4d38d8740b2..60af103189f84 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import math +from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from functools import cached_property, partial -from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, - Union, cast) +from typing import List, Optional, Set, Tuple, TypedDict, Union, cast import numpy as np import torch @@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + BaseProcessingInfo, PromptInsertion, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import JSONTree, json_map_leaves @@ -1190,6 +1190,8 @@ class MolmoProcessingInfo(BaseProcessingInfo): return MolmoProcessorWrapper(processor) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + # TODO: Investigate different `embed_is_patch` between cache/no-cache + # in multi-image case return {"image": 1} def get_mm_max_tokens_per_item( @@ -1328,25 +1330,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): img_patch_id=MultiModalFieldConfig.shared("image", num_images), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - tokenizer = self.info.get_tokenizer() image_token_length_w = processor.image_token_length_w image_token_length_h = processor.image_token_length_h pooling_size = processor.pooling_size - user_str = "User:" - if processor.always_start_with_space: - user_str = " " + user_str - - user_tokens = tokenizer.encode(user_str, add_special_tokens=False) - img_patch_id = processor.image_patch_id img_col_id = processor.im_col_id img_start_id = processor.im_start_id @@ -1356,7 +1351,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): extra_joint = ([img_start_id] + extra_row * image_token_length_h + [img_end_id]) - def get_replacement_molmo(item_idx: int): + def get_insertion_molmo(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) @@ -1371,17 +1366,13 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ((nrows + 1) // pooling_size) + [img_end_id]) image_tokens = extra_joint + joint - - return PromptReplacementDetails( - full=image_tokens + user_tokens, - features=image_tokens, - ) + return image_tokens return [ - PromptReplacement( + PromptInsertion( modality="image", - target=user_str, - replacement=get_replacement_molmo, + target="<|endoftext|>", + insertion=get_insertion_molmo, ) ] diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 5de8eeb3fffed..1e1760491a974 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -6,7 +6,8 @@ # Copyright (c) 2024 NVIDIA # Licensed under Apache 2.0 License [see LICENSE for details] # -------------------------------------------------------- -from typing import Mapping, Optional +from collections.abc import Mapping, Sequence +from typing import Optional import torch import torch.nn as nn @@ -17,8 +18,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) -from vllm.multimodal.processing import (PromptReplacement, - PromptReplacementDetails) +from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import ProcessorInputs from .intern_vit import InternVisionModel @@ -142,12 +143,12 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]): class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) if "image_num_patches" in out_mm_kwargs: @@ -179,7 +180,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): if num_patches is not None: assert isinstance(num_patches, int) - return PromptReplacementDetails( + return PromptUpdateDetails( full=hf_processor.get_image_repl_full(feature_size, num_patches) + "\n", features=hf_processor.get_image_repl_features( diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 0f45f131065a8..0fd4b3c702111 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -38,11 +38,10 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - BoundPromptReplacement, + BaseProcessingInfo, BoundPromptUpdate, PlaceholderFeaturesInfo, - PromptReplacement, - PromptReplacementDetails) + PromptReplacement, PromptUpdate, + PromptUpdateDetails) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -420,12 +419,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): image_embeds=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_tokens: list[str] = hf_processor.img_tokens # type: ignore @@ -449,7 +448,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens - return PromptReplacementDetails( + return PromptUpdateDetails( full=image_tokens + [bos_token_id], features=image_tokens, ) @@ -464,15 +463,15 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ) for image_token in image_tokens[:num_images] ] - def _apply_prompt_replacements( + def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], mm_item_counts: Mapping[str, int], ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: - token_ids, text, placeholders = super()._apply_prompt_replacements( + token_ids, text, placeholders = super()._apply_prompt_updates( token_ids=token_ids, - mm_prompt_repls=mm_prompt_repls, + mm_prompt_updates=mm_prompt_updates, mm_item_counts=mm_item_counts, ) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 3d95e949e71da..bfa90e42733db 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -15,7 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM/NASA Prithvi Geospatial model.""" -from typing import Iterable, Mapping, Optional, Set, Tuple, Union +from collections.abc import Iterable, Mapping, Sequence +from typing import Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -32,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import (IntermediateTensors, PoolerOutput, PoolingSequenceGroupOutput) @@ -44,7 +45,7 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): return {"image": None} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - pass + return {"image": 0} class PrithviGeoSpatialMAEInputBuilder( @@ -78,20 +79,13 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): location_coords=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: - pass - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - pass + ) -> Sequence[PromptUpdate]: + return [] def apply( self, @@ -120,7 +114,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): """ Prithvi Masked Autoencoder""" - def _instantiate_model(self, config: dict) -> nn.Module | None: + def _instantiate_model(self, config: dict) -> Optional[nn.Module]: # We might be able/need to support different tasks with this same model if config["task_args"]["task"] == "SemanticSegmentationTask": @@ -158,7 +152,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): "by PrithviGeospatialMAE.") def _parse_and_validate_multimodal_data( - self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]: + self, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: pixel_values = kwargs.pop("pixel_values", None) if not isinstance(pixel_values, torch.Tensor): diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index f0dc8573ee14e..1c3107e76eb6a 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -21,9 +21,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Any, Iterable, Mapping, Optional, Set, Tuple, TypedDict, - Union) +from typing import Any, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -43,7 +43,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -188,12 +188,12 @@ class Qwen2AudioMultiModalProcessor( feature_attention_mask=MultiModalFieldConfig.batched("audio"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -230,7 +230,7 @@ class Qwen2AudioMultiModalProcessor( audio_tokens = [audio_token_id] * num_features - return PromptReplacementDetails( + return PromptUpdateDetails( full=[audio_bos_id] + audio_tokens + [audio_eos_id], features=audio_tokens, ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 849ef7293bb7f..cb92fcbe9fa1a 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -23,9 +23,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property, partial -from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set, - Tuple, Type, TypedDict, Union) +from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, + Union) import torch import torch.nn as nn @@ -61,7 +62,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors @@ -169,7 +171,7 @@ class Qwen2VisionMLP(nn.Module): self, in_features: int, hidden_features: int, - act_layer: Type[nn.Module] = QuickGELU, + act_layer: type[nn.Module] = QuickGELU, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -383,7 +385,7 @@ class Qwen2VisionBlock(nn.Module): dim: int, num_heads: int, mlp_ratio: float, - act_layer: Type[nn.Module] = QuickGELU, + act_layer: type[nn.Module] = QuickGELU, norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -987,12 +989,12 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] self.info._get_image_processor_kwargs(**mm_kwargs), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( **hf_processor_mm_kwargs) diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index e0d8bf2fa3d25..b8aaa7f1db1bf 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -9,9 +9,10 @@ import copy import math import re import unicodedata +from collections.abc import Collection, Mapping, Sequence +from collections.abc import Set as AbstractSet from functools import lru_cache, partial -from typing import (AbstractSet, Callable, Collection, List, Literal, Mapping, - Optional, TypedDict, Union) +from typing import Callable, List, Literal, Optional, TypedDict, Union import torch from torch import nn @@ -36,7 +37,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -606,7 +607,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): mm_kwargs=mm_kwargs, ) - def _hf_processor_applies_repl( + def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, @@ -624,12 +625,12 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): image_embeds=MultiModalFieldConfig.batched("image"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() special_tokens: dict[str, int] = tokenizer.special_tokens # type: ignore @@ -646,7 +647,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): PromptReplacement( modality="image", target=[img_start_id, img_end_id], - replacement=PromptReplacementDetails( + replacement=PromptUpdateDetails( full=[img_start_id] + image_tokens + [img_end_id], features=image_tokens, ), diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b8d4aef252e5f..d47f924ea19e8 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,9 +3,9 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" import math +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.utils.checkpoint @@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig @@ -197,12 +198,12 @@ class UltravoxMultiModalProcessor( audio_embeds=MultiModalFieldConfig.batched("audio"), ) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index a2eefbc6d8991..2da8c5c8b0e2e 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, - Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import List, Optional, Set, Tuple, TypedDict, Union import torch from torch import nn @@ -31,7 +31,7 @@ from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, - PromptReplacement) + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from .interfaces import SupportsMultiModal, SupportsTranscription @@ -623,12 +623,12 @@ class WhisperMultiModalProcessor( ) -> Mapping[str, MultiModalFieldConfig]: return dict(input_features=MultiModalFieldConfig.batched("audio")) - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> Sequence[PromptUpdate]: num_tokens = self.info.get_max_audio_tokens() return [ PromptReplacement( diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 60b000e2b34ff..ac33af7c10c77 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -6,11 +6,14 @@ from collections import defaultdict from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, Sequence) from dataclasses import dataclass, field +from enum import Enum from functools import lru_cache +from itertools import groupby from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, - TypeVar, Union) + TypeVar, Union, cast) from transformers import BatchFeature, PretrainedConfig, ProcessorMixin +from typing_extensions import assert_never import vllm.envs as envs from vllm.inputs import InputProcessingContext @@ -38,35 +41,129 @@ PromptSeq = Union[str, list[int]] @dataclass -class PromptReplacementDetails: - """Details about the replacement token sequence or text.""" +class PromptUpdateDetails: + """Details about the token sequence or text that are part of the update.""" full: PromptSeq - """The full replacement.""" + """The full content.""" features: PromptSeq """ - The part of the replacement that corresponds to feature placeholders; + The part of the content that corresponds to feature placeholders; this will be replaced by the output of the vision encoder during model inference. """ @staticmethod - def from_seq(seq: PromptSeq) -> "PromptReplacementDetails": - return PromptReplacementDetails(full=seq, features=seq) + def from_seq(seq: PromptSeq) -> "PromptUpdateDetails": + return PromptUpdateDetails(full=seq, features=seq) -PromptRepl = Union[PromptSeq, PromptReplacementDetails] +PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] """ -The replacement token sequence or text. +The token sequence or text that are part of the update. -If only part of the replacement corresponds to feature placeholders, you can -use :class:`PromptReplacementDetails` to specify which part. +If only part of the content corresponds to feature placeholders, you can +use :class:`PromptUpdateDetails` to specify which part. """ +PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo], + PromptUpdateInfo] +""" +Given the index of the processed item within :attr:`modality`, +output the corresponding token sequence (or text). + +For convenience, you can directly pass in the token sequence (or text) +instead of a function if it does not depend on the input. +""" + + +class UpdateMode(str, Enum): + INSERT = "insert" + REPLACE = "replace" + @dataclass -class PromptReplacement: +class PromptUpdate: + """ + Defines how to update a prompt with placeholder tokens. + """ + + modality: str + """The modality for which the update is made.""" + + target: PromptSeq + """The token sequence (or text) to update.""" + + @property + @abstractmethod + def content(self) -> PromptUpdateContent: + """The placeholder tokens that are part of the update.""" + raise NotImplementedError + + @property + @abstractmethod + def mode(self) -> UpdateMode: + """Defines how to update the prompt.""" + raise NotImplementedError + + def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate": + return BoundPromptUpdate( + _origin=self, + tokenizer=tokenizer, + ) + + +@dataclass +class PromptInsertion(PromptUpdate): + """ + Defines how to insert placeholder tokens into a prompt. + + Example: + + For each image, insert a number of ```` feature placeholders + equal to the feature size of the vision encoder at the start of the + prompt: + + .. code-block:: python + + PromptInsertion( + modality="image", + target="", + insertion="" * image_feature_size, + ) + + As above, but insert after the ```` token: + + .. code-block:: python + + PromptInsertion( + modality="image", + target="", + insertion="" * image_feature_size, + ) + """ + + insertion: PromptUpdateContent = field(repr=False) + """ + Given the index of the processed item within :attr:`modality`, + output the token sequence (or text) to insert right after :attr:`target`. + + For convenience, you can directly pass in the token sequence (or text) + instead of a function if it does not depend on the input. + """ + + @property + def content(self) -> PromptUpdateContent: + return self.insertion + + @property + def mode(self) -> UpdateMode: + return UpdateMode.INSERT + + +@dataclass +class PromptReplacement(PromptUpdate): """ Defines how to replace portions of an input prompt with placeholder tokens. @@ -93,7 +190,7 @@ class PromptReplacement: PromptReplacement( modality="image", target="", - replacement=PromptReplacementDetails( + replacement=PromptUpdateDetails( full="".join([ "", "" * image_feature_size, @@ -111,7 +208,7 @@ class PromptReplacement: PromptReplacement( modality="image", target=[image_token_id], - replacement=PromptReplacementDetails( + replacement=PromptUpdateDetails( full=([image_bos_id] + [image_token_id] * image_feature_size + [image_eos_id]), features=[image_token_id] * image_feature_size, @@ -119,29 +216,22 @@ class PromptReplacement: ) """ - modality: str - """The modality for which the replacement is made.""" - - target: PromptSeq - """The token sequence (or text) to find and replace.""" - - replacement: Union[Callable[[int], PromptRepl], - PromptRepl] = field(repr=False) + replacement: PromptUpdateContent = field(repr=False) """ Given the index of the processed item within :attr:`modality`, - output the replacement token sequence (or text). + output the token sequence (or text) to replace :attr:`target`. - For convenience, you can directly pass in the replacement token sequence - (or text) instead of a function if it does not depend on the input. + For convenience, you can directly pass in the token sequence (or text) + instead of a function if it does not depend on the input. """ - def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement": - return BoundPromptReplacement( - tokenizer=tokenizer, - modality=self.modality, - _target=self.target, - _replacement=self.replacement, - ) + @property + def content(self) -> PromptUpdateContent: + return self.replacement + + @property + def mode(self) -> UpdateMode: + return UpdateMode.REPLACE @lru_cache(maxsize=2048) @@ -232,64 +322,73 @@ class _BoundPromptSequence: @dataclass -class _BoundPromptReplacementGroup: +class _BoundPromptContent: full: _BoundPromptSequence features: _BoundPromptSequence @dataclass -class BoundPromptReplacement: +class BoundPromptUpdate: """ - A :class:`PromptReplacement` bound to a tokenizer to automatically - convert :attr:`target` and the result of :meth:`get_replacement` between + A :class:`PromptUpdate` bound to a tokenizer to automatically convert + :attr:`target` and the result of :meth:`get_content` between token sequence and text representations. """ + _origin: PromptUpdate tokenizer: AnyTokenizer = field(repr=False) - modality: str - - _target: PromptSeq - _replacement: Union[Callable[[int], PromptRepl], - PromptRepl] = field(repr=False) def __post_init__(self) -> None: - self._replacement_cache = dict[int, _BoundPromptReplacementGroup]() + self._content_cache = dict[int, _BoundPromptContent]() + + @property + def modality(self) -> str: + return self._origin.modality @property def target(self) -> _BoundPromptSequence: - """The token sequence (or text) to find and replace.""" - return _BoundPromptSequence.from_seq(self.tokenizer, self._target) + """The token sequence (or text) to update.""" + return _BoundPromptSequence.from_seq(self.tokenizer, + self._origin.target) - def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup: + @property + def content(self) -> PromptUpdateContent: + """The placeholder tokens that are part of the update.""" + return self._origin.content + + @property + def mode(self) -> UpdateMode: + """Defines how to update the prompt.""" + return self._origin.mode + + def get_content(self, item_idx: int) -> _BoundPromptContent: """ Given the index of the processed item within :attr:`modality`, - output the replacement token sequence (or text). + output the token sequence (or text) to update. """ - replacement = self._replacement - if callable(replacement): + content = self.content + if callable(content): cache_key = item_idx - if cache_key in self._replacement_cache: - return self._replacement_cache[cache_key] + if cache_key in self._content_cache: + return self._content_cache[cache_key] - replacement = replacement(item_idx) + content = content(item_idx) else: cache_key = None - if not isinstance(replacement, PromptReplacementDetails): - replacement = PromptReplacementDetails.from_seq(replacement) + if not isinstance(content, PromptUpdateDetails): + content = PromptUpdateDetails.from_seq(content) bound_full = _BoundPromptSequence.from_seq(self.tokenizer, - replacement.full) + content.full) bound_features = _BoundPromptSequence.from_seq(self.tokenizer, - replacement.features) - bound_replacement = _BoundPromptReplacementGroup( - full=bound_full, - features=bound_features, - ) + content.features) + bound_content = _BoundPromptContent(full=bound_full, + features=bound_features) if cache_key is not None: - self._replacement_cache[cache_key] = bound_replacement + self._content_cache[cache_key] = bound_content - return bound_replacement + return bound_content class _TokenMatch(NamedTuple): @@ -326,12 +425,12 @@ def iter_token_matches( @dataclass(repr=False) -class _PromptReplacementMatch(ABC): - prompt_repl: BoundPromptReplacement +class _PromptTargetMatch(ABC): + _origin: BoundPromptUpdate @property def modality(self) -> str: - return self.prompt_repl.modality + return self._origin.modality @property @abstractmethod @@ -349,7 +448,7 @@ class _PromptReplacementMatch(ABC): @dataclass(repr=False) -class _PromptReplacementTokenMatch(_PromptReplacementMatch): +class _PromptTargetTokenMatch(_PromptTargetMatch): match: _TokenMatch @property @@ -362,7 +461,7 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch): @dataclass(repr=False) -class _PromptReplacementTextMatch(_PromptReplacementMatch): +class _PromptTargetTextMatch(_PromptTargetMatch): match: re.Match[str] @property @@ -394,40 +493,37 @@ class PlaceholderFeaturesInfo: def find_token_matches( prompt: list[int], - prompt_repls: Sequence[BoundPromptReplacement], -) -> list[_PromptReplacementTokenMatch]: - """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" + prompt_updates: Sequence[BoundPromptUpdate], +) -> Sequence[_PromptTargetMatch]: + """Return each target of :code:`prompt_updates` found in :code:`prompt`.""" return [ - _PromptReplacementTokenMatch(prompt_repl, match) - for prompt_repl in prompt_repls - for match in iter_token_matches(prompt, prompt_repl.target.token_ids) + _PromptTargetTokenMatch(update, match) for update in prompt_updates + for match in iter_token_matches(prompt, update.target.token_ids) ] def find_text_matches( prompt: str, - prompt_repls: Sequence[BoundPromptReplacement], -) -> list[_PromptReplacementTextMatch]: - """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" + prompt_updates: Sequence[BoundPromptUpdate], +) -> Sequence[_PromptTargetMatch]: + """Return each target of :code:`prompt_updates` found in :code:`prompt`.""" return [ - _PromptReplacementTextMatch(prompt_repl, match) - for prompt_repl in prompt_repls - for match in re.finditer(re.escape(prompt_repl.target.text), prompt) + _PromptTargetTextMatch(update, match) for update in prompt_updates + for match in re.finditer(re.escape(update.target.text), prompt) ] def _resolve_matches( prompt: PromptSeq, - mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]], -) -> list[_PromptReplacementMatch]: + mm_matches: Mapping[str, Sequence[_PromptTargetMatch]], +) -> list[_PromptTargetMatch]: """ Resolve :code:`mm_matches` to ensure that there are no overlapping matches, and sort them such that earlier matches take priority over later ones. """ matches = [m for matches in mm_matches.values() for m in matches] - seen_matches: list[Optional[_PromptReplacementMatch]] = [None - ] * len(prompt) + seen_matches: list[Optional[_PromptTargetMatch]] = [None] * len(prompt) for match in matches: for idx in range(match.start_idx, match.end_idx): @@ -441,74 +537,91 @@ def _resolve_matches( return sorted(matches, key=lambda x: x.start_idx) -def _replace_matches( +def _apply_matches( prompt: _S, - mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]], + mm_matches: Mapping[str, Sequence[_PromptTargetMatch]], mm_item_counts: Mapping[str, int], ) -> list[_S]: - """Apply the replacements in :code:`mm_matches` to :code:`prompt`.""" - out_seqs = list[_S]() + """Apply the updates in :code:`mm_matches` to :code:`prompt`.""" + out_seqs = list[Union[str, list[int]]]() prev_end_idx = 0 next_idx_by_modality = defaultdict[str, int](lambda: 0) - for match in _resolve_matches(prompt, mm_matches): - modality = match.modality + for (start_idx, end_idx), group in groupby( + _resolve_matches(prompt, mm_matches), + key=lambda x: (x.start_idx, x.end_idx), + ): + matches = tuple(group) + assert len(matches) == 1 - item_idx = next_idx_by_modality[modality] - if item_idx >= mm_item_counts.get(modality, 0): - continue + for match in matches: + modality = match.modality - start_idx = match.start_idx - end_idx = match.end_idx + item_idx = next_idx_by_modality[modality] + if item_idx >= mm_item_counts.get(modality, 0): + continue - repl_info = match.prompt_repl - replacement = repl_info.get_replacement(item_idx) + origin = match._origin + content = origin.get_content(item_idx) + mode = origin.mode - if isinstance(prompt, str): - repl_seq = replacement.full.text - out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq) - else: - repl_seq = replacement.full.token_ids - out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq) + if mode == UpdateMode.INSERT: + out_seqs.append(prompt[prev_end_idx:end_idx]) + num_inserts = mm_item_counts.get(modality, 0) + elif mode == UpdateMode.REPLACE: + out_seqs.append(prompt[prev_end_idx:start_idx]) + num_inserts = 1 + else: + assert_never(mode) - prev_end_idx = end_idx - next_idx_by_modality[modality] += 1 + for _ in range(num_inserts): + if item_idx >= mm_item_counts.get(modality, 0): + continue + + if isinstance(prompt, str): + out_seqs.append(content.full.text) + else: + out_seqs.append(content.full.token_ids) + + next_idx_by_modality[modality] += 1 + + prev_end_idx = end_idx out_seqs.append(prompt[prev_end_idx:]) - return out_seqs + return cast(list[_S], out_seqs) -def replace_token_matches( +def apply_token_matches( prompt: list[int], - mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]], + mm_matches: Mapping[str, Sequence[_PromptTargetMatch]], mm_item_counts: Mapping[str, int], ) -> list[int]: - """Apply the replacements in :code:`mm_matches` to :code:`prompt`.""" + """Apply the updates in :code:`mm_matches` to :code:`prompt`.""" if not mm_matches: return prompt - token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts) + token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts) return flatten_2d_lists(token_id_seqs) -def replace_text_matches( +def apply_text_matches( prompt: str, - mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]], + mm_matches: Mapping[str, Sequence[_PromptTargetMatch]], mm_item_counts: Mapping[str, int], ) -> str: - """Apply the replacements in :code:`mm_matches` to :code:`prompt`.""" + """Apply the updates in :code:`mm_matches` to :code:`prompt`.""" if not mm_matches: return prompt - texts = _replace_matches(prompt, mm_matches, mm_item_counts) + texts = _apply_matches(prompt, mm_matches, mm_item_counts) return "".join(texts) def _iter_placeholders( - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], mm_item_counts: Mapping[str, int], ) -> Iterable[PlaceholderFeaturesInfo]: @@ -517,7 +630,7 @@ def _iter_placeholders( Matches are exclusive even when multiple modalities share the same placeholder tokens. In that case, the modality that - appears earlier in `mm_prompt_repls` takes priority. + appears earlier in `mm_prompt_updates` takes priority. Note that empty matches are ignored. """ @@ -528,37 +641,37 @@ def _iter_placeholders( while start_idx < prompt_len: found = False - for modality, modality_repls in mm_prompt_repls.items(): + for modality, modality_updates in mm_prompt_updates.items(): item_idx = item_idx_by_modality[modality] if item_idx >= mm_item_counts.get(modality, 0): continue - for repl_info in modality_repls: - replacement = repl_info.get_replacement(item_idx) - repl_tokens_full = replacement.full.token_ids - repl_len_full = len(repl_tokens_full) - end_idx_full = start_idx + repl_len_full + for update_info in modality_updates: + content = update_info.get_content(item_idx) + content_tokens_full = content.full.token_ids + content_len_full = len(content_tokens_full) + end_idx_full = start_idx + content_len_full - if repl_len_full == 0 or end_idx_full > prompt_len: + if content_len_full == 0 or end_idx_full > prompt_len: continue - if prompt[start_idx:end_idx_full] == repl_tokens_full: - repl_tokens_feat = replacement.features.token_ids + if prompt[start_idx:end_idx_full] == content_tokens_full: + content_tokens_feat = content.features.token_ids try: match = next( - iter_token_matches(repl_tokens_full, - repl_tokens_feat)) + iter_token_matches(content_tokens_full, + content_tokens_feat)) yield PlaceholderFeaturesInfo( modality=modality, item_idx=item_idx, start_idx=start_idx + match.start_idx, - tokens=repl_tokens_feat, + tokens=content_tokens_feat, ) except StopIteration: raise AssertionError( - f"{repl_tokens_feat=} should be a " - f"subsequence of {repl_tokens_full=}") from None + f"{content_tokens_feat=} should be a " + f"subsequence of {content_tokens_full=}") from None # Exclude overlapping matches start_idx = end_idx_full @@ -574,11 +687,11 @@ def _iter_placeholders( def find_mm_placeholders( - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], mm_item_counts: Mapping[str, int], ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts) + it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts) return dict(full_groupby_modality(it)) @@ -712,6 +825,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): *, cache: Optional[ProcessingCache] = None, enable_sanity_checks: bool = True) -> None: + if get_repls := getattr(self, "_get_prompt_replacements", None): + logger.warning_once("`_get_prompt_replacements` has been renamed " + "to `_get_prompt_updates`. The old name will " + "be removed in an upcoming release.") + self._get_prompt_updates = get_repls # type: ignore[method-assign] + super().__init__() self.info = info @@ -770,34 +889,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): raise NotImplementedError @abstractmethod - def _get_prompt_replacements( + def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: + ) -> list[PromptUpdate]: """ Given the original multi-modal items for this modality - and HF-processed data, output the replacements to perform. + and HF-processed data, output the updates to perform. Notes: - You should not assume that HF processor always performs prompt - replacement: in :meth:`_apply_hf_processor_missing`, this method + updates: in :meth:`_apply_hf_processor_missing`, this method is called on text-only and multimodal-only inputs separately, instead of passing them in the same call. - - The replacement information returned by this method is also used - to determine the placeholder token positions for each multi-modal + - The update information returned by this method is also used to + determine the placeholder token positions for each multi-modal item. """ raise NotImplementedError def _find_mm_placeholders( self, - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], mm_item_counts: Mapping[str, int], ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - return find_mm_placeholders(mm_prompt_repls, new_token_ids, + return find_mm_placeholders(mm_prompt_updates, new_token_ids, mm_item_counts) def _get_hf_mm_data( @@ -831,14 +950,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_kwargs, ) - def _hf_processor_applies_repl( + def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], ) -> bool: """ - Return whether the HF processor applies prompt replacements. + Return whether the HF processor applies prompt updates. For most HF processors, this should be :code:`True` when multi-modal data items are passed, but :code:`False` when multi-modal embeddings @@ -858,7 +977,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): Apply the HF processor on the prompt text and multi-modal data together. - In addition, return whether prompt replacements have been applied. + In addition, return whether prompt updates have been applied. """ processor_data, passthrough_data = self._get_hf_mm_data(mm_items) @@ -876,13 +995,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), ) - is_repl_applied = self._hf_processor_applies_repl( + is_update_applied = self._hf_processor_applies_updates( prompt_text=prompt_text, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, ) - return prompt_ids, mm_kwargs, is_repl_applied + return prompt_ids, mm_kwargs, is_update_applied def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]: """ @@ -948,21 +1067,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], *, - enable_hf_prompt_replacement: bool, + enable_hf_prompt_update: bool, ) -> tuple[list[int], MultiModalKwargs, bool]: """ Apply the HF processor on the prompt text and multi-modal data. - In addition, return whether prompt replacements have been applied + In addition, return whether prompt updates have been applied (for most HF processors, this should be :code:`True`). Note: - If :code:`enable_hf_prompt_replacement=False`, we use HF processor - to perform prompt replacement if available; HF processor requires + If :code:`enable_hf_prompt_update=False`, we use HF processor + to perform prompt updates if available; HF processor requires that the prompt corresponds to multi-modal items. """ if isinstance(prompt, str): - if enable_hf_prompt_replacement: + if enable_hf_prompt_update: return self._apply_hf_processor_text_mm( prompt_text=prompt, mm_items=mm_items, @@ -999,7 +1118,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt=prompt, mm_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, - enable_hf_prompt_replacement=True, + enable_hf_prompt_update=True, ) mm_maybe_cached_kw_items = { @@ -1022,17 +1141,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_missing_data_items = self._to_mm_items(mm_missing_data) # NOTE: `prompt` does not correspond to `mm_missing_data_items`, - # so we can't apply prompt replacements until the new multimodal + # so we can't apply prompt updates until the new multimodal # items are combined with the cached multimodal items ( prompt_ids, mm_missing_kwargs, - is_repl_applied, + is_update_applied, ) = self._apply_hf_processor_main( prompt=prompt, mm_items=mm_missing_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, - enable_hf_prompt_replacement=False, + enable_hf_prompt_update=False, ) mm_missing_next_idx = { @@ -1071,28 +1190,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) - return prompt_ids, mm_kwargs, is_repl_applied + return prompt_ids, mm_kwargs, is_update_applied - def _bind_and_group_repls( + def _bind_and_group_updates( self, - prompt_repls: list[PromptReplacement], - ) -> dict[str, list[BoundPromptReplacement]]: + prompt_updates: list[PromptUpdate], + ) -> dict[str, list[BoundPromptUpdate]]: tokenizer = self.info.get_tokenizer() - it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) + it = (update.bind(tokenizer) for update in prompt_updates) return dict(full_groupby_modality(it)) - def _apply_prompt_replacements( + def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], mm_item_counts: Mapping[str, int], ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: tokenizer = self.info.get_tokenizer() mm_token_matches = { - modality: find_token_matches(token_ids, prompt_repls) - for modality, prompt_repls in mm_prompt_repls.items() + modality: find_token_matches(token_ids, updates) + for modality, updates in mm_prompt_updates.items() } mm_match_counts = { modality: len(matches) @@ -1107,31 +1226,31 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): # up a token, then the token ID of "foo" will not appear at all # ---- # Since it is inefficient to search for all possible tokenizations - # of the search text in the prompt, we instead perform string - # replacement on the decoded token IDs, then encode them back. + # of the search text in the prompt, we instead perform string-based + # updates on the decoded token IDs, then encode them back. if all( mm_match_counts.get(modality, 0) >= item_count for modality, item_count in mm_item_counts.items() ): # yapf: disable - token_ids = replace_token_matches( + token_ids = apply_token_matches( token_ids, mm_token_matches, mm_item_counts, ) text = decode_tokens(tokenizer, token_ids) - matched_repls = { - modality: [match.prompt_repl for match in token_matches] + matched_updates = { + modality: [match._origin for match in token_matches] for modality, token_matches in mm_token_matches.items() } else: text = decode_tokens(tokenizer, token_ids) mm_text_matches = { - modality: find_text_matches(text, prompt_repls) - for modality, prompt_repls in mm_prompt_repls.items() + modality: find_text_matches(text, updates) + for modality, updates in mm_prompt_updates.items() } - text = replace_text_matches( + text = apply_text_matches( text, mm_text_matches, mm_item_counts, @@ -1140,13 +1259,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): token_ids = encode_tokens(tokenizer, text, add_special_tokens=False) - matched_repls = { - modality: [match.prompt_repl for match in token_matches] + matched_updates = { + modality: [match._origin for match in token_matches] for modality, token_matches in mm_text_matches.items() } placeholders = self._find_mm_placeholders( - matched_repls, + matched_updates, token_ids, mm_item_counts, ) @@ -1184,14 +1303,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): if len(placeholders) != item_count: raise RuntimeError( - f"Expected there to be {item_count} prompt replacements " + f"Expected there to be {item_count} prompt updates " f"corresponding to {item_count} {modality} items, but " - f"instead found {len(placeholders)} prompt replacements! " + f"instead found {len(placeholders)} prompt updates! " "Either the prompt text has missing/incorrect tokens for " "multi-modal inputs, or there is a problem with your " "implementation of merged multi-modal processor for this " "model (usually arising from an inconsistency between " - "`_call_hf_processor` and `_get_prompt_replacements`).") + "`_call_hf_processor` and `_get_prompt_updates`).") def apply( self, @@ -1206,7 +1325,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): 1. Apply HF Processor on prompt text and multi-modal data together, outputting token IDs and processed tensors. - 2. Find and replace sequences in the token IDs with placeholder tokens. + 2. Find and update sequences in the token IDs with placeholder tokens. The number of placeholder tokens equals the feature size of the multi-modal data outputted by the multi-modal encoder. 3. Extract information about the placeholder tokens from the @@ -1235,26 +1354,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ( prompt_ids, mm_kwargs, - is_repl_applied, + is_update_applied, ) = self._cached_apply_hf_processor( prompt, mm_items, hf_processor_mm_kwargs, ) - unbound_prompt_repls = self._get_prompt_replacements( + unbound_prompt_updates = self._get_prompt_updates( mm_items, hf_processor_mm_kwargs, mm_kwargs, ) - mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls) + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates) mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) - if is_repl_applied: + if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_repls, + mm_prompt_updates, prompt_ids, mm_item_counts, ) @@ -1267,9 +1387,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_ids, prompt, mm_placeholders, - ) = self._apply_prompt_replacements( + ) = self._apply_prompt_updates( prompt_ids, - mm_prompt_repls, + mm_prompt_updates, mm_item_counts, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts)