[Doc] [1/N] Initial guide for merged multi-modal processor (#11925)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-10 22:30:25 +08:00 committed by GitHub
parent 241ad7b301
commit 12664ddda5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 433 additions and 198 deletions

View File

@ -3,6 +3,7 @@ sphinx-book-theme==1.0.1
sphinx-copybutton==0.5.2
myst-parser==3.0.1
sphinx-argparse==0.4.0
sphinx-design==0.6.1
sphinx-togglebutton==0.3.2
msgspec
cloudpickle

View File

@ -7,7 +7,7 @@ vLLM provides experimental support for multi-modal models through the {mod}`vllm
Multi-modal inputs can be passed alongside text and token prompts to [supported models](#supported-mm-models)
via the `multi_modal_data` field in {class}`vllm.inputs.PromptType`.
Looking to add your own multi-modal model? Please follow the instructions listed [here](#enabling-multimodal-inputs).
Looking to add your own multi-modal model? Please follow the instructions listed [here](#supports-multimodal).
## Module Contents

View File

@ -3,7 +3,7 @@
## User-facing inputs
```{eval-rst}
.. autodata:: vllm.multimodal.MultiModalDataDict
.. autodata:: vllm.multimodal.inputs.MultiModalDataDict
```
## Internal data structures

View File

@ -43,6 +43,7 @@ extensions = [
"sphinx.ext.autosummary",
"myst_parser",
"sphinxarg.ext",
"sphinx_design",
"sphinx_togglebutton",
]
myst_enable_extensions = [

View File

@ -2,7 +2,7 @@
# Adding a New Model
This section provides more information on how to integrate a [HuggingFace Transformers](https://github.com/huggingface/transformers) model into vLLM.
This section provides more information on how to integrate a [PyTorch](https://pytorch.org/) model into vLLM.
```{toctree}
:caption: Contents

View File

@ -1,6 +1,6 @@
(enabling-multimodal-inputs)=
(supports-multimodal)=
# Enabling Multimodal Inputs
# Multi-Modal Support
This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs](#multimodal-inputs).
@ -37,103 +37,355 @@ Further update the model as follows:
) -> SamplerOutput:
```
## 2. Register input mappers
## 2. Specify processing information
For each modality type that the model accepts as input, decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in {meth}`~torch.nn.Module.forward`.
Next, create a subclass of {class}`~vllm.multimodal.processing.BaseProcessingInfo`
to provide basic information related to HF processing.
### Maximum number of input items
You need to override the abstract method {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_supported_mm_limits`
to return the maximum number of input items for each modality supported by the model.
For example, if the model supports any number of images but only one video per prompt:
```python
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": 1}
```
### Maximum number of placeholder feature tokens
Also, override the abstract method {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max_tokens_per_item`
to return the maximum number of placeholder feature tokens per input item for each modality.
When calling the model, the output embeddings from the visual encoder are assigned to the input positions
containing placeholder feature tokens. Therefore, the number of placeholder feature tokens should be equal
to the size of the output embeddings.
::::{tab-set}
:::{tab-item} Basic example: LLaVA
:sync: llava
Looking at the code of HF's `LlavaForConditionalGeneration`:
```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L530-L544
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
```
The number of placeholder feature tokens per image is `image_features.shape[1]`.
`image_features` is calculated inside the `get_image_features` method:
```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L290-L300
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
```
We can infer that `image_features.shape[1]` is based on `image_outputs.hidden_states.shape[1]` from the vision tower
(`CLIPVisionModel` for the [`llava-hf/llava-1.5-7b-hf`](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model).
Moreover, we only need the sequence length (the second dimension of the tensor) to get `image_features.shape[1]`.
The sequence length is determined by the initial hidden states in `CLIPVisionTransformer` since the attention
mechanism doesn't change the sequence length of the output hidden states.
```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L1094-L1102
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
```
To find the sequence length, we turn to the code of `CLIPVisionEmbeddings`:
```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L247-L257
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
```
We can infer that `embeddings.shape[1] == self.num_positions`, where
```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L195-L196
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
```
Overall, the number of placeholder feature tokens for an image can be calculated as:
```python
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
hf_processor = self.get_hf_processor()
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
num_image_tokens = (image_size // patch_size) ** 2 + 1
if hf_processor.vision_feature_select_strategy == "default":
num_image_tokens -= 1
return num_image_tokens
```
Notice that the number of image tokens doesn't depend on the image width and height.
So, we can calculate the maximum number of image tokens using any image size:
```python
def get_image_size_with_most_features(self) -> ImageSize:
hf_config = self.get_hf_config()
width = height = hf_config.image_size
return ImageSize(width=width, height=height)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
```
And thus, we can override the method as:
```python
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
```
```{note}
Our [actual code](gh-file:vllm/model_executor/models/llava.py) is more abstracted to support vision encoders other than CLIP.
```
:::
::::
## 3. Specify dummy inputs
Then, inherit {class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` to construct dummy inputs for
HF processing as well as memory profiling.
### For memory profiling
Override the abstract method {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`
to construct dummy inputs for memory profiling. This dummy input should result in the worst-case memory usage of
the model so that vLLM can reserve the correct amount of memory for it.
Assuming that the memory usage increases with the number of tokens, the dummy input can be constructed based
on the code for {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max_tokens_per_item`.
::::{tab-set}
:::{tab-item} Basic example: LLaVA
:sync: llava
Making use of the `get_image_size_with_most_features` method implemented in the previous section:
```python
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
hf_config = self.get_hf_config()
target_width, target_height = self.info.get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
```
:::
::::
## 4. Specify processing details
Afterwards, create a subclass of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`
to fill in the missing details about HF processing.
```{seealso}
[Multi-Modal Data Processing](#mm-processing)
```
### Multi-modal fields
Override {class}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` to
return a schema of the tensors outputted by the HF processor that are related to the input multi-modal items.
::::{tab-set}
:::{tab-item} Basic example: LLaVA
:sync: llava
Looking at the model's `forward` method:
```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L387-L404
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
```
The only related keyword argument is `pixel_values` which directly corresponds to input images.
The shape of `pixel_values` is `(N, C, H, W)` where `N` is the number of images.
So, we override the method as follows:
```python
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
)
```
```{note}
Our [actual code](gh-file:vllm/model_executor/models/llava.py) additionally supports
pre-computed image embeddings, which can be passed to be model via the `image_embeds` argument.
```
:::
::::
### Prompt replacements
Override {class}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to
return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances.
Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace
operation performed by the HF processor.
::::{tab-set}
:::{tab-item} Basic example: LLaVA
:sync: llava
Looking at HF's `LlavaProcessor`:
```python
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/processing_llava.py#L167-L170
prompt_strings = []
for sample in text:
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)
```
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
Based on this, we override the method as follows:
```python
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
return [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement,
),
]
```
:::
::::
## 5. Register processor-related classes
After you have defined {class}`~vllm.multimodal.processing.BaseProcessingInfo` (Step 2),
{class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` (Step 3),
and {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` (Step 4),
decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_processor <vllm.multimodal.registry.MultiModalRegistry.register_processor>`
to register them to the multi-modal registry:
```diff
from vllm.model_executor.models.interfaces import SupportsMultiModal
+ from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
+ @MULTIMODAL_REGISTRY.register_processor(YourMultiModalProcessor,
+ info=YourProcessingInfo,
+ dummy_inputs=YourDummyInputsBuilder)
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
```
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
```{seealso}
[Input Processing Pipeline](#input-processing-pipeline)
```
## 3. Register maximum number of multi-modal tokens
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item
and register it via {meth}`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
```diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
```
Here are some examples:
- Image inputs (static feature size): [LLaVA-1.5 Model](gh-file:vllm/model_executor/models/llava.py)
- Image inputs (dynamic feature size): [LLaVA-NeXT Model](gh-file:vllm/model_executor/models/llava_next.py)
```{seealso}
[Input Processing Pipeline](#input-processing-pipeline)
```
## 4. (Optional) Register dummy data
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via {meth}`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.
```diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
```
```{note}
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
```
Here are some examples:
- Image inputs (static feature size): [LLaVA-1.5 Model](gh-file:vllm/model_executor/models/llava.py)
- Image inputs (dynamic feature size): [LLaVA-NeXT Model](gh-file:vllm/model_executor/models/llava_next.py)
```{seealso}
[Input Processing Pipeline](#input-processing-pipeline)
```
## 5. (Optional) Register input processor
Sometimes, there is a need to process inputs at the {class}`~vllm.LLMEngine` level before they are passed to the model executor.
This is often due to the fact that unlike implementations in HuggingFace Transformers, the reshaping and/or expansion of multi-modal embeddings needs to take place outside model's {meth}`~torch.nn.Module.forward` call.
You can register input processors via {meth}`INPUT_REGISTRY.register_input_processor <vllm.inputs.registry.InputRegistry.register_input_processor>`.
```diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
```
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples:
- Insert static number of image tokens: [LLaVA-1.5 Model](gh-file:vllm/model_executor/models/llava.py)
- Insert dynamic number of image tokens: [LLaVA-NeXT Model](gh-file:vllm/model_executor/models/llava_next.py)
```{seealso}
[Input Processing Pipeline](#input-processing-pipeline)
```

View File

@ -48,7 +48,7 @@ ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCaus
```{important}
If your model is a multimodal model, ensure the model class implements the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
Read more about that [here](#enabling-multimodal-inputs).
Read more about that [here](#supports-multimodal).
```
```{note}

View File

@ -1,19 +0,0 @@
(input-processing-pipeline)=
# Input Processing Pipeline
1. Input data is passed to {class}`~vllm.LLMEngine` (or {class}`~vllm.AsyncLLMEngine`).
2. Tokenize the data if necessary.
3. Process the inputs using {meth}`INPUT_REGISTRY.process_input <vllm.inputs.registry.InputRegistry.process_input>`.
- For example, add placeholder tokens to reserve KV cache for multi-modal embeddings.
4. Send the processed inputs to {class}`~vllm.executor.executor_base.ExecutorBase`.
5. Distribute the inputs via {class}`~vllm.worker.worker_base.WorkerBase` to {class}`~vllm.worker.model_runner_base.ModelRunnerBase`.
6. If the data contains multi-modal data, convert it into keyword arguments using {meth}`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
- For example, convert a {class}`PIL.Image.Image` input to its pixel values for a vision model.

View File

@ -1,43 +0,0 @@
(input-processing)=
# Input Processing
```{eval-rst}
.. currentmodule:: vllm.inputs
```
Each model can override parts of vLLM's [input processing pipeline](#input-processing-pipeline) via
{data}`~vllm.inputs.INPUT_REGISTRY` and {data}`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Currently, this mechanism is only utilized in [multi-modal](#multi-modality) models for preprocessing multi-modal input
data in addition to input prompt, but it can be extended to text-only language models when needed.
## Guides
```{toctree}
:maxdepth: 1
input_processing_pipeline
```
## Module Contents
### LLM Engine Inputs
```{eval-rst}
.. autoclass:: vllm.inputs.DecoderOnlyInputs
:members:
:show-inheritance:
```
### Registry
```{eval-rst}
.. autodata:: vllm.inputs.INPUT_REGISTRY
```
```{eval-rst}
.. automodule:: vllm.inputs.registry
:members:
:show-inheritance:
```

View File

@ -0,0 +1,64 @@
(mm-processing)=
# Multi-Modal Data Processing
To enable various optimizations in vLLM such as [chunked prefill](#chunked-prefill) and [prefix caching](#automatic-prefix-caching), we use {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` to provide the correspondence between placeholder feature tokens (e.g. `<image>`) and multi-modal inputs (e.g. the raw input image) based on the outputs of HF processor.
Here are the main features of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`:
## Prompt Replacement Detection
One of the main responsibilies of HF processor is to replace input placeholder tokens (e.g. `<image>` for a single image) with feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size). The information about which tokens have been replaced is key to finding the correspondence between placeholder feature tokens and multi-modal inputs.
In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptReplacement` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. Given this specification, we can automatically detect whether HF has replaced the input placeholder tokens by checking whether the feature placeholder tokens exist in the prompt.
## Tokenized Prompt Inputs
To enable tokenization in a separate process, we support passing input token IDs alongside multi-modal data.
### The problem
Consider that HF processors follow these main steps:
1. Tokenize the text
2. Process multi-modal inputs
3. Perform prompt replacement
And we require that:
- For text + multi-modal inputs, apply all steps 1--3.
- For tokenized + multi-modal inputs, apply only steps 2--3.
How can we achieve this without rewriting HF processors? We can try to call the HF processor several times on different inputs:
- For text + multi-modal inputs, simply call the HF processor directly.
- For tokenized + multi-modal inputs, call the processor only on the multi-modal inputs.
While HF processors support text + multi-modal inputs natively, this is not so for tokenized + multi-modal inputs: an error is thrown if the number of input placeholder tokens do not correspond to the number of multi-modal inputs.
Moreover, since the tokenized text has not passed through the HF processor, we have to apply Step 3 by ourselves to keep the output tokens and multi-modal data consistent with each other.
(mm-dummy-text)=
### Dummy text
We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data.
(mm-automatic-prompt-replacement)=
### Automatic prompt replacement
We address the second issue by implementing model-agnostic code in
{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_replacements` to automatically replace input placeholder tokens with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`.
### Summary
With the help of dummy text and automatic prompt replacement, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`.
## Processor Output Caching
Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238). To alleviate this problem, we cache the multi-modal outputs of HF processor to avoid processing the same multi-modal input (e.g. image) again.
When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache.
Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt replacement code, we apply [automatic prompt replacement](#mm-automatic-prompt-replacement) afterwards to keep the output tokens and multi-modal data consistent with each other.

View File

@ -145,7 +145,7 @@ design/arch_overview
design/huggingface_integration
design/plugin_system
design/kernel/paged_attention
design/input_processing/model_inputs_index
design/mm_processing
design/automatic_prefix_caching
design/multiprocessing
```

View File

@ -14,7 +14,7 @@ and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/ch
To input multi-modal data, follow this schema in {class}`vllm.inputs.PromptType`:
- `prompt`: The prompt should follow the format that is documented on HuggingFace.
- `multi_modal_data`: This is a dictionary that follows the schema defined in {class}`vllm.multimodal.MultiModalDataDict`.
- `multi_modal_data`: This is a dictionary that follows the schema defined in {class}`vllm.multimodal.inputs.MultiModalDataDict`.
### Image

View File

@ -2124,8 +2124,7 @@ class MultiModalConfig:
limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
"""
The maximum number of multi-modal input instances allowed per prompt
for each :class:`~vllm.multimodal.MultiModalPlugin`.
The maximum number of input items allowed per prompt for each modality.
"""
def compute_hash(self) -> str:

View File

@ -11,9 +11,6 @@ INPUT_REGISTRY = InputRegistry()
"""
The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine`
to dispatch data processing according to the target model.
See also:
:ref:`input-processing-pipeline`
"""
__all__ = [

View File

@ -313,9 +313,6 @@ class InputRegistry:
The model is identified by ``model_config``.
See also:
:ref:`enabling-multimodal-inputs`
Note:
This should be called after
:meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
@ -384,10 +381,8 @@ class InputRegistry:
Register an input processor to a model class.
The provided function is invoked on each input to the model. This
happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`.
See also:
:ref:`input-processing-pipeline`
happens before
:meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`.
"""
def wrapper(model_cls: N) -> N:
@ -429,9 +424,6 @@ class InputRegistry:
Apply an input processor to an instance of model inputs.
The model is identified by ``model_config``.
See also:
:ref:`input-processing-pipeline`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture

View File

@ -8,10 +8,10 @@ from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""
The global :class:`~MultiModalRegistry` is used by model runners to
dispatch data processing according to its modality and the target model.
dispatch data processing according to the target model.
See also:
:ref:`input-processing-pipeline`
:ref:`mm-processing`
"""
__all__ = [

View File

@ -90,10 +90,6 @@ class MultiModalPlugin(ABC):
invoked to transform the data into a dictionary of model inputs.
If `None` is provided, then the default input mapper is used instead.
See also:
- :ref:`input-processing-pipeline`
- :ref:`enabling-multimodal-inputs`
"""
def wrapper(model_cls: N) -> N:
@ -126,10 +122,6 @@ class MultiModalPlugin(ABC):
Raises:
TypeError: If the data type is not supported.
See also:
- :ref:`input-processing-pipeline`
- :ref:`enabling-multimodal-inputs`
"""
# Avoid circular import
@ -186,9 +178,6 @@ class MultiModalPlugin(ABC):
for a model class.
If `None` is provided, then the default calculation is used instead.
See also:
:ref:`enabling-multimodal-inputs`
"""
def wrapper(model_cls: N) -> N:
@ -218,9 +207,6 @@ class MultiModalPlugin(ABC):
If this registry is not applicable to the model, `0` is returned.
The model is identified by ``model_config``.
See also:
:ref:`enabling-multimodal-inputs`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture

View File

@ -493,7 +493,8 @@ A dictionary containing placeholder ranges for each modality.
class MultiModalInputsV2(TypedDict):
"""
Represents the outputs of :class:`vllm.multimodal.MultiModalProcessor`,
Represents the outputs of
:class:`vllm.multimodal.processing.BaseMultiModalProcessor`,
ready to be passed to vLLM internals.
"""

View File

@ -100,8 +100,7 @@ class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
class MultiModalRegistry:
"""
A registry that dispatches data processing to the
:class:`~vllm.multimodal.MultiModalPlugin` for each modality.
A registry that dispatches data processing according to the model.
"""
DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
@ -367,8 +366,7 @@ class MultiModalRegistry:
invoked to transform the data into a dictionary of model inputs.
See also:
- :ref:`input-processing-pipeline`
- :ref:`enabling-multimodal-inputs`
:ref:`mm-processing`
"""
def wrapper(model_cls: N) -> N:
@ -398,6 +396,9 @@ class MultiModalRegistry:
def has_processor(self, model_config: "ModelConfig") -> bool:
"""
Test whether a multi-modal processor is defined for a specific model.
See also:
:ref:`mm-processing`
"""
return self._get_model_cls(model_config) in self._processor_factories
@ -408,6 +409,9 @@ class MultiModalRegistry:
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""
Create a multi-modal processor for a specific model and tokenizer.
See also:
:ref:`mm-processing`
"""
model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls]