[Core] Registry for processing model inputs (#5214)

Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
Cyrus Leung 2024-06-28 20:09:56 +08:00 committed by GitHub
parent 0d0e3a42ac
commit 5cbe8d155c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 784 additions and 398 deletions

View File

@ -0,0 +1,20 @@
.. _input_processing_pipeline:
Input Processing Pipeline
=========================
1. Input data is passed to :class:`~vllm.LLMEngine` (or :class:`~vllm.AsyncLLMEngine`).
2. Tokenize the data if necessary.
3. Process the inputs using :meth:`INPUT_REGISTRY.process_input <vllm.inputs.registry.InputRegistry.process_input>`.
- For example, add placeholder tokens to reserve KV cache for multi-modal embeddings.
4. Send the processed inputs to :class:`~vllm.executor.executor_base.ExecutorBase`.
5. Distribute the inputs via :class:`~vllm.worker.worker_base.WorkerBase` to :class:`~vllm.worker.model_runner_base.ModelRunnerBase`.
6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision language model.

View File

@ -0,0 +1,39 @@
.. _input_processing:
Input Processing
================
.. currentmodule:: vllm.inputs
vLLM provides a mechanism for defining input processors for each model so that the inputs are processed
in :class:`~vllm.LLMEngine` before they are passed to model executors.
Currently, this mechanism is only utilized in **multi-modal models** for preprocessing multi-modal input
data in addition to input prompt, but it can be extended to text-only language models when needed.
Guides
++++++
.. toctree::
:maxdepth: 1
input_processing_pipeline
Module Contents
+++++++++++++++
LLM Engine Inputs
-----------------
.. autoclass:: vllm.inputs.LLMInputs
:members:
:show-inheritance:
Registry
--------
.. autodata:: vllm.inputs.INPUT_REGISTRY
.. automodule:: vllm.inputs.registry
:members:
:show-inheritance:

View File

@ -12,10 +12,6 @@ By default, vLLM models do not support multi-modal inputs. To enable multi-modal
you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data <MultiModalRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input <MultiModalRegistry.register_input>` for each modality type to support.
.. contents::
:local:
:backlinks: none
Module Contents
+++++++++++++++
@ -24,9 +20,7 @@ Module Contents
Registry
--------
.. data:: vllm.multimodal.MULTIMODAL_REGISTRY
The global :class:`MultiModalRegistry` which is used by model runners.
.. autodata:: vllm.multimodal.MULTIMODAL_REGISTRY
.. autoclass:: vllm.multimodal.MultiModalRegistry
:members:

View File

@ -120,6 +120,7 @@ Documentation
dev/offline_inference/offline_index
dev/engine/engine_index
dev/kernel/paged_attention
dev/input_processing/model_inputs_index
dev/multimodal/multimodal_index
dev/dockerfile/dockerfile

View File

@ -37,7 +37,7 @@ For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/
2. Rewrite the :code:`forward` methods
--------------------------------------
Next, you need to rewrite the :code:`forward` methods of your model by following these steps:
Next, you need to rewrite the :meth:`~torch.nn.Module.forward` method of your model by following these steps:
1. Remove any unnecessary code, such as the code only used for training.
2. Change the input parameters:
@ -75,7 +75,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
For the embedding layer, you can simply replace :class:`torch.nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
When it comes to the linear layers, we provide the following options to parallelize them:
* :code:`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.

View File

@ -11,7 +11,7 @@ def run_phi3v():
model_path = "microsoft/Phi-3-vision-128k-instruct"
# Note: The model has 128k context length by default which may cause OOM
# If that's the case, override `max_model_len` with a smaller value via args
# In this example, we override max_model_len to 2048.
llm = LLM(
model=model_path,
trust_remote_code=True,
@ -19,6 +19,7 @@ def run_phi3v():
image_token_id=32044,
image_input_shape="1,3,1008,1344",
image_feature_size=1921,
max_model_len=2048,
)
image = Image.open("images/cherry_blossom.jpg")

View File

@ -25,14 +25,14 @@ def test_clip_image_processor(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
),
)
for asset in image_assets:
@ -40,10 +40,9 @@ def test_clip_image_processor(image_assets, dtype):
asset.pil_image,
return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.process_input(
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
model_config=model_config,
vlm_config=vlm_config,
)
assert hf_result.keys() == vllm_result.keys()
@ -74,14 +73,14 @@ def test_llava_next_image_processor(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=64000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=2928,
image_processor=MODEL_NAME,
image_processor_revision=None,
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=64000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=2928,
image_processor=MODEL_NAME,
image_processor_revision=None,
),
)
for asset in image_assets:
@ -89,10 +88,9 @@ def test_llava_next_image_processor(image_assets, dtype):
asset.pil_image,
return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.process_input(
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
model_config=model_config,
vlm_config=vlm_config,
)
assert hf_result.keys() == vllm_result.keys()
@ -119,26 +117,23 @@ def test_image_pixel_types(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
)
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
))
for asset in image_assets:
image_result = MULTIMODAL_REGISTRY.process_input(
image_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
model_config=model_config,
vlm_config=vlm_config,
)
tensor_result = MULTIMODAL_REGISTRY.process_input(
tensor_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pixel_values),
model_config=model_config,
vlm_config=vlm_config,
)
assert image_result.keys() == tensor_result.keys()

View File

@ -109,6 +109,7 @@ class ModelConfig:
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
multimodal_config: Optional["VisionLanguageConfig"] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
@ -159,6 +160,8 @@ class ModelConfig:
sliding_window_len=self.get_hf_config_sliding_window())
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = multimodal_config
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self._verify_embedding_mode()

View File

@ -643,6 +643,36 @@ class EngineArgs:
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size):
raise ValueError(
'Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.')
if self.image_processor is None:
self.image_processor = self.model
if self.disable_image_processor:
if self.image_processor != self.model:
warnings.warn(
"You've specified an image processor "
f"({self.image_processor}) but also disabled "
"it via `--disable-image-processor`.",
stacklevel=2)
self.image_processor = None
vision_language_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.
get_image_input_enum_type(self.image_input_type),
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
image_processor=self.image_processor,
image_processor_revision=self.image_processor_revision,
)
else:
vision_language_config = None
device_config = DeviceConfig(device=self.device)
model_config = ModelConfig(
@ -666,7 +696,8 @@ class EngineArgs:
max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name)
served_model_name=self.served_model_name,
multimodal_config=vision_language_config)
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
@ -742,37 +773,6 @@ class EngineArgs:
model_loader_extra_config=self.model_loader_extra_config,
)
if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size):
raise ValueError(
'Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.')
if self.image_processor is None:
self.image_processor = self.model
if self.disable_image_processor:
if self.image_processor != self.model:
warnings.warn(
"You've specified an image processor "
f"({self.image_processor}) but also disabled "
"it via `--disable-image-processor`.",
stacklevel=2)
self.image_processor = None
vision_language_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.
get_image_input_enum_type(self.image_input_type),
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
image_processor=self.image_processor,
image_processor_revision=self.image_processor_revision,
)
else:
vision_language_config = None
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)

View File

@ -278,9 +278,11 @@ class _AsyncLLMEngine(LLMEngine):
else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
return self.input_processor(llm_inputs)
async def add_request_async(
self,

View File

@ -20,7 +20,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import LLMInputs, PromptInputs
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
@ -227,6 +227,9 @@ class LLMEngine:
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.input_processor = INPUT_REGISTRY.create_input_processor(
self.model_config)
self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
@ -511,9 +514,11 @@ class LLMEngine:
else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
return self.input_processor(llm_inputs)
def add_request(
self,

19
vllm/inputs/__init__.py Normal file
View File

@ -0,0 +1,19 @@
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
PromptStrictInputs, TextPrompt, TextTokensPrompt,
TokensPrompt, parse_and_batch_prompt)
from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry()
"""
The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine`
to dispatch data processing according to the target model.
See also:
:ref:`input_processing_pipeline`
"""
__all__ = [
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
"TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs",
"LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry"
]

View File

@ -101,8 +101,7 @@ class TextTokensPrompt(TypedDict):
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt. If None, we use the
tokenizer to convert the prompts to token IDs."""
"""The token IDs of the prompt."""
multi_modal_data: NotRequired["MultiModalData"]
"""
@ -125,6 +124,21 @@ PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
class LLMInputs(TypedDict):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
"""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
prompt: NotRequired[Optional[str]]
"""
The original prompt text corresponding to the token IDs, if available.
"""
multi_modal_data: NotRequired[Optional["MultiModalData"]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""

207
vllm/inputs/registry.py Normal file
View File

@ -0,0 +1,207 @@
import functools
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type,
TypeVar)
from torch import nn
from transformers import PretrainedConfig
from vllm.logger import init_logger
from .data import LLMInputs
if TYPE_CHECKING:
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.multimodal import MultiModalData
from vllm.sequence import SequenceData
logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig)
@dataclass(frozen=True)
class InputContext:
"""
Contains information about the model which may be used to
modify the inputs.
"""
model_config: "ModelConfig"
"""The configuration of the model."""
def get_multimodal_config(self) -> "VisionLanguageConfig":
"""
Get the multimodal configuration of the model.
Raises:
ValueError: If the model is not multimodal.
"""
multimodal_config = self.model_config.multimodal_config
if multimodal_config is None:
raise ValueError("No multimodal config found")
return multimodal_config
def get_hf_config(self, hf_config_type: Type[C]) -> C:
"""
Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
ValueError: If the model is not of the specified type.
"""
hf_config = self.model_config.hf_config
if not isinstance(hf_config, hf_config_type):
raise TypeError("Invalid type of HuggingFace config. "
f"Expected type: {hf_config_type}, but "
f"found type: {type(hf_config)}")
return hf_config
N = TypeVar("N", bound=Type[nn.Module])
DummyDataFactory = Callable[[InputContext, int],
Tuple["SequenceData", Optional["MultiModalData"]]]
"""
Create dummy data to be inputted into the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
"""
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
"""Preprocess the inputs to the model."""
class InputRegistry:
"""
A registry to dispatch data processing
according to the target model.
"""
def __init__(self) -> None:
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
DummyDataFactory] = {}
self._input_processors_by_model_type: Dict[Type[nn.Module],
InputProcessor] = {}
def _default_dummy_data_factory(
self,
ctx: InputContext,
seq_len: int,
) -> Tuple["SequenceData", Optional["MultiModalData"]]:
"""
The default dummy data factory represents the longest possible text
that can be inputted to the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
"""
# Avoid circular import
from vllm.sequence import SequenceData
dummy_seq_data = SequenceData([0] * seq_len)
dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data
def register_dummy_data(self, factory: DummyDataFactory):
"""
Register a dummy data factory to a model class.
During memory profiling, the provided function is invoked to create
dummy data to be inputted into the model. The resulting memory usage
should be an upper bound of what the model would use at inference time.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_factories_by_model_type:
logger.warning(
"Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def dummy_data_for_profiling(self, model_config: "ModelConfig",
seq_len: int):
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
TODO: Add guide [ref: PR #5276]
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
dummy_factory = self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
return dummy_factory(InputContext(model_config), seq_len)
def _default_input_processor(self, ctx: InputContext,
inputs: LLMInputs) -> LLMInputs:
"""The default input processor is a no-op."""
return inputs
def register_input_processor(self, processor: InputProcessor):
"""
Register an input processor to a model class.
The provided function is invoked on each input to the model. This
happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`.
See also:
:ref:`input_processing_pipeline`
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._input_processors_by_model_type:
logger.warning(
"Model class %s already has input processor "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._input_processors_by_model_type[model_cls] = processor
return model_cls
return wrapper
def process_input(self, model_config: "ModelConfig",
inputs: LLMInputs) -> LLMInputs:
"""
Apply an input processor to an instance of model inputs.
The model is identified by ``model_config``.
See also:
:ref:`input_processing_pipeline`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
processor = self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor)
return processor(InputContext(model_config), inputs)
def create_input_processor(self, model_config: "ModelConfig"):
"""
Create an input processor (see :meth:`process_input`) for a
specific model.
"""
return functools.partial(self.process_input, model_config)

View File

@ -1,22 +1,83 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPAttention
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SequenceData
def get_clip_num_patches(image_size: int, patch_size: int) -> int:
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
return (image_size // patch_size)**2
return image_size // patch_size
def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
grid_length = get_clip_patch_grid_length(image_size=image_size,
patch_size=patch_size)
return grid_length * grid_length
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
return get_clip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
def dummy_seq_data_for_clip(
hf_config: CLIPVisionConfig,
seq_len: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
return SequenceData(token_ids)
def dummy_pixel_data_for_clip(
hf_config: CLIPVisionConfig,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return ImagePixelData(image)
def dummy_feature_data_for_clip(
hf_config: CLIPVisionConfig,
*,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
values = torch.zeros((1, image_feature_size, hf_config.hidden_size),
dtype=torch.float16)
return ImageFeatureData(values)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
@ -39,8 +100,8 @@ class CLIPVisionEmbeddings(nn.Module):
bias=False,
)
self.num_patches = get_clip_num_patches(self.image_size,
self.patch_size)
self.num_patches = get_clip_num_patches(image_size=self.image_size,
patch_size=self.patch_size)
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
@ -101,7 +162,7 @@ class CLIPEncoderLayer(nn.Module):
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states

View File

@ -2,10 +2,11 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
import torch.nn as nn
from transformers import LlavaConfig
from transformers import CLIPVisionConfig, LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
@ -16,10 +17,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import get_dummy_image_data
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.sequence import SamplerOutput
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
dummy_seq_data_for_clip)
from .interfaces import SupportsVision
_KEYS_TO_MODIFY_MAPPING = {
@ -83,9 +85,35 @@ class LlavaImageFeatureInputs(TypedDict):
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
@MULTIMODAL_REGISTRY.register_image_feature_input()
@MULTIMODAL_REGISTRY.register_image_pixel_input()
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
image_token_id=hf_config.image_token_index,
)
image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
mm_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
mm_data = dummy_pixel_data_for_clip(vision_config)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
mm_data = dummy_feature_data_for_clip(vision_config)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,

View File

@ -3,14 +3,14 @@ from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
import torch
import torch.nn as nn
from PIL import Image
from transformers import LlavaNextConfig
from transformers import CLIPVisionConfig, LlavaNextConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
@ -22,9 +22,11 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
from vllm.sequence import SamplerOutput, SequenceData
from vllm.multimodal.image import ImagePixelData
from vllm.sequence import SamplerOutput
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
dummy_seq_data_for_clip, get_clip_patch_grid_length)
from .interfaces import SupportsVision
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
@ -58,41 +60,118 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageFeatureInputs]
def _get_dummy_image_data(
seq_len: int,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Tuple[SequenceData, MultiModalData]:
seq_data, fake_mm_data = get_dummy_image_data(seq_len, model_config,
vlm_config)
def _get_llava_next_num_unpadded_features(
height: int,
width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
# Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
config_input_type = vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
else:
new_width = (width * current_height) // height
current_width = new_width
if config_input_type == ImageInputType.PIXEL_VALUES:
_, c, h, w = vlm_config.image_input_shape
mode = {1: "L", 3: "RGB"}[c]
fake_mm_data = ImagePixelData(Image.new(mode, (w, h), color=0))
return seq_data, fake_mm_data
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def _image_pixel_processor(
data: ImagePixelData,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Dict[str, torch.Tensor]:
def _get_llava_next_image_feature_size(
hf_config: LlavaNextConfig,
*,
input_height: int,
input_width: int,
) -> int:
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
num_patches = get_clip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = num_patches * num_patches
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
)
(
unpadded_feature_size,
newline_feature_size,
) = _get_llava_next_num_unpadded_features(input_height, input_width,
num_patches,
num_patch_height,
num_patch_width)
return unpadded_feature_size + newline_feature_size + base_feature_size
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
#TODO: change the logic for dummy data to support dynamic shape
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
image_feature_size = _get_llava_next_image_feature_size(
hf_config, input_height=dummy_height, input_width=dummy_width)
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
mm_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
mm_data = dummy_pixel_data_for_clip(
vision_config,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
mm_data = dummy_feature_data_for_clip(
vision_config,
image_feature_size_override=image_feature_size,
)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def _pixel_mapper(ctx: InputContext,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
image = data.image
if isinstance(image, torch.Tensor):
pixel_values = image.to(model_config.dtype)
pixel_values = image.to(ctx.model_config.dtype)
batch_size, _, _, h, w = pixel_values.shape
image_sizes = torch.tensor([(w, h) for _ in range(batch_size)])
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = vlm_config.image_input_shape
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != (image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
@ -101,11 +180,12 @@ def _image_pixel_processor(
data.image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
._default_input_processor(data, model_config, vlm_config)
._default_input_mapper(ctx, data)
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor)
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_pixel_mapper)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,

View File

@ -22,7 +22,8 @@ from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
@ -34,9 +35,10 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
from vllm.multimodal.image import ImagePixelData
from vllm.sequence import SamplerOutput
from .clip import dummy_pixel_data_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision
logger = init_logger(__name__)
@ -107,7 +109,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
self.num_img_tokens = config.img_processor['num_img_tokens']
self.image_dim_out = image_dim_out
self.img_sizes = None
# global_gn and sub_gn for hd transform, serves as line separator
self.use_hd_transform = config.embd_layer.get('use_hd_transform',
@ -134,7 +135,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
self.img_projection = nn.Sequential(*layers)
self.vocab_size = config.vocab_size
self.img_features = None
self.layer_idx = config.img_processor.get('layer_idx', -2)
self.type_feature = config.img_processor.get('type_feature', 'patch')
@ -260,9 +260,44 @@ class Phi3VImagePixelInputs(TypedDict):
"""Shape: (batch_size, 2)"""
# FIXME(Isotr0py): Remove these after dynamic num_img_tokens is supported
# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def calc_padded_size(width, height, padding_unit=336):
def _get_phi3v_image_feature_size(
*,
input_height: int,
input_width: int,
) -> int:
h, w = input_height, input_width
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178
return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
#TODO: change the logic for dummy data to support dynamic shape
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
image_feature_size = _get_phi3v_image_feature_size(
input_height=dummy_height,
input_width=dummy_width,
)
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
image_token_id=32044,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_pixel_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
return seq_data, mm_data
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
target_height = int(np.ceil(height / padding_unit) * padding_unit)
top_padding = int((target_height - height) / 2)
bottom_padding = target_height - height - top_padding
@ -271,8 +306,8 @@ def calc_padded_size(width, height, padding_unit=336):
return padded_width, padded_height
# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def calc_hd_transform_size(width, height, hd_num=16):
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
transposed = False
if width < height:
width, height = height, width
@ -287,7 +322,8 @@ def calc_hd_transform_size(width, height, hd_num=16):
new_width = int(scale * 336)
new_height = int(new_width / ratio)
padded_width, padded_height = calc_padded_size(new_width, new_height)
padded_width, padded_height = _calc_padded_size(width=new_width,
height=new_height)
if transposed:
padded_width, padded_height = padded_height, padded_width
@ -295,17 +331,15 @@ def calc_hd_transform_size(width, height, hd_num=16):
return padded_width, padded_height
def _image_processor(
data: ImagePixelData,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Dict[str, torch.Tensor]:
def _image_processor(ctx: InputContext,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
image = data.image
if isinstance(image, Image.Image):
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = vlm_config.image_input_shape
if (w, h) != calc_hd_transform_size(image.width, image.height):
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != _calc_hd_transform_size(width=image.width,
height=image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
@ -313,11 +347,11 @@ def _image_processor(
data.image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
._default_input_processor(data, model_config, vlm_config)
._default_input_mapper(ctx, data)
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_processor)
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_image_processor)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision):
def __init__(self,

View File

@ -1,5 +1,14 @@
from .base import MultiModalData, MultiModalPlugin
from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""
The global :class:`~MultiModalRegistry` is used by model runners to
dispatch data processing according to its modality and the target model.
See also:
:ref:`input_processing_pipeline`
"""
__all__ = [
"MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY",

View File

@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
TypeVar)
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
if TYPE_CHECKING:
@ -23,7 +24,7 @@ class MultiModalData:
Finally, register the new plugin to
:const:`vllm.multimodal.MULTIMODAL_REGISTRY`.
This enables models to call :meth:`MultiModalRegistry.register_input` for
This enables models to call :meth:`MultiModalRegistry.map_input` for
the new modality.
"""
pass
@ -32,10 +33,9 @@ class MultiModalData:
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
Dict[str, "torch.Tensor"]]
MultiModalInputMapper = Callable[[InputContext, D], Dict[str, "torch.Tensor"]]
"""Return a dictionary to be passed as keyword arguments to
:meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers."""
@ -50,16 +50,9 @@ class MultiModalPlugin(ABC, Generic[D]):
(i.e., the modality of the data).
"""
@classmethod
def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]:
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
return get_model_architecture(model_config)[0]
def __init__(self) -> None:
self._input_processors: Dict[Type["nn.Module"],
MultiModalInputProcessor[D]] = {}
self._input_mappers: Dict[Type["nn.Module"],
MultiModalInputMapper[D]] = {}
@abstractmethod
def get_data_type(self) -> Type[D]:
@ -70,57 +63,62 @@ class MultiModalPlugin(ABC, Generic[D]):
raise NotImplementedError
@abstractmethod
def _default_input_processor(
self, data: D, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
def _default_input_mapper(self, ctx: InputContext,
data: D) -> Dict[str, "torch.Tensor"]:
"""Return a dictionary to be passed as keyword arguments to
:meth:`torch.nn.Module.forward`. This is similar in concept to
:meth:`~torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers.
"""
raise NotImplementedError
def register_input_processor(self,
processor: Optional[
MultiModalInputProcessor[D]] = None):
def register_input_mapper(
self,
mapper: Optional[MultiModalInputMapper[D]] = None,
):
"""
Register an input processor to a model class.
Register an input mapper to a model class.
When the model receives input data that matches the modality served by
this plugin (see :meth:`get_data_type`), the provided input processor is
applied to preprocess the data. If `None` is provided, then the default
input processor is applied instead.
this plugin (see :meth:`get_data_type`), the provided function is
invoked to transform the data into a dictionary of model inputs.
If `None` is provided, then the default input mapper is used instead.
See also:
:ref:`input_processing_pipeline`
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._input_processors:
if model_cls in self._input_mappers:
logger.warning(
"Model class %s already has an input processor "
"Model class %s already has an input mapper "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._input_processors[model_cls] = processor \
or self._default_input_processor
self._input_mappers[model_cls] = mapper \
or self._default_input_mapper
return model_cls
return wrapper
def process_input(
self, data: D, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
def map_input(self, model_config: ModelConfig,
data: D) -> Dict[str, "torch.Tensor"]:
"""
Apply an input processor to a :class:`~MultiModalData` instance passed
to the model.
The model is identified by ``model_config``. ``vlm_config`` is
for compatibility purposes and may be merged into ``model_config``
in the near future.
"""
model_cls = self.get_model_cls(model_config)
Apply an input mapper to a :class:`~MultiModalData` instance passed
to the model, transforming the data into a dictionary of model inputs.
processor = self._input_processors.get(model_cls)
if processor is None:
raise KeyError(f"No input processor in {self} is registered for "
The model is identified by ``model_config``.
TODO: Add guide [ref: PR #5276]
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
mapper = self._input_mappers.get(model_cls)
if mapper is None:
raise KeyError(f"No input mapper in {self} is registered for "
f"model class {model_cls.__name__}.")
return processor(data, model_config, vlm_config)
return mapper(InputContext(model_config), data)

View File

@ -1,70 +1,28 @@
from typing import Dict, Tuple, Type, Union
from functools import lru_cache
from typing import Dict, Type, Union
import torch
from PIL import Image
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.sequence import SequenceData
from vllm.transformers_utils.image_processor import cached_get_image_processor
from vllm.transformers_utils.image_processor import get_image_processor
from .base import MultiModalData, MultiModalPlugin
logger = init_logger(__name__)
def _get_dummy_seq_data(seq_len: int,
vlm_config: VisionLanguageConfig) -> SequenceData:
# NOTE: We assume that <image> token is repeated `image_feature_size` times
# and then concatenated with the text prompt
# TODO: Enable other ways of inserting the image into the prompt
token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size
token_ids += [0] * (seq_len - vlm_config.image_feature_size)
return SequenceData(token_ids)
def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor:
if vlm_config.image_processor is None:
values_dtype = torch.float16
else:
values_dtype = torch.uint8
return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype)
def get_dummy_image_data(
seq_len: int,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Tuple[SequenceData, MultiModalData]:
"""Standard dummy data factory for image data (to be used in
:meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`)."""
seq_data = _get_dummy_seq_data(seq_len, vlm_config)
values = _get_dummy_values(vlm_config)
config_input_type = vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
fake_mm_data: MultiModalData
if config_input_type == ImageInputType.PIXEL_VALUES:
fake_mm_data = ImagePixelData(values)
elif config_input_type == ImageInputType.IMAGE_FEATURES:
fake_mm_data = ImageFeatureData(values)
else:
raise NotImplementedError
return seq_data, fake_mm_data
cached_get_image_processor = lru_cache(get_image_processor)
class ImagePixelData(MultiModalData):
"""
The pixel data of an image. Can be one of:
- :class:``PIL.Image``: An image object. Requires that a HuggingFace
- :class:`PIL.Image.Image`: An image object. Requires that a HuggingFace
processor is available to the model.
- :class:``torch.Tensor``: The raw pixel data which is passed to the model
- :class:`torch.Tensor`: The raw pixel data which is passed to the model
without additional pre-processing.
"""
@ -89,8 +47,8 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
def get_data_type(self) -> Type[ImagePixelData]:
return ImagePixelData
def _get_hf_image_processor(self, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
def _get_hf_image_processor(self, model_config: ModelConfig):
vlm_config = model_config.multimodal_config
if vlm_config is None or vlm_config.image_processor is None:
return None
@ -100,14 +58,13 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
revision=vlm_config.image_processor_revision,
)
def _default_input_processor(
self, data: ImagePixelData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
def _default_input_mapper(self, ctx: InputContext,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
model_config = ctx.model_config
image = data.image
if isinstance(image, Image.Image):
image_processor = self._get_hf_image_processor(
model_config, vlm_config)
image_processor = self._get_hf_image_processor(model_config)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available"
"to process the image object")
@ -147,9 +104,10 @@ class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
def get_data_type(self) -> Type[ImageFeatureData]:
return ImageFeatureData
def _default_input_processor(
self, data: ImageFeatureData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
def _default_input_mapper(
self, ctx: InputContext,
data: ImageFeatureData) -> Dict[str, torch.Tensor]:
model_config = ctx.model_config
image_features = data.image_features.to(model_config.dtype)
return {"image_features": image_features}

View File

@ -1,46 +1,35 @@
import functools
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence,
Tuple, Type, TypeVar)
from typing import Any, Optional, Sequence, Type, TypeVar
from vllm.config import ModelConfig, VisionLanguageConfig
from torch import nn
from vllm.config import ModelConfig
from vllm.logger import init_logger
from .base import MultiModalData, MultiModalPlugin
from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
ImagePixelPlugin)
if TYPE_CHECKING:
import torch
from torch import nn
from vllm.sequence import SequenceData
logger = init_logger(__name__)
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
Dict[str, "torch.Tensor"]]
MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig],
Tuple["SequenceData", MultiModalData]]
N = TypeVar("N", bound=Type[nn.Module])
class MultiModalRegistry:
"""
This registry is used by model runners to dispatch data processing
A registry to dispatch data processing
according to its modality and the target model.
"""
DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
def __init__(self,
*,
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS
) -> None:
def __init__(
self,
*,
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS,
) -> None:
self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
self._dummy_factories_by_model_type: Dict[Type["nn.Module"],
MultiModalDummyFactory] = {}
def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
data_type = plugin.get_data_type()
@ -62,95 +51,53 @@ class MultiModalRegistry:
msg = f"Unknown multi-modal data type: {data_type}"
raise NotImplementedError(msg)
def register_dummy_data(self, factory: MultiModalDummyFactory):
def register_input_mapper(
self,
data_type: Type[D],
mapper: Optional[MultiModalInputMapper[D]] = None,
):
"""
Register a dummy data factory to a model class.
Register an input mapper for a specific modality to a model class.
During memory profiling, the provided function is invoked to create
dummy data to be inputted into the model. The modality and shape of
the dummy data should be an upper bound of what the model would receive
at inference time.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_factories_by_model_type:
logger.warning(
"Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
"""Create dummy data for memory profiling."""
model_cls = MultiModalPlugin.get_model_cls(model_config)
dummy_factory = self._dummy_factories_by_model_type.get(model_cls)
if dummy_factory is None:
msg = f"No dummy data defined for model class: {model_cls}"
raise NotImplementedError(msg)
return dummy_factory(seq_len, model_config, vlm_config)
def register_input(
self,
data_type: Type[D],
processor: Optional[MultiModalInputProcessor[D]] = None):
"""
Register an input processor for a specific modality to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self._get_plugin_for_data_type(data_type) \
.register_input_processor(processor)
.register_input_mapper(mapper)
def register_image_pixel_input(
self,
processor: Optional[
MultiModalInputProcessor[ImagePixelData]] = None):
"""
Register an input processor for image pixel data to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details.
"""
return self.register_input(ImagePixelData, processor)
def register_image_feature_input(
def register_image_pixel_input_mapper(
self,
processor: Optional[
MultiModalInputProcessor[ImageFeatureData]] = None):
mapper: Optional[MultiModalInputMapper[ImagePixelData]] = None,
):
"""
Register an input processor for image feature data to a model class.
Register an input mapper for image pixel data to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input(ImageFeatureData, processor)
return self.register_input_mapper(ImagePixelData, mapper)
def process_input(self, data: MultiModalData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
def register_image_feature_input_mapper(
self,
mapper: Optional[MultiModalInputMapper[ImageFeatureData]] = None,
):
"""
Apply an input processor to a :class:`~MultiModalData` instance passed
Register an input mapper for image feature data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper(ImageFeatureData, mapper)
def map_input(self, model_config: ModelConfig, data: MultiModalData):
"""
Apply an input mapper to a :class:`~MultiModalData` instance passed
to the model.
See :meth:`MultiModalPlugin.process_input` for more details.
See :meth:`MultiModalPlugin.map_input` for more details.
"""
return self._get_plugin_for_data_type(type(data)) \
.process_input(data, model_config, vlm_config)
.map_input(model_config, data)
def create_input_processor(self, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
def create_input_mapper(self, model_config: ModelConfig):
"""
Create an input processor (see :meth:`process_input`) for a
specific model.
Create an input mapper (see :meth:`map_input`) for a specific model.
"""
return functools.partial(self.process_input,
model_config=model_config,
vlm_config=vlm_config)
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""The global :class:`~MultiModalRegistry` which is used by model runners."""
return functools.partial(self.map_input, model_config)

View File

@ -8,12 +8,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from vllm.inputs import LLMInputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
if TYPE_CHECKING:
from vllm.inputs import LLMInputs
from vllm.multimodal import MultiModalData
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@ -221,7 +221,7 @@ class Sequence:
def __init__(
self,
seq_id: int,
inputs: LLMInputs,
inputs: "LLMInputs",
block_size: int,
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,

View File

@ -1,4 +1,3 @@
from functools import lru_cache
from typing import Optional
from transformers import AutoImageProcessor
@ -40,6 +39,3 @@ def get_image_processor(
raise e
return processor
cached_get_image_processor = lru_cache(get_image_processor)

View File

@ -110,15 +110,9 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self.block_size,
)
# Create processor for multi-modal data
if self.vision_language_config is not None:
self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
.create_input_processor(
self.model_config,
self.vision_language_config,
)
else:
self.multi_modal_input_processor = None
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
@ -168,13 +162,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None:
# Process multi-modal data
if self.multi_modal_input_processor is None:
raise ValueError(
"Multi-modal inputs are only supported by "
"vision language models.")
mm_kwargs = self.multi_modal_input_processor(mm_data)
mm_kwargs = self.multi_modal_input_mapper(mm_data)
for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v)

View File

@ -15,6 +15,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
@ -25,7 +26,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import supports_lora
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad)
from vllm.worker.model_runner_base import (
@ -191,15 +192,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.block_size,
) if num_attn_heads else None
# Create processor for multi-modal data
if self.vision_language_config is not None:
self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
.create_input_processor(
self.model_config,
self.vision_language_config,
)
else:
self.multi_modal_input_processor = None
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
# Lazy initialization
self.model: nn.Module # Set after load_model
@ -506,12 +501,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None:
# Process multi-modal data
if self.multi_modal_input_processor is None:
raise ValueError(
"Multi-modal inputs are only supported by "
"vision language models.")
mm_kwargs = self.multi_modal_input_processor(mm_data)
mm_kwargs = self.multi_modal_input_mapper(mm_data)
for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v)
@ -764,12 +754,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
if vlm_config is None:
seq_data = SequenceData([0] * seq_len)
dummy_multi_modal_data = None
else:
seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \
.dummy_data_for_profiling(seq_len, model_config, vlm_config)
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len)
assert len(seq_data.prompt_token_ids) == seq_len
seq = SequenceGroupMetadata(
request_id=str(group_id),