mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:34:56 +08:00
[Doc] [1/N] Initial guide for merged multi-modal processor (#11925)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
241ad7b301
commit
12664ddda5
@ -3,6 +3,7 @@ sphinx-book-theme==1.0.1
|
|||||||
sphinx-copybutton==0.5.2
|
sphinx-copybutton==0.5.2
|
||||||
myst-parser==3.0.1
|
myst-parser==3.0.1
|
||||||
sphinx-argparse==0.4.0
|
sphinx-argparse==0.4.0
|
||||||
|
sphinx-design==0.6.1
|
||||||
sphinx-togglebutton==0.3.2
|
sphinx-togglebutton==0.3.2
|
||||||
msgspec
|
msgspec
|
||||||
cloudpickle
|
cloudpickle
|
||||||
|
|||||||
@ -7,7 +7,7 @@ vLLM provides experimental support for multi-modal models through the {mod}`vllm
|
|||||||
Multi-modal inputs can be passed alongside text and token prompts to [supported models](#supported-mm-models)
|
Multi-modal inputs can be passed alongside text and token prompts to [supported models](#supported-mm-models)
|
||||||
via the `multi_modal_data` field in {class}`vllm.inputs.PromptType`.
|
via the `multi_modal_data` field in {class}`vllm.inputs.PromptType`.
|
||||||
|
|
||||||
Looking to add your own multi-modal model? Please follow the instructions listed [here](#enabling-multimodal-inputs).
|
Looking to add your own multi-modal model? Please follow the instructions listed [here](#supports-multimodal).
|
||||||
|
|
||||||
## Module Contents
|
## Module Contents
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
## User-facing inputs
|
## User-facing inputs
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. autodata:: vllm.multimodal.MultiModalDataDict
|
.. autodata:: vllm.multimodal.inputs.MultiModalDataDict
|
||||||
```
|
```
|
||||||
|
|
||||||
## Internal data structures
|
## Internal data structures
|
||||||
|
|||||||
@ -43,6 +43,7 @@ extensions = [
|
|||||||
"sphinx.ext.autosummary",
|
"sphinx.ext.autosummary",
|
||||||
"myst_parser",
|
"myst_parser",
|
||||||
"sphinxarg.ext",
|
"sphinxarg.ext",
|
||||||
|
"sphinx_design",
|
||||||
"sphinx_togglebutton",
|
"sphinx_togglebutton",
|
||||||
]
|
]
|
||||||
myst_enable_extensions = [
|
myst_enable_extensions = [
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
# Adding a New Model
|
# Adding a New Model
|
||||||
|
|
||||||
This section provides more information on how to integrate a [HuggingFace Transformers](https://github.com/huggingface/transformers) model into vLLM.
|
This section provides more information on how to integrate a [PyTorch](https://pytorch.org/) model into vLLM.
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:caption: Contents
|
:caption: Contents
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
(enabling-multimodal-inputs)=
|
(supports-multimodal)=
|
||||||
|
|
||||||
# Enabling Multimodal Inputs
|
# Multi-Modal Support
|
||||||
|
|
||||||
This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs](#multimodal-inputs).
|
This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs](#multimodal-inputs).
|
||||||
|
|
||||||
@ -37,103 +37,355 @@ Further update the model as follows:
|
|||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
```
|
```
|
||||||
|
|
||||||
## 2. Register input mappers
|
## 2. Specify processing information
|
||||||
|
|
||||||
For each modality type that the model accepts as input, decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
|
Next, create a subclass of {class}`~vllm.multimodal.processing.BaseProcessingInfo`
|
||||||
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in {meth}`~torch.nn.Module.forward`.
|
to provide basic information related to HF processing.
|
||||||
|
|
||||||
|
### Maximum number of input items
|
||||||
|
|
||||||
|
You need to override the abstract method {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_supported_mm_limits`
|
||||||
|
to return the maximum number of input items for each modality supported by the model.
|
||||||
|
|
||||||
|
For example, if the model supports any number of images but only one video per prompt:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": None, "video": 1}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Maximum number of placeholder feature tokens
|
||||||
|
|
||||||
|
Also, override the abstract method {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max_tokens_per_item`
|
||||||
|
to return the maximum number of placeholder feature tokens per input item for each modality.
|
||||||
|
|
||||||
|
When calling the model, the output embeddings from the visual encoder are assigned to the input positions
|
||||||
|
containing placeholder feature tokens. Therefore, the number of placeholder feature tokens should be equal
|
||||||
|
to the size of the output embeddings.
|
||||||
|
|
||||||
|
::::{tab-set}
|
||||||
|
:::{tab-item} Basic example: LLaVA
|
||||||
|
:sync: llava
|
||||||
|
|
||||||
|
Looking at the code of HF's `LlavaForConditionalGeneration`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L530-L544
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||||
|
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||||
|
|
||||||
|
if n_image_tokens != n_image_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
special_image_mask = (
|
||||||
|
(input_ids == self.config.image_token_index)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||||
|
```
|
||||||
|
|
||||||
|
The number of placeholder feature tokens per image is `image_features.shape[1]`.
|
||||||
|
`image_features` is calculated inside the `get_image_features` method:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L290-L300
|
||||||
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||||
|
|
||||||
|
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||||
|
if vision_feature_select_strategy == "default":
|
||||||
|
selected_image_feature = selected_image_feature[:, 1:]
|
||||||
|
elif vision_feature_select_strategy == "full":
|
||||||
|
selected_image_feature = selected_image_feature
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||||
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
|
return image_features
|
||||||
|
```
|
||||||
|
|
||||||
|
We can infer that `image_features.shape[1]` is based on `image_outputs.hidden_states.shape[1]` from the vision tower
|
||||||
|
(`CLIPVisionModel` for the [`llava-hf/llava-1.5-7b-hf`](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model).
|
||||||
|
Moreover, we only need the sequence length (the second dimension of the tensor) to get `image_features.shape[1]`.
|
||||||
|
The sequence length is determined by the initial hidden states in `CLIPVisionTransformer` since the attention
|
||||||
|
mechanism doesn't change the sequence length of the output hidden states.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L1094-L1102
|
||||||
|
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
|
hidden_states = self.pre_layrnorm(hidden_states)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
To find the sequence length, we turn to the code of `CLIPVisionEmbeddings`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L247-L257
|
||||||
|
target_dtype = self.patch_embedding.weight.dtype
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||||
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||||
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
else:
|
||||||
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||||
|
return embeddings
|
||||||
|
```
|
||||||
|
|
||||||
|
We can infer that `embeddings.shape[1] == self.num_positions`, where
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L195-L196
|
||||||
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||||
|
self.num_positions = self.num_patches + 1
|
||||||
|
```
|
||||||
|
|
||||||
|
Overall, the number of placeholder feature tokens for an image can be calculated as:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_num_image_tokens(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
hf_processor = self.get_hf_processor()
|
||||||
|
|
||||||
|
image_size = hf_config.vision_config.image_size
|
||||||
|
patch_size = hf_config.vision_config.patch_size
|
||||||
|
|
||||||
|
num_image_tokens = (image_size // patch_size) ** 2 + 1
|
||||||
|
if hf_processor.vision_feature_select_strategy == "default":
|
||||||
|
num_image_tokens -= 1
|
||||||
|
|
||||||
|
return num_image_tokens
|
||||||
|
```
|
||||||
|
|
||||||
|
Notice that the number of image tokens doesn't depend on the image width and height.
|
||||||
|
So, we can calculate the maximum number of image tokens using any image size:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
width = height = hf_config.image_size
|
||||||
|
return ImageSize(width=width, height=height)
|
||||||
|
|
||||||
|
def get_max_image_tokens(self) -> int:
|
||||||
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
|
return self.get_num_image_tokens(
|
||||||
|
image_width=target_width,
|
||||||
|
image_height=target_height,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
And thus, we can override the method as:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
|
return {"image": self.get_max_image_tokens()}
|
||||||
|
```
|
||||||
|
|
||||||
|
```{note}
|
||||||
|
Our [actual code](gh-file:vllm/model_executor/models/llava.py) is more abstracted to support vision encoders other than CLIP.
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
::::
|
||||||
|
|
||||||
|
## 3. Specify dummy inputs
|
||||||
|
|
||||||
|
Then, inherit {class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` to construct dummy inputs for
|
||||||
|
HF processing as well as memory profiling.
|
||||||
|
|
||||||
|
### For memory profiling
|
||||||
|
|
||||||
|
Override the abstract method {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`
|
||||||
|
to construct dummy inputs for memory profiling. This dummy input should result in the worst-case memory usage of
|
||||||
|
the model so that vLLM can reserve the correct amount of memory for it.
|
||||||
|
|
||||||
|
Assuming that the memory usage increases with the number of tokens, the dummy input can be constructed based
|
||||||
|
on the code for {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max_tokens_per_item`.
|
||||||
|
|
||||||
|
::::{tab-set}
|
||||||
|
:::{tab-item} Basic example: LLaVA
|
||||||
|
:sync: llava
|
||||||
|
Making use of the `get_image_size_with_most_features` method implemented in the previous section:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_dummy_processor_inputs(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> ProcessorInputs:
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
|
processor = self.info.get_hf_processor()
|
||||||
|
image_token = processor.image_token
|
||||||
|
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
target_width, target_height = self.info.get_image_size_with_most_features()
|
||||||
|
|
||||||
|
mm_data = {
|
||||||
|
"image":
|
||||||
|
self._get_dummy_images(width=target_width,
|
||||||
|
height=target_height,
|
||||||
|
num_images=num_images)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ProcessorInputs(
|
||||||
|
prompt_text=image_token * num_images,
|
||||||
|
mm_data=mm_data,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
::::
|
||||||
|
|
||||||
|
## 4. Specify processing details
|
||||||
|
|
||||||
|
Afterwards, create a subclass of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`
|
||||||
|
to fill in the missing details about HF processing.
|
||||||
|
|
||||||
|
```{seealso}
|
||||||
|
[Multi-Modal Data Processing](#mm-processing)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-modal fields
|
||||||
|
|
||||||
|
Override {class}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` to
|
||||||
|
return a schema of the tensors outputted by the HF processor that are related to the input multi-modal items.
|
||||||
|
|
||||||
|
::::{tab-set}
|
||||||
|
:::{tab-item} Basic example: LLaVA
|
||||||
|
:sync: llava
|
||||||
|
|
||||||
|
Looking at the model's `forward` method:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L387-L404
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
vision_feature_layer: Optional[int] = None,
|
||||||
|
vision_feature_select_strategy: Optional[str] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
num_logits_to_keep: int = 0,
|
||||||
|
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
||||||
|
```
|
||||||
|
|
||||||
|
The only related keyword argument is `pixel_values` which directly corresponds to input images.
|
||||||
|
The shape of `pixel_values` is `(N, C, H, W)` where `N` is the number of images.
|
||||||
|
So, we override the method as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return dict(
|
||||||
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
```{note}
|
||||||
|
Our [actual code](gh-file:vllm/model_executor/models/llava.py) additionally supports
|
||||||
|
pre-computed image embeddings, which can be passed to be model via the `image_embeds` argument.
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
::::
|
||||||
|
|
||||||
|
### Prompt replacements
|
||||||
|
|
||||||
|
Override {class}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to
|
||||||
|
return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances.
|
||||||
|
|
||||||
|
Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace
|
||||||
|
operation performed by the HF processor.
|
||||||
|
|
||||||
|
::::{tab-set}
|
||||||
|
:::{tab-item} Basic example: LLaVA
|
||||||
|
:sync: llava
|
||||||
|
|
||||||
|
Looking at HF's `LlavaProcessor`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/processing_llava.py#L167-L170
|
||||||
|
prompt_strings = []
|
||||||
|
for sample in text:
|
||||||
|
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
||||||
|
prompt_strings.append(sample)
|
||||||
|
```
|
||||||
|
|
||||||
|
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
|
||||||
|
Based on this, we override the method as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _get_prompt_replacements(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> list[PromptReplacement]:
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
image_token_id = hf_config.image_token_index
|
||||||
|
|
||||||
|
def get_replacement(item_idx: int):
|
||||||
|
images = mm_items.get_items("image", ImageProcessorItems)
|
||||||
|
|
||||||
|
image_size = images.get_image_size(item_idx)
|
||||||
|
num_image_tokens = self.info.get_num_image_tokens(
|
||||||
|
image_width=image_size.width,
|
||||||
|
image_height=image_size.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [image_token_id] * num_image_tokens
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality="image",
|
||||||
|
target=[image_token_id],
|
||||||
|
replacement=get_replacement,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
::::
|
||||||
|
|
||||||
|
## 5. Register processor-related classes
|
||||||
|
|
||||||
|
After you have defined {class}`~vllm.multimodal.processing.BaseProcessingInfo` (Step 2),
|
||||||
|
{class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` (Step 3),
|
||||||
|
and {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` (Step 4),
|
||||||
|
decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_processor <vllm.multimodal.registry.MultiModalRegistry.register_processor>`
|
||||||
|
to register them to the multi-modal registry:
|
||||||
|
|
||||||
```diff
|
```diff
|
||||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
+ from vllm.multimodal import MULTIMODAL_REGISTRY
|
+ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
|
+ @MULTIMODAL_REGISTRY.register_processor(YourMultiModalProcessor,
|
||||||
|
+ info=YourProcessingInfo,
|
||||||
|
+ dummy_inputs=YourDummyInputsBuilder)
|
||||||
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
|
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
|
||||||
```
|
```
|
||||||
|
|
||||||
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
|
|
||||||
|
|
||||||
```{seealso}
|
|
||||||
[Input Processing Pipeline](#input-processing-pipeline)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 3. Register maximum number of multi-modal tokens
|
|
||||||
|
|
||||||
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item
|
|
||||||
and register it via {meth}`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
|
|
||||||
|
|
||||||
```diff
|
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
|
||||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
|
||||||
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
|
|
||||||
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
|
|
||||||
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
|
|
||||||
```
|
|
||||||
|
|
||||||
Here are some examples:
|
|
||||||
|
|
||||||
- Image inputs (static feature size): [LLaVA-1.5 Model](gh-file:vllm/model_executor/models/llava.py)
|
|
||||||
- Image inputs (dynamic feature size): [LLaVA-NeXT Model](gh-file:vllm/model_executor/models/llava_next.py)
|
|
||||||
|
|
||||||
```{seealso}
|
|
||||||
[Input Processing Pipeline](#input-processing-pipeline)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 4. (Optional) Register dummy data
|
|
||||||
|
|
||||||
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
|
|
||||||
In such cases, you can define your own dummy data by registering a factory method via {meth}`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.
|
|
||||||
|
|
||||||
```diff
|
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
|
||||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
|
|
||||||
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
|
|
||||||
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
|
|
||||||
```
|
|
||||||
|
|
||||||
```{note}
|
|
||||||
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
|
|
||||||
```
|
|
||||||
|
|
||||||
Here are some examples:
|
|
||||||
|
|
||||||
- Image inputs (static feature size): [LLaVA-1.5 Model](gh-file:vllm/model_executor/models/llava.py)
|
|
||||||
- Image inputs (dynamic feature size): [LLaVA-NeXT Model](gh-file:vllm/model_executor/models/llava_next.py)
|
|
||||||
|
|
||||||
```{seealso}
|
|
||||||
[Input Processing Pipeline](#input-processing-pipeline)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 5. (Optional) Register input processor
|
|
||||||
|
|
||||||
Sometimes, there is a need to process inputs at the {class}`~vllm.LLMEngine` level before they are passed to the model executor.
|
|
||||||
This is often due to the fact that unlike implementations in HuggingFace Transformers, the reshaping and/or expansion of multi-modal embeddings needs to take place outside model's {meth}`~torch.nn.Module.forward` call.
|
|
||||||
You can register input processors via {meth}`INPUT_REGISTRY.register_input_processor <vllm.inputs.registry.InputRegistry.register_input_processor>`.
|
|
||||||
|
|
||||||
```diff
|
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
|
||||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
|
|
||||||
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
|
|
||||||
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
|
|
||||||
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
|
|
||||||
```
|
|
||||||
|
|
||||||
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
|
|
||||||
Here are some examples:
|
|
||||||
|
|
||||||
- Insert static number of image tokens: [LLaVA-1.5 Model](gh-file:vllm/model_executor/models/llava.py)
|
|
||||||
- Insert dynamic number of image tokens: [LLaVA-NeXT Model](gh-file:vllm/model_executor/models/llava_next.py)
|
|
||||||
|
|
||||||
```{seealso}
|
|
||||||
[Input Processing Pipeline](#input-processing-pipeline)
|
|
||||||
```
|
|
||||||
|
|||||||
@ -48,7 +48,7 @@ ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCaus
|
|||||||
|
|
||||||
```{important}
|
```{important}
|
||||||
If your model is a multimodal model, ensure the model class implements the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
|
If your model is a multimodal model, ensure the model class implements the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
|
||||||
Read more about that [here](#enabling-multimodal-inputs).
|
Read more about that [here](#supports-multimodal).
|
||||||
```
|
```
|
||||||
|
|
||||||
```{note}
|
```{note}
|
||||||
|
|||||||
@ -1,19 +0,0 @@
|
|||||||
(input-processing-pipeline)=
|
|
||||||
|
|
||||||
# Input Processing Pipeline
|
|
||||||
|
|
||||||
1. Input data is passed to {class}`~vllm.LLMEngine` (or {class}`~vllm.AsyncLLMEngine`).
|
|
||||||
|
|
||||||
2. Tokenize the data if necessary.
|
|
||||||
|
|
||||||
3. Process the inputs using {meth}`INPUT_REGISTRY.process_input <vllm.inputs.registry.InputRegistry.process_input>`.
|
|
||||||
|
|
||||||
- For example, add placeholder tokens to reserve KV cache for multi-modal embeddings.
|
|
||||||
|
|
||||||
4. Send the processed inputs to {class}`~vllm.executor.executor_base.ExecutorBase`.
|
|
||||||
|
|
||||||
5. Distribute the inputs via {class}`~vllm.worker.worker_base.WorkerBase` to {class}`~vllm.worker.model_runner_base.ModelRunnerBase`.
|
|
||||||
|
|
||||||
6. If the data contains multi-modal data, convert it into keyword arguments using {meth}`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
|
|
||||||
|
|
||||||
- For example, convert a {class}`PIL.Image.Image` input to its pixel values for a vision model.
|
|
||||||
@ -1,43 +0,0 @@
|
|||||||
(input-processing)=
|
|
||||||
|
|
||||||
# Input Processing
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. currentmodule:: vllm.inputs
|
|
||||||
```
|
|
||||||
|
|
||||||
Each model can override parts of vLLM's [input processing pipeline](#input-processing-pipeline) via
|
|
||||||
{data}`~vllm.inputs.INPUT_REGISTRY` and {data}`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
|
||||||
|
|
||||||
Currently, this mechanism is only utilized in [multi-modal](#multi-modality) models for preprocessing multi-modal input
|
|
||||||
data in addition to input prompt, but it can be extended to text-only language models when needed.
|
|
||||||
|
|
||||||
## Guides
|
|
||||||
|
|
||||||
```{toctree}
|
|
||||||
:maxdepth: 1
|
|
||||||
|
|
||||||
input_processing_pipeline
|
|
||||||
```
|
|
||||||
|
|
||||||
## Module Contents
|
|
||||||
|
|
||||||
### LLM Engine Inputs
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autoclass:: vllm.inputs.DecoderOnlyInputs
|
|
||||||
:members:
|
|
||||||
:show-inheritance:
|
|
||||||
```
|
|
||||||
|
|
||||||
### Registry
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autodata:: vllm.inputs.INPUT_REGISTRY
|
|
||||||
```
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. automodule:: vllm.inputs.registry
|
|
||||||
:members:
|
|
||||||
:show-inheritance:
|
|
||||||
```
|
|
||||||
64
docs/source/design/mm_processing.md
Normal file
64
docs/source/design/mm_processing.md
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
(mm-processing)=
|
||||||
|
|
||||||
|
# Multi-Modal Data Processing
|
||||||
|
|
||||||
|
To enable various optimizations in vLLM such as [chunked prefill](#chunked-prefill) and [prefix caching](#automatic-prefix-caching), we use {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` to provide the correspondence between placeholder feature tokens (e.g. `<image>`) and multi-modal inputs (e.g. the raw input image) based on the outputs of HF processor.
|
||||||
|
|
||||||
|
Here are the main features of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`:
|
||||||
|
|
||||||
|
## Prompt Replacement Detection
|
||||||
|
|
||||||
|
One of the main responsibilies of HF processor is to replace input placeholder tokens (e.g. `<image>` for a single image) with feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size). The information about which tokens have been replaced is key to finding the correspondence between placeholder feature tokens and multi-modal inputs.
|
||||||
|
|
||||||
|
In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptReplacement` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. Given this specification, we can automatically detect whether HF has replaced the input placeholder tokens by checking whether the feature placeholder tokens exist in the prompt.
|
||||||
|
|
||||||
|
## Tokenized Prompt Inputs
|
||||||
|
|
||||||
|
To enable tokenization in a separate process, we support passing input token IDs alongside multi-modal data.
|
||||||
|
|
||||||
|
### The problem
|
||||||
|
|
||||||
|
Consider that HF processors follow these main steps:
|
||||||
|
|
||||||
|
1. Tokenize the text
|
||||||
|
2. Process multi-modal inputs
|
||||||
|
3. Perform prompt replacement
|
||||||
|
|
||||||
|
And we require that:
|
||||||
|
|
||||||
|
- For text + multi-modal inputs, apply all steps 1--3.
|
||||||
|
- For tokenized + multi-modal inputs, apply only steps 2--3.
|
||||||
|
|
||||||
|
How can we achieve this without rewriting HF processors? We can try to call the HF processor several times on different inputs:
|
||||||
|
|
||||||
|
- For text + multi-modal inputs, simply call the HF processor directly.
|
||||||
|
- For tokenized + multi-modal inputs, call the processor only on the multi-modal inputs.
|
||||||
|
|
||||||
|
While HF processors support text + multi-modal inputs natively, this is not so for tokenized + multi-modal inputs: an error is thrown if the number of input placeholder tokens do not correspond to the number of multi-modal inputs.
|
||||||
|
|
||||||
|
Moreover, since the tokenized text has not passed through the HF processor, we have to apply Step 3 by ourselves to keep the output tokens and multi-modal data consistent with each other.
|
||||||
|
|
||||||
|
(mm-dummy-text)=
|
||||||
|
|
||||||
|
### Dummy text
|
||||||
|
|
||||||
|
We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data.
|
||||||
|
|
||||||
|
(mm-automatic-prompt-replacement)=
|
||||||
|
|
||||||
|
### Automatic prompt replacement
|
||||||
|
|
||||||
|
We address the second issue by implementing model-agnostic code in
|
||||||
|
{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_replacements` to automatically replace input placeholder tokens with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`.
|
||||||
|
|
||||||
|
### Summary
|
||||||
|
|
||||||
|
With the help of dummy text and automatic prompt replacement, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`.
|
||||||
|
|
||||||
|
## Processor Output Caching
|
||||||
|
|
||||||
|
Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238). To alleviate this problem, we cache the multi-modal outputs of HF processor to avoid processing the same multi-modal input (e.g. image) again.
|
||||||
|
|
||||||
|
When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache.
|
||||||
|
|
||||||
|
Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt replacement code, we apply [automatic prompt replacement](#mm-automatic-prompt-replacement) afterwards to keep the output tokens and multi-modal data consistent with each other.
|
||||||
@ -145,7 +145,7 @@ design/arch_overview
|
|||||||
design/huggingface_integration
|
design/huggingface_integration
|
||||||
design/plugin_system
|
design/plugin_system
|
||||||
design/kernel/paged_attention
|
design/kernel/paged_attention
|
||||||
design/input_processing/model_inputs_index
|
design/mm_processing
|
||||||
design/automatic_prefix_caching
|
design/automatic_prefix_caching
|
||||||
design/multiprocessing
|
design/multiprocessing
|
||||||
```
|
```
|
||||||
|
|||||||
@ -14,7 +14,7 @@ and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/ch
|
|||||||
To input multi-modal data, follow this schema in {class}`vllm.inputs.PromptType`:
|
To input multi-modal data, follow this schema in {class}`vllm.inputs.PromptType`:
|
||||||
|
|
||||||
- `prompt`: The prompt should follow the format that is documented on HuggingFace.
|
- `prompt`: The prompt should follow the format that is documented on HuggingFace.
|
||||||
- `multi_modal_data`: This is a dictionary that follows the schema defined in {class}`vllm.multimodal.MultiModalDataDict`.
|
- `multi_modal_data`: This is a dictionary that follows the schema defined in {class}`vllm.multimodal.inputs.MultiModalDataDict`.
|
||||||
|
|
||||||
### Image
|
### Image
|
||||||
|
|
||||||
|
|||||||
@ -2124,8 +2124,7 @@ class MultiModalConfig:
|
|||||||
|
|
||||||
limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
|
limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
|
||||||
"""
|
"""
|
||||||
The maximum number of multi-modal input instances allowed per prompt
|
The maximum number of input items allowed per prompt for each modality.
|
||||||
for each :class:`~vllm.multimodal.MultiModalPlugin`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
|
|||||||
@ -11,9 +11,6 @@ INPUT_REGISTRY = InputRegistry()
|
|||||||
"""
|
"""
|
||||||
The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine`
|
The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine`
|
||||||
to dispatch data processing according to the target model.
|
to dispatch data processing according to the target model.
|
||||||
|
|
||||||
See also:
|
|
||||||
:ref:`input-processing-pipeline`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -313,9 +313,6 @@ class InputRegistry:
|
|||||||
|
|
||||||
The model is identified by ``model_config``.
|
The model is identified by ``model_config``.
|
||||||
|
|
||||||
See also:
|
|
||||||
:ref:`enabling-multimodal-inputs`
|
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This should be called after
|
This should be called after
|
||||||
:meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
|
:meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
|
||||||
@ -384,10 +381,8 @@ class InputRegistry:
|
|||||||
Register an input processor to a model class.
|
Register an input processor to a model class.
|
||||||
|
|
||||||
The provided function is invoked on each input to the model. This
|
The provided function is invoked on each input to the model. This
|
||||||
happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`.
|
happens before
|
||||||
|
:meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`.
|
||||||
See also:
|
|
||||||
:ref:`input-processing-pipeline`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(model_cls: N) -> N:
|
def wrapper(model_cls: N) -> N:
|
||||||
@ -429,9 +424,6 @@ class InputRegistry:
|
|||||||
Apply an input processor to an instance of model inputs.
|
Apply an input processor to an instance of model inputs.
|
||||||
|
|
||||||
The model is identified by ``model_config``.
|
The model is identified by ``model_config``.
|
||||||
|
|
||||||
See also:
|
|
||||||
:ref:`input-processing-pipeline`
|
|
||||||
"""
|
"""
|
||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
from vllm.model_executor.model_loader import get_model_architecture
|
from vllm.model_executor.model_loader import get_model_architecture
|
||||||
|
|||||||
@ -8,10 +8,10 @@ from .registry import MultiModalRegistry
|
|||||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||||
"""
|
"""
|
||||||
The global :class:`~MultiModalRegistry` is used by model runners to
|
The global :class:`~MultiModalRegistry` is used by model runners to
|
||||||
dispatch data processing according to its modality and the target model.
|
dispatch data processing according to the target model.
|
||||||
|
|
||||||
See also:
|
See also:
|
||||||
:ref:`input-processing-pipeline`
|
:ref:`mm-processing`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -90,10 +90,6 @@ class MultiModalPlugin(ABC):
|
|||||||
invoked to transform the data into a dictionary of model inputs.
|
invoked to transform the data into a dictionary of model inputs.
|
||||||
|
|
||||||
If `None` is provided, then the default input mapper is used instead.
|
If `None` is provided, then the default input mapper is used instead.
|
||||||
|
|
||||||
See also:
|
|
||||||
- :ref:`input-processing-pipeline`
|
|
||||||
- :ref:`enabling-multimodal-inputs`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(model_cls: N) -> N:
|
def wrapper(model_cls: N) -> N:
|
||||||
@ -126,10 +122,6 @@ class MultiModalPlugin(ABC):
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If the data type is not supported.
|
TypeError: If the data type is not supported.
|
||||||
|
|
||||||
See also:
|
|
||||||
- :ref:`input-processing-pipeline`
|
|
||||||
- :ref:`enabling-multimodal-inputs`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
@ -186,9 +178,6 @@ class MultiModalPlugin(ABC):
|
|||||||
for a model class.
|
for a model class.
|
||||||
|
|
||||||
If `None` is provided, then the default calculation is used instead.
|
If `None` is provided, then the default calculation is used instead.
|
||||||
|
|
||||||
See also:
|
|
||||||
:ref:`enabling-multimodal-inputs`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(model_cls: N) -> N:
|
def wrapper(model_cls: N) -> N:
|
||||||
@ -218,9 +207,6 @@ class MultiModalPlugin(ABC):
|
|||||||
If this registry is not applicable to the model, `0` is returned.
|
If this registry is not applicable to the model, `0` is returned.
|
||||||
|
|
||||||
The model is identified by ``model_config``.
|
The model is identified by ``model_config``.
|
||||||
|
|
||||||
See also:
|
|
||||||
:ref:`enabling-multimodal-inputs`
|
|
||||||
"""
|
"""
|
||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
from vllm.model_executor.model_loader import get_model_architecture
|
from vllm.model_executor.model_loader import get_model_architecture
|
||||||
|
|||||||
@ -493,7 +493,8 @@ A dictionary containing placeholder ranges for each modality.
|
|||||||
|
|
||||||
class MultiModalInputsV2(TypedDict):
|
class MultiModalInputsV2(TypedDict):
|
||||||
"""
|
"""
|
||||||
Represents the outputs of :class:`vllm.multimodal.MultiModalProcessor`,
|
Represents the outputs of
|
||||||
|
:class:`vllm.multimodal.processing.BaseMultiModalProcessor`,
|
||||||
ready to be passed to vLLM internals.
|
ready to be passed to vLLM internals.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -100,8 +100,7 @@ class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
|
|||||||
|
|
||||||
class MultiModalRegistry:
|
class MultiModalRegistry:
|
||||||
"""
|
"""
|
||||||
A registry that dispatches data processing to the
|
A registry that dispatches data processing according to the model.
|
||||||
:class:`~vllm.multimodal.MultiModalPlugin` for each modality.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
|
DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
|
||||||
@ -367,8 +366,7 @@ class MultiModalRegistry:
|
|||||||
invoked to transform the data into a dictionary of model inputs.
|
invoked to transform the data into a dictionary of model inputs.
|
||||||
|
|
||||||
See also:
|
See also:
|
||||||
- :ref:`input-processing-pipeline`
|
:ref:`mm-processing`
|
||||||
- :ref:`enabling-multimodal-inputs`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(model_cls: N) -> N:
|
def wrapper(model_cls: N) -> N:
|
||||||
@ -398,6 +396,9 @@ class MultiModalRegistry:
|
|||||||
def has_processor(self, model_config: "ModelConfig") -> bool:
|
def has_processor(self, model_config: "ModelConfig") -> bool:
|
||||||
"""
|
"""
|
||||||
Test whether a multi-modal processor is defined for a specific model.
|
Test whether a multi-modal processor is defined for a specific model.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
:ref:`mm-processing`
|
||||||
"""
|
"""
|
||||||
return self._get_model_cls(model_config) in self._processor_factories
|
return self._get_model_cls(model_config) in self._processor_factories
|
||||||
|
|
||||||
@ -408,6 +409,9 @@ class MultiModalRegistry:
|
|||||||
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
|
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
|
||||||
"""
|
"""
|
||||||
Create a multi-modal processor for a specific model and tokenizer.
|
Create a multi-modal processor for a specific model and tokenizer.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
:ref:`mm-processing`
|
||||||
"""
|
"""
|
||||||
model_cls = self._get_model_cls(model_config)
|
model_cls = self._get_model_cls(model_config)
|
||||||
factories = self._processor_factories[model_cls]
|
factories = self._processor_factories[model_cls]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user