mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[Core] Registry for processing model inputs (#5214)
Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
parent
0d0e3a42ac
commit
5cbe8d155c
@ -0,0 +1,20 @@
|
||||
.. _input_processing_pipeline:
|
||||
|
||||
Input Processing Pipeline
|
||||
=========================
|
||||
|
||||
1. Input data is passed to :class:`~vllm.LLMEngine` (or :class:`~vllm.AsyncLLMEngine`).
|
||||
|
||||
2. Tokenize the data if necessary.
|
||||
|
||||
3. Process the inputs using :meth:`INPUT_REGISTRY.process_input <vllm.inputs.registry.InputRegistry.process_input>`.
|
||||
|
||||
- For example, add placeholder tokens to reserve KV cache for multi-modal embeddings.
|
||||
|
||||
4. Send the processed inputs to :class:`~vllm.executor.executor_base.ExecutorBase`.
|
||||
|
||||
5. Distribute the inputs via :class:`~vllm.worker.worker_base.WorkerBase` to :class:`~vllm.worker.model_runner_base.ModelRunnerBase`.
|
||||
|
||||
6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
|
||||
|
||||
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision language model.
|
||||
39
docs/source/dev/input_processing/model_inputs_index.rst
Normal file
39
docs/source/dev/input_processing/model_inputs_index.rst
Normal file
@ -0,0 +1,39 @@
|
||||
.. _input_processing:
|
||||
|
||||
Input Processing
|
||||
================
|
||||
|
||||
.. currentmodule:: vllm.inputs
|
||||
|
||||
vLLM provides a mechanism for defining input processors for each model so that the inputs are processed
|
||||
in :class:`~vllm.LLMEngine` before they are passed to model executors.
|
||||
|
||||
Currently, this mechanism is only utilized in **multi-modal models** for preprocessing multi-modal input
|
||||
data in addition to input prompt, but it can be extended to text-only language models when needed.
|
||||
|
||||
Guides
|
||||
++++++
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
input_processing_pipeline
|
||||
|
||||
Module Contents
|
||||
+++++++++++++++
|
||||
|
||||
LLM Engine Inputs
|
||||
-----------------
|
||||
|
||||
.. autoclass:: vllm.inputs.LLMInputs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
Registry
|
||||
--------
|
||||
|
||||
.. autodata:: vllm.inputs.INPUT_REGISTRY
|
||||
|
||||
.. automodule:: vllm.inputs.registry
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@ -12,10 +12,6 @@ By default, vLLM models do not support multi-modal inputs. To enable multi-modal
|
||||
you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data <MultiModalRegistry.register_dummy_data>`,
|
||||
as well as :meth:`MULTIMODAL_REGISTRY.register_input <MultiModalRegistry.register_input>` for each modality type to support.
|
||||
|
||||
.. contents::
|
||||
:local:
|
||||
:backlinks: none
|
||||
|
||||
Module Contents
|
||||
+++++++++++++++
|
||||
|
||||
@ -24,9 +20,7 @@ Module Contents
|
||||
Registry
|
||||
--------
|
||||
|
||||
.. data:: vllm.multimodal.MULTIMODAL_REGISTRY
|
||||
|
||||
The global :class:`MultiModalRegistry` which is used by model runners.
|
||||
.. autodata:: vllm.multimodal.MULTIMODAL_REGISTRY
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalRegistry
|
||||
:members:
|
||||
|
||||
@ -120,6 +120,7 @@ Documentation
|
||||
dev/offline_inference/offline_index
|
||||
dev/engine/engine_index
|
||||
dev/kernel/paged_attention
|
||||
dev/input_processing/model_inputs_index
|
||||
dev/multimodal/multimodal_index
|
||||
dev/dockerfile/dockerfile
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/
|
||||
2. Rewrite the :code:`forward` methods
|
||||
--------------------------------------
|
||||
|
||||
Next, you need to rewrite the :code:`forward` methods of your model by following these steps:
|
||||
Next, you need to rewrite the :meth:`~torch.nn.Module.forward` method of your model by following these steps:
|
||||
|
||||
1. Remove any unnecessary code, such as the code only used for training.
|
||||
2. Change the input parameters:
|
||||
@ -75,7 +75,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
|
||||
|
||||
If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
|
||||
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
|
||||
For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
|
||||
For the embedding layer, you can simply replace :class:`torch.nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
|
||||
When it comes to the linear layers, we provide the following options to parallelize them:
|
||||
|
||||
* :code:`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.
|
||||
|
||||
@ -11,7 +11,7 @@ def run_phi3v():
|
||||
model_path = "microsoft/Phi-3-vision-128k-instruct"
|
||||
|
||||
# Note: The model has 128k context length by default which may cause OOM
|
||||
# If that's the case, override `max_model_len` with a smaller value via args
|
||||
# In this example, we override max_model_len to 2048.
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
@ -19,6 +19,7 @@ def run_phi3v():
|
||||
image_token_id=32044,
|
||||
image_input_shape="1,3,1008,1344",
|
||||
image_feature_size=1921,
|
||||
max_model_len=2048,
|
||||
)
|
||||
|
||||
image = Image.open("images/cherry_blossom.jpg")
|
||||
|
||||
@ -25,14 +25,14 @@ def test_clip_image_processor(image_assets, dtype):
|
||||
seed=0,
|
||||
dtype=dtype,
|
||||
revision=None,
|
||||
)
|
||||
vlm_config = VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
|
||||
image_token_id=32000,
|
||||
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
|
||||
image_feature_size=576,
|
||||
image_processor=MODEL_NAME,
|
||||
image_processor_revision=None,
|
||||
multimodal_config=VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
|
||||
image_token_id=32000,
|
||||
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
|
||||
image_feature_size=576,
|
||||
image_processor=MODEL_NAME,
|
||||
image_processor_revision=None,
|
||||
),
|
||||
)
|
||||
|
||||
for asset in image_assets:
|
||||
@ -40,10 +40,9 @@ def test_clip_image_processor(image_assets, dtype):
|
||||
asset.pil_image,
|
||||
return_tensors="pt",
|
||||
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
|
||||
vllm_result = MULTIMODAL_REGISTRY.process_input(
|
||||
vllm_result = MULTIMODAL_REGISTRY.map_input(
|
||||
model_config,
|
||||
ImagePixelData(asset.pil_image),
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
|
||||
assert hf_result.keys() == vllm_result.keys()
|
||||
@ -74,14 +73,14 @@ def test_llava_next_image_processor(image_assets, dtype):
|
||||
seed=0,
|
||||
dtype=dtype,
|
||||
revision=None,
|
||||
)
|
||||
vlm_config = VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
|
||||
image_token_id=64000,
|
||||
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
|
||||
image_feature_size=2928,
|
||||
image_processor=MODEL_NAME,
|
||||
image_processor_revision=None,
|
||||
multimodal_config=VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
|
||||
image_token_id=64000,
|
||||
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
|
||||
image_feature_size=2928,
|
||||
image_processor=MODEL_NAME,
|
||||
image_processor_revision=None,
|
||||
),
|
||||
)
|
||||
|
||||
for asset in image_assets:
|
||||
@ -89,10 +88,9 @@ def test_llava_next_image_processor(image_assets, dtype):
|
||||
asset.pil_image,
|
||||
return_tensors="pt",
|
||||
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
|
||||
vllm_result = MULTIMODAL_REGISTRY.process_input(
|
||||
vllm_result = MULTIMODAL_REGISTRY.map_input(
|
||||
model_config,
|
||||
ImagePixelData(asset.pil_image),
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
|
||||
assert hf_result.keys() == vllm_result.keys()
|
||||
@ -119,26 +117,23 @@ def test_image_pixel_types(image_assets, dtype):
|
||||
seed=0,
|
||||
dtype=dtype,
|
||||
revision=None,
|
||||
)
|
||||
vlm_config = VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
|
||||
image_token_id=32000,
|
||||
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
|
||||
image_feature_size=576,
|
||||
image_processor=MODEL_NAME,
|
||||
image_processor_revision=None,
|
||||
)
|
||||
multimodal_config=VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
|
||||
image_token_id=32000,
|
||||
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
|
||||
image_feature_size=576,
|
||||
image_processor=MODEL_NAME,
|
||||
image_processor_revision=None,
|
||||
))
|
||||
|
||||
for asset in image_assets:
|
||||
image_result = MULTIMODAL_REGISTRY.process_input(
|
||||
image_result = MULTIMODAL_REGISTRY.map_input(
|
||||
model_config,
|
||||
ImagePixelData(asset.pil_image),
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
tensor_result = MULTIMODAL_REGISTRY.process_input(
|
||||
tensor_result = MULTIMODAL_REGISTRY.map_input(
|
||||
model_config,
|
||||
ImagePixelData(asset.pixel_values),
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
|
||||
assert image_result.keys() == tensor_result.keys()
|
||||
@ -109,6 +109,7 @@ class ModelConfig:
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
multimodal_config: Optional["VisionLanguageConfig"] = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@ -159,6 +160,8 @@ class ModelConfig:
|
||||
sliding_window_len=self.get_hf_config_sliding_window())
|
||||
self.served_model_name = get_served_model_name(model,
|
||||
served_model_name)
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
if not self.skip_tokenizer_init:
|
||||
self._verify_tokenizer_mode()
|
||||
self._verify_embedding_mode()
|
||||
|
||||
@ -643,6 +643,36 @@ class EngineArgs:
|
||||
raise ValueError(
|
||||
"BitsAndBytes load format and QLoRA adapter only support "
|
||||
f"'bitsandbytes' quantization, but got {self.quantization}")
|
||||
if self.image_input_type:
|
||||
if (not self.image_token_id or not self.image_input_shape
|
||||
or not self.image_feature_size):
|
||||
raise ValueError(
|
||||
'Specify `image_token_id`, `image_input_shape` and '
|
||||
'`image_feature_size` together with `image_input_type`.')
|
||||
|
||||
if self.image_processor is None:
|
||||
self.image_processor = self.model
|
||||
if self.disable_image_processor:
|
||||
if self.image_processor != self.model:
|
||||
warnings.warn(
|
||||
"You've specified an image processor "
|
||||
f"({self.image_processor}) but also disabled "
|
||||
"it via `--disable-image-processor`.",
|
||||
stacklevel=2)
|
||||
|
||||
self.image_processor = None
|
||||
|
||||
vision_language_config = VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.
|
||||
get_image_input_enum_type(self.image_input_type),
|
||||
image_token_id=self.image_token_id,
|
||||
image_input_shape=str_to_int_tuple(self.image_input_shape),
|
||||
image_feature_size=self.image_feature_size,
|
||||
image_processor=self.image_processor,
|
||||
image_processor_revision=self.image_processor_revision,
|
||||
)
|
||||
else:
|
||||
vision_language_config = None
|
||||
|
||||
device_config = DeviceConfig(device=self.device)
|
||||
model_config = ModelConfig(
|
||||
@ -666,7 +696,8 @@ class EngineArgs:
|
||||
max_logprobs=self.max_logprobs,
|
||||
disable_sliding_window=self.disable_sliding_window,
|
||||
skip_tokenizer_init=self.skip_tokenizer_init,
|
||||
served_model_name=self.served_model_name)
|
||||
served_model_name=self.served_model_name,
|
||||
multimodal_config=vision_language_config)
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
@ -742,37 +773,6 @@ class EngineArgs:
|
||||
model_loader_extra_config=self.model_loader_extra_config,
|
||||
)
|
||||
|
||||
if self.image_input_type:
|
||||
if (not self.image_token_id or not self.image_input_shape
|
||||
or not self.image_feature_size):
|
||||
raise ValueError(
|
||||
'Specify `image_token_id`, `image_input_shape` and '
|
||||
'`image_feature_size` together with `image_input_type`.')
|
||||
|
||||
if self.image_processor is None:
|
||||
self.image_processor = self.model
|
||||
if self.disable_image_processor:
|
||||
if self.image_processor != self.model:
|
||||
warnings.warn(
|
||||
"You've specified an image processor "
|
||||
f"({self.image_processor}) but also disabled "
|
||||
"it via `--disable-image-processor`.",
|
||||
stacklevel=2)
|
||||
|
||||
self.image_processor = None
|
||||
|
||||
vision_language_config = VisionLanguageConfig(
|
||||
image_input_type=VisionLanguageConfig.
|
||||
get_image_input_enum_type(self.image_input_type),
|
||||
image_token_id=self.image_token_id,
|
||||
image_input_shape=str_to_int_tuple(self.image_input_shape),
|
||||
image_feature_size=self.image_feature_size,
|
||||
image_processor=self.image_processor,
|
||||
image_processor_revision=self.image_processor_revision,
|
||||
)
|
||||
else:
|
||||
vision_language_config = None
|
||||
|
||||
decoding_config = DecodingConfig(
|
||||
guided_decoding_backend=self.guided_decoding_backend)
|
||||
|
||||
|
||||
@ -278,9 +278,11 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
|
||||
return self.input_processor(llm_inputs)
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import LLMInputs, PromptInputs
|
||||
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
@ -227,6 +227,9 @@ class LLMEngine:
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
model_config)
|
||||
|
||||
self.input_processor = INPUT_REGISTRY.create_input_processor(
|
||||
self.model_config)
|
||||
|
||||
self.model_executor = executor_class(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
@ -511,9 +514,11 @@ class LLMEngine:
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
|
||||
return self.input_processor(llm_inputs)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
|
||||
19
vllm/inputs/__init__.py
Normal file
19
vllm/inputs/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
|
||||
PromptStrictInputs, TextPrompt, TextTokensPrompt,
|
||||
TokensPrompt, parse_and_batch_prompt)
|
||||
from .registry import InputContext, InputRegistry
|
||||
|
||||
INPUT_REGISTRY = InputRegistry()
|
||||
"""
|
||||
The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine`
|
||||
to dispatch data processing according to the target model.
|
||||
|
||||
See also:
|
||||
:ref:`input_processing_pipeline`
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
|
||||
"TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs",
|
||||
"LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry"
|
||||
]
|
||||
@ -101,8 +101,7 @@ class TextTokensPrompt(TypedDict):
|
||||
"""The prompt text."""
|
||||
|
||||
prompt_token_ids: List[int]
|
||||
"""The token IDs of the prompt. If None, we use the
|
||||
tokenizer to convert the prompts to token IDs."""
|
||||
"""The token IDs of the prompt."""
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalData"]
|
||||
"""
|
||||
@ -125,6 +124,21 @@ PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
|
||||
|
||||
|
||||
class LLMInputs(TypedDict):
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
passed to the model executor.
|
||||
"""
|
||||
|
||||
prompt_token_ids: List[int]
|
||||
"""The token IDs of the prompt."""
|
||||
|
||||
prompt: NotRequired[Optional[str]]
|
||||
"""
|
||||
The original prompt text corresponding to the token IDs, if available.
|
||||
"""
|
||||
|
||||
multi_modal_data: NotRequired[Optional["MultiModalData"]]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
207
vllm/inputs/registry.py
Normal file
207
vllm/inputs/registry.py
Normal file
@ -0,0 +1,207 @@
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type,
|
||||
TypeVar)
|
||||
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .data import LLMInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.multimodal import MultiModalData
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
C = TypeVar("C", bound=PretrainedConfig)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputContext:
|
||||
"""
|
||||
Contains information about the model which may be used to
|
||||
modify the inputs.
|
||||
"""
|
||||
|
||||
model_config: "ModelConfig"
|
||||
"""The configuration of the model."""
|
||||
|
||||
def get_multimodal_config(self) -> "VisionLanguageConfig":
|
||||
"""
|
||||
Get the multimodal configuration of the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not multimodal.
|
||||
"""
|
||||
|
||||
multimodal_config = self.model_config.multimodal_config
|
||||
if multimodal_config is None:
|
||||
raise ValueError("No multimodal config found")
|
||||
|
||||
return multimodal_config
|
||||
|
||||
def get_hf_config(self, hf_config_type: Type[C]) -> C:
|
||||
"""
|
||||
Get the HuggingFace configuration
|
||||
(:class:`transformers.PretrainedConfig`) of the model,
|
||||
additionally checking its type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not of the specified type.
|
||||
"""
|
||||
|
||||
hf_config = self.model_config.hf_config
|
||||
if not isinstance(hf_config, hf_config_type):
|
||||
raise TypeError("Invalid type of HuggingFace config. "
|
||||
f"Expected type: {hf_config_type}, but "
|
||||
f"found type: {type(hf_config)}")
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
|
||||
DummyDataFactory = Callable[[InputContext, int],
|
||||
Tuple["SequenceData", Optional["MultiModalData"]]]
|
||||
"""
|
||||
Create dummy data to be inputted into the model.
|
||||
|
||||
Note:
|
||||
:data:`InputProcessor` is not applied to the dummy data.
|
||||
"""
|
||||
|
||||
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
|
||||
"""Preprocess the inputs to the model."""
|
||||
|
||||
|
||||
class InputRegistry:
|
||||
"""
|
||||
A registry to dispatch data processing
|
||||
according to the target model.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
|
||||
DummyDataFactory] = {}
|
||||
self._input_processors_by_model_type: Dict[Type[nn.Module],
|
||||
InputProcessor] = {}
|
||||
|
||||
def _default_dummy_data_factory(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
) -> Tuple["SequenceData", Optional["MultiModalData"]]:
|
||||
"""
|
||||
The default dummy data factory represents the longest possible text
|
||||
that can be inputted to the model.
|
||||
|
||||
Note:
|
||||
:data:`InputProcessor` is not applied to the dummy data.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
dummy_seq_data = SequenceData([0] * seq_len)
|
||||
dummy_multi_modal_data = None
|
||||
|
||||
return dummy_seq_data, dummy_multi_modal_data
|
||||
|
||||
def register_dummy_data(self, factory: DummyDataFactory):
|
||||
"""
|
||||
Register a dummy data factory to a model class.
|
||||
|
||||
During memory profiling, the provided function is invoked to create
|
||||
dummy data to be inputted into the model. The resulting memory usage
|
||||
should be an upper bound of what the model would use at inference time.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._dummy_factories_by_model_type:
|
||||
logger.warning(
|
||||
"Model class %s already has dummy data "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._dummy_factories_by_model_type[model_cls] = factory
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def dummy_data_for_profiling(self, model_config: "ModelConfig",
|
||||
seq_len: int):
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
|
||||
TODO: Add guide [ref: PR #5276]
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
dummy_factory = self._dummy_factories_by_model_type \
|
||||
.get(model_cls, self._default_dummy_data_factory)
|
||||
|
||||
return dummy_factory(InputContext(model_config), seq_len)
|
||||
|
||||
def _default_input_processor(self, ctx: InputContext,
|
||||
inputs: LLMInputs) -> LLMInputs:
|
||||
"""The default input processor is a no-op."""
|
||||
return inputs
|
||||
|
||||
def register_input_processor(self, processor: InputProcessor):
|
||||
"""
|
||||
Register an input processor to a model class.
|
||||
|
||||
The provided function is invoked on each input to the model. This
|
||||
happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`.
|
||||
|
||||
See also:
|
||||
:ref:`input_processing_pipeline`
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._input_processors_by_model_type:
|
||||
logger.warning(
|
||||
"Model class %s already has input processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._input_processors_by_model_type[model_cls] = processor
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def process_input(self, model_config: "ModelConfig",
|
||||
inputs: LLMInputs) -> LLMInputs:
|
||||
"""
|
||||
Apply an input processor to an instance of model inputs.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
|
||||
See also:
|
||||
:ref:`input_processing_pipeline`
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
|
||||
processor = self._input_processors_by_model_type \
|
||||
.get(model_cls, self._default_input_processor)
|
||||
|
||||
return processor(InputContext(model_config), inputs)
|
||||
|
||||
def create_input_processor(self, model_config: "ModelConfig"):
|
||||
"""
|
||||
Create an input processor (see :meth:`process_input`) for a
|
||||
specific model.
|
||||
"""
|
||||
return functools.partial(self.process_input, model_config)
|
||||
@ -1,22 +1,83 @@
|
||||
"""Minimal implementation of CLIPVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig
|
||||
from transformers.models.clip.modeling_clip import CLIPAttention
|
||||
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
def get_clip_num_patches(image_size: int, patch_size: int) -> int:
|
||||
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
return (image_size // patch_size)**2
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
|
||||
grid_length = get_clip_patch_grid_length(image_size=image_size,
|
||||
patch_size=patch_size)
|
||||
return grid_length * grid_length
|
||||
|
||||
|
||||
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
|
||||
return get_clip_num_patches(image_size=hf_config.image_size,
|
||||
patch_size=hf_config.patch_size)
|
||||
|
||||
|
||||
def dummy_seq_data_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
seq_len: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_clip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = [image_token_id] * image_feature_size
|
||||
token_ids += [0] * (seq_len - image_feature_size)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def dummy_pixel_data_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
):
|
||||
width = height = hf_config.image_size
|
||||
if image_width_override is not None:
|
||||
width = image_width_override
|
||||
if image_height_override is not None:
|
||||
height = image_height_override
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return ImagePixelData(image)
|
||||
|
||||
|
||||
def dummy_feature_data_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
*,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_clip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
values = torch.zeros((1, image_feature_size, hf_config.hidden_size),
|
||||
dtype=torch.float16)
|
||||
return ImageFeatureData(values)
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
|
||||
@ -39,8 +100,8 @@ class CLIPVisionEmbeddings(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.num_patches = get_clip_num_patches(self.image_size,
|
||||
self.patch_size)
|
||||
self.num_patches = get_clip_num_patches(image_size=self.image_size,
|
||||
patch_size=self.patch_size)
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = nn.Embedding(self.num_positions,
|
||||
self.embed_dim)
|
||||
@ -101,7 +162,7 @@ class CLIPEncoderLayer(nn.Module):
|
||||
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
|
||||
@ -2,10 +2,11 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import LlavaConfig
|
||||
from transformers import CLIPVisionConfig, LlavaConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -16,10 +17,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import get_dummy_image_data
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
|
||||
dummy_seq_data_for_clip)
|
||||
from .interfaces import SupportsVision
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
@ -83,9 +85,35 @@ class LlavaImageFeatureInputs(TypedDict):
|
||||
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_feature_input()
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input()
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
||||
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
)
|
||||
|
||||
image_input_type = multimodal_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
mm_data: MultiModalData
|
||||
if image_input_type == ImageInputType.PIXEL_VALUES:
|
||||
mm_data = dummy_pixel_data_for_clip(vision_config)
|
||||
elif image_input_type == ImageInputType.IMAGE_FEATURES:
|
||||
mm_data = dummy_feature_data_for_clip(vision_config)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@ -3,14 +3,14 @@ from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import LlavaNextConfig
|
||||
from transformers import CLIPVisionConfig, LlavaNextConfig
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -22,9 +22,11 @@ from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
|
||||
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
|
||||
from vllm.sequence import SamplerOutput, SequenceData
|
||||
from vllm.multimodal.image import ImagePixelData
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
|
||||
dummy_seq_data_for_clip, get_clip_patch_grid_length)
|
||||
from .interfaces import SupportsVision
|
||||
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
|
||||
|
||||
@ -58,41 +60,118 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
|
||||
LlavaNextImageFeatureInputs]
|
||||
|
||||
|
||||
def _get_dummy_image_data(
|
||||
seq_len: int,
|
||||
model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
) -> Tuple[SequenceData, MultiModalData]:
|
||||
seq_data, fake_mm_data = get_dummy_image_data(seq_len, model_config,
|
||||
vlm_config)
|
||||
def _get_llava_next_num_unpadded_features(
|
||||
height: int,
|
||||
width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
# Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
config_input_type = vlm_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
aspect_ratio: float = width / height
|
||||
current_aspect_ratio: float = current_width / current_height
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = (height * current_width) // width
|
||||
current_height = new_height
|
||||
else:
|
||||
new_width = (width * current_height) // height
|
||||
current_width = new_width
|
||||
|
||||
if config_input_type == ImageInputType.PIXEL_VALUES:
|
||||
_, c, h, w = vlm_config.image_input_shape
|
||||
mode = {1: "L", 3: "RGB"}[c]
|
||||
fake_mm_data = ImagePixelData(Image.new(mode, (w, h), color=0))
|
||||
|
||||
return seq_data, fake_mm_data
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
|
||||
def _image_pixel_processor(
|
||||
data: ImagePixelData,
|
||||
model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
def _get_llava_next_image_feature_size(
|
||||
hf_config: LlavaNextConfig,
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
) -> int:
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
num_patches = get_clip_patch_grid_length(
|
||||
image_size=vision_config.image_size,
|
||||
patch_size=vision_config.patch_size,
|
||||
)
|
||||
base_feature_size = num_patches * num_patches
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_size=(input_height, input_width),
|
||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||
patch_size=vision_config.image_size,
|
||||
)
|
||||
|
||||
(
|
||||
unpadded_feature_size,
|
||||
newline_feature_size,
|
||||
) = _get_llava_next_num_unpadded_features(input_height, input_width,
|
||||
num_patches,
|
||||
num_patch_height,
|
||||
num_patch_width)
|
||||
|
||||
return unpadded_feature_size + newline_feature_size + base_feature_size
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
#TODO: change the logic for dummy data to support dynamic shape
|
||||
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
|
||||
image_feature_size = _get_llava_next_image_feature_size(
|
||||
hf_config, input_height=dummy_height, input_width=dummy_width)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
image_input_type = multimodal_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
mm_data: MultiModalData
|
||||
if image_input_type == ImageInputType.PIXEL_VALUES:
|
||||
mm_data = dummy_pixel_data_for_clip(
|
||||
vision_config,
|
||||
image_width_override=dummy_width,
|
||||
image_height_override=dummy_height,
|
||||
)
|
||||
elif image_input_type == ImageInputType.IMAGE_FEATURES:
|
||||
mm_data = dummy_feature_data_for_clip(
|
||||
vision_config,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _pixel_mapper(ctx: InputContext,
|
||||
data: ImagePixelData) -> Dict[str, torch.Tensor]:
|
||||
image = data.image
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
pixel_values = image.to(model_config.dtype)
|
||||
pixel_values = image.to(ctx.model_config.dtype)
|
||||
batch_size, _, _, h, w = pixel_values.shape
|
||||
image_sizes = torch.tensor([(w, h) for _ in range(batch_size)])
|
||||
|
||||
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
|
||||
|
||||
# Temporary patch before dynamic number of image tokens is supported
|
||||
_, _, h, w = vlm_config.image_input_shape
|
||||
_, _, h, w = ctx.get_multimodal_config().image_input_shape
|
||||
if (w, h) != (image.width, image.height):
|
||||
logger.warning(
|
||||
"Dynamic image shape is currently not supported. "
|
||||
@ -101,11 +180,12 @@ def _image_pixel_processor(
|
||||
data.image = image.resize((w, h))
|
||||
|
||||
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
|
||||
._default_input_processor(data, model_config, vlm_config)
|
||||
._default_input_mapper(ctx, data)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor)
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
|
||||
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_pixel_mapper)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@ -22,7 +22,8 @@ from PIL import Image
|
||||
from transformers import CLIPVisionConfig, PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -34,9 +35,10 @@ from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
|
||||
from vllm.multimodal.image import ImagePixelData
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .clip import dummy_pixel_data_for_clip, dummy_seq_data_for_clip
|
||||
from .interfaces import SupportsVision
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -107,7 +109,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
self.num_img_tokens = config.img_processor['num_img_tokens']
|
||||
|
||||
self.image_dim_out = image_dim_out
|
||||
self.img_sizes = None
|
||||
|
||||
# global_gn and sub_gn for hd transform, serves as line separator
|
||||
self.use_hd_transform = config.embd_layer.get('use_hd_transform',
|
||||
@ -134,7 +135,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
self.img_projection = nn.Sequential(*layers)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.img_features = None
|
||||
|
||||
self.layer_idx = config.img_processor.get('layer_idx', -2)
|
||||
self.type_feature = config.img_processor.get('type_feature', 'patch')
|
||||
@ -260,9 +260,44 @@ class Phi3VImagePixelInputs(TypedDict):
|
||||
"""Shape: (batch_size, 2)"""
|
||||
|
||||
|
||||
# FIXME(Isotr0py): Remove these after dynamic num_img_tokens is supported
|
||||
# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||
def calc_padded_size(width, height, padding_unit=336):
|
||||
def _get_phi3v_image_feature_size(
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
) -> int:
|
||||
h, w = input_height, input_width
|
||||
|
||||
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178
|
||||
return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12
|
||||
|
||||
|
||||
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
|
||||
#TODO: change the logic for dummy data to support dynamic shape
|
||||
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
|
||||
image_feature_size = _get_phi3v_image_feature_size(
|
||||
input_height=dummy_height,
|
||||
input_width=dummy_width,
|
||||
)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
seq_len,
|
||||
image_token_id=32044,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_pixel_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
image_width_override=dummy_width,
|
||||
image_height_override=dummy_height,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||
target_height = int(np.ceil(height / padding_unit) * padding_unit)
|
||||
top_padding = int((target_height - height) / 2)
|
||||
bottom_padding = target_height - height - top_padding
|
||||
@ -271,8 +306,8 @@ def calc_padded_size(width, height, padding_unit=336):
|
||||
return padded_width, padded_height
|
||||
|
||||
|
||||
# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||
def calc_hd_transform_size(width, height, hd_num=16):
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
|
||||
transposed = False
|
||||
if width < height:
|
||||
width, height = height, width
|
||||
@ -287,7 +322,8 @@ def calc_hd_transform_size(width, height, hd_num=16):
|
||||
new_width = int(scale * 336)
|
||||
new_height = int(new_width / ratio)
|
||||
|
||||
padded_width, padded_height = calc_padded_size(new_width, new_height)
|
||||
padded_width, padded_height = _calc_padded_size(width=new_width,
|
||||
height=new_height)
|
||||
|
||||
if transposed:
|
||||
padded_width, padded_height = padded_height, padded_width
|
||||
@ -295,17 +331,15 @@ def calc_hd_transform_size(width, height, hd_num=16):
|
||||
return padded_width, padded_height
|
||||
|
||||
|
||||
def _image_processor(
|
||||
data: ImagePixelData,
|
||||
model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
def _image_processor(ctx: InputContext,
|
||||
data: ImagePixelData) -> Dict[str, torch.Tensor]:
|
||||
image = data.image
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
# Temporary patch before dynamic number of image tokens is supported
|
||||
_, _, h, w = vlm_config.image_input_shape
|
||||
if (w, h) != calc_hd_transform_size(image.width, image.height):
|
||||
_, _, h, w = ctx.get_multimodal_config().image_input_shape
|
||||
if (w, h) != _calc_hd_transform_size(width=image.width,
|
||||
height=image.height):
|
||||
logger.warning(
|
||||
"Dynamic image shape is currently not supported. "
|
||||
"Resizing input image to (%d, %d).", w, h)
|
||||
@ -313,11 +347,11 @@ def _image_processor(
|
||||
data.image = image.resize((w, h))
|
||||
|
||||
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
|
||||
._default_input_processor(data, model_config, vlm_config)
|
||||
._default_input_mapper(ctx, data)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_processor)
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_image_processor)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@ -1,5 +1,14 @@
|
||||
from .base import MultiModalData, MultiModalPlugin
|
||||
from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
"""
|
||||
The global :class:`~MultiModalRegistry` is used by model runners to
|
||||
dispatch data processing according to its modality and the target model.
|
||||
|
||||
See also:
|
||||
:ref:`input_processing_pipeline`
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY",
|
||||
|
||||
@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
|
||||
TypeVar)
|
||||
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -23,7 +24,7 @@ class MultiModalData:
|
||||
|
||||
Finally, register the new plugin to
|
||||
:const:`vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
This enables models to call :meth:`MultiModalRegistry.register_input` for
|
||||
This enables models to call :meth:`MultiModalRegistry.map_input` for
|
||||
the new modality.
|
||||
"""
|
||||
pass
|
||||
@ -32,10 +33,9 @@ class MultiModalData:
|
||||
D = TypeVar("D", bound=MultiModalData)
|
||||
N = TypeVar("N", bound=Type["nn.Module"])
|
||||
|
||||
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
|
||||
Dict[str, "torch.Tensor"]]
|
||||
MultiModalInputMapper = Callable[[InputContext, D], Dict[str, "torch.Tensor"]]
|
||||
"""Return a dictionary to be passed as keyword arguments to
|
||||
:meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers
|
||||
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
|
||||
and processors in HuggingFace Transformers."""
|
||||
|
||||
|
||||
@ -50,16 +50,9 @@ class MultiModalPlugin(ABC, Generic[D]):
|
||||
(i.e., the modality of the data).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]:
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
return get_model_architecture(model_config)[0]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._input_processors: Dict[Type["nn.Module"],
|
||||
MultiModalInputProcessor[D]] = {}
|
||||
self._input_mappers: Dict[Type["nn.Module"],
|
||||
MultiModalInputMapper[D]] = {}
|
||||
|
||||
@abstractmethod
|
||||
def get_data_type(self) -> Type[D]:
|
||||
@ -70,57 +63,62 @@ class MultiModalPlugin(ABC, Generic[D]):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _default_input_processor(
|
||||
self, data: D, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
|
||||
def _default_input_mapper(self, ctx: InputContext,
|
||||
data: D) -> Dict[str, "torch.Tensor"]:
|
||||
"""Return a dictionary to be passed as keyword arguments to
|
||||
:meth:`torch.nn.Module.forward`. This is similar in concept to
|
||||
:meth:`~torch.nn.Module.forward`. This is similar in concept to
|
||||
tokenizers and processors in HuggingFace Transformers.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def register_input_processor(self,
|
||||
processor: Optional[
|
||||
MultiModalInputProcessor[D]] = None):
|
||||
def register_input_mapper(
|
||||
self,
|
||||
mapper: Optional[MultiModalInputMapper[D]] = None,
|
||||
):
|
||||
"""
|
||||
Register an input processor to a model class.
|
||||
Register an input mapper to a model class.
|
||||
|
||||
When the model receives input data that matches the modality served by
|
||||
this plugin (see :meth:`get_data_type`), the provided input processor is
|
||||
applied to preprocess the data. If `None` is provided, then the default
|
||||
input processor is applied instead.
|
||||
this plugin (see :meth:`get_data_type`), the provided function is
|
||||
invoked to transform the data into a dictionary of model inputs.
|
||||
If `None` is provided, then the default input mapper is used instead.
|
||||
|
||||
See also:
|
||||
:ref:`input_processing_pipeline`
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._input_processors:
|
||||
if model_cls in self._input_mappers:
|
||||
logger.warning(
|
||||
"Model class %s already has an input processor "
|
||||
"Model class %s already has an input mapper "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._input_processors[model_cls] = processor \
|
||||
or self._default_input_processor
|
||||
self._input_mappers[model_cls] = mapper \
|
||||
or self._default_input_mapper
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def process_input(
|
||||
self, data: D, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
|
||||
def map_input(self, model_config: ModelConfig,
|
||||
data: D) -> Dict[str, "torch.Tensor"]:
|
||||
"""
|
||||
Apply an input processor to a :class:`~MultiModalData` instance passed
|
||||
to the model.
|
||||
|
||||
The model is identified by ``model_config``. ``vlm_config`` is
|
||||
for compatibility purposes and may be merged into ``model_config``
|
||||
in the near future.
|
||||
"""
|
||||
model_cls = self.get_model_cls(model_config)
|
||||
Apply an input mapper to a :class:`~MultiModalData` instance passed
|
||||
to the model, transforming the data into a dictionary of model inputs.
|
||||
|
||||
processor = self._input_processors.get(model_cls)
|
||||
if processor is None:
|
||||
raise KeyError(f"No input processor in {self} is registered for "
|
||||
The model is identified by ``model_config``.
|
||||
|
||||
TODO: Add guide [ref: PR #5276]
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
|
||||
mapper = self._input_mappers.get(model_cls)
|
||||
if mapper is None:
|
||||
raise KeyError(f"No input mapper in {self} is registered for "
|
||||
f"model class {model_cls.__name__}.")
|
||||
|
||||
return processor(data, model_config, vlm_config)
|
||||
return mapper(InputContext(model_config), data)
|
||||
|
||||
@ -1,70 +1,28 @@
|
||||
from typing import Dict, Tuple, Type, Union
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import SequenceData
|
||||
from vllm.transformers_utils.image_processor import cached_get_image_processor
|
||||
from vllm.transformers_utils.image_processor import get_image_processor
|
||||
|
||||
from .base import MultiModalData, MultiModalPlugin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_dummy_seq_data(seq_len: int,
|
||||
vlm_config: VisionLanguageConfig) -> SequenceData:
|
||||
# NOTE: We assume that <image> token is repeated `image_feature_size` times
|
||||
# and then concatenated with the text prompt
|
||||
# TODO: Enable other ways of inserting the image into the prompt
|
||||
|
||||
token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size
|
||||
token_ids += [0] * (seq_len - vlm_config.image_feature_size)
|
||||
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor:
|
||||
if vlm_config.image_processor is None:
|
||||
values_dtype = torch.float16
|
||||
else:
|
||||
values_dtype = torch.uint8
|
||||
|
||||
return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype)
|
||||
|
||||
|
||||
def get_dummy_image_data(
|
||||
seq_len: int,
|
||||
model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
) -> Tuple[SequenceData, MultiModalData]:
|
||||
"""Standard dummy data factory for image data (to be used in
|
||||
:meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`)."""
|
||||
seq_data = _get_dummy_seq_data(seq_len, vlm_config)
|
||||
values = _get_dummy_values(vlm_config)
|
||||
|
||||
config_input_type = vlm_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
fake_mm_data: MultiModalData
|
||||
if config_input_type == ImageInputType.PIXEL_VALUES:
|
||||
fake_mm_data = ImagePixelData(values)
|
||||
elif config_input_type == ImageInputType.IMAGE_FEATURES:
|
||||
fake_mm_data = ImageFeatureData(values)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return seq_data, fake_mm_data
|
||||
cached_get_image_processor = lru_cache(get_image_processor)
|
||||
|
||||
|
||||
class ImagePixelData(MultiModalData):
|
||||
"""
|
||||
The pixel data of an image. Can be one of:
|
||||
|
||||
- :class:``PIL.Image``: An image object. Requires that a HuggingFace
|
||||
- :class:`PIL.Image.Image`: An image object. Requires that a HuggingFace
|
||||
processor is available to the model.
|
||||
- :class:``torch.Tensor``: The raw pixel data which is passed to the model
|
||||
- :class:`torch.Tensor`: The raw pixel data which is passed to the model
|
||||
without additional pre-processing.
|
||||
"""
|
||||
|
||||
@ -89,8 +47,8 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
|
||||
def get_data_type(self) -> Type[ImagePixelData]:
|
||||
return ImagePixelData
|
||||
|
||||
def _get_hf_image_processor(self, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig):
|
||||
def _get_hf_image_processor(self, model_config: ModelConfig):
|
||||
vlm_config = model_config.multimodal_config
|
||||
if vlm_config is None or vlm_config.image_processor is None:
|
||||
return None
|
||||
|
||||
@ -100,14 +58,13 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
|
||||
revision=vlm_config.image_processor_revision,
|
||||
)
|
||||
|
||||
def _default_input_processor(
|
||||
self, data: ImagePixelData, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
|
||||
def _default_input_mapper(self, ctx: InputContext,
|
||||
data: ImagePixelData) -> Dict[str, torch.Tensor]:
|
||||
model_config = ctx.model_config
|
||||
image = data.image
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
image_processor = self._get_hf_image_processor(
|
||||
model_config, vlm_config)
|
||||
image_processor = self._get_hf_image_processor(model_config)
|
||||
if image_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available"
|
||||
"to process the image object")
|
||||
@ -147,9 +104,10 @@ class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
|
||||
def get_data_type(self) -> Type[ImageFeatureData]:
|
||||
return ImageFeatureData
|
||||
|
||||
def _default_input_processor(
|
||||
self, data: ImageFeatureData, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
|
||||
def _default_input_mapper(
|
||||
self, ctx: InputContext,
|
||||
data: ImageFeatureData) -> Dict[str, torch.Tensor]:
|
||||
model_config = ctx.model_config
|
||||
image_features = data.image_features.to(model_config.dtype)
|
||||
|
||||
return {"image_features": image_features}
|
||||
|
||||
@ -1,46 +1,35 @@
|
||||
import functools
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence,
|
||||
Tuple, Type, TypeVar)
|
||||
from typing import Any, Optional, Sequence, Type, TypeVar
|
||||
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import MultiModalData, MultiModalPlugin
|
||||
from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin
|
||||
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
|
||||
ImagePixelPlugin)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
D = TypeVar("D", bound=MultiModalData)
|
||||
N = TypeVar("N", bound=Type["nn.Module"])
|
||||
|
||||
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
|
||||
Dict[str, "torch.Tensor"]]
|
||||
MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig],
|
||||
Tuple["SequenceData", MultiModalData]]
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
This registry is used by model runners to dispatch data processing
|
||||
A registry to dispatch data processing
|
||||
according to its modality and the target model.
|
||||
"""
|
||||
|
||||
DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS,
|
||||
) -> None:
|
||||
self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
|
||||
self._dummy_factories_by_model_type: Dict[Type["nn.Module"],
|
||||
MultiModalDummyFactory] = {}
|
||||
|
||||
def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
|
||||
data_type = plugin.get_data_type()
|
||||
@ -62,95 +51,53 @@ class MultiModalRegistry:
|
||||
msg = f"Unknown multi-modal data type: {data_type}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def register_dummy_data(self, factory: MultiModalDummyFactory):
|
||||
def register_input_mapper(
|
||||
self,
|
||||
data_type: Type[D],
|
||||
mapper: Optional[MultiModalInputMapper[D]] = None,
|
||||
):
|
||||
"""
|
||||
Register a dummy data factory to a model class.
|
||||
Register an input mapper for a specific modality to a model class.
|
||||
|
||||
During memory profiling, the provided function is invoked to create
|
||||
dummy data to be inputted into the model. The modality and shape of
|
||||
the dummy data should be an upper bound of what the model would receive
|
||||
at inference time.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._dummy_factories_by_model_type:
|
||||
logger.warning(
|
||||
"Model class %s already has dummy data "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._dummy_factories_by_model_type[model_cls] = factory
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig):
|
||||
"""Create dummy data for memory profiling."""
|
||||
model_cls = MultiModalPlugin.get_model_cls(model_config)
|
||||
dummy_factory = self._dummy_factories_by_model_type.get(model_cls)
|
||||
if dummy_factory is None:
|
||||
msg = f"No dummy data defined for model class: {model_cls}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return dummy_factory(seq_len, model_config, vlm_config)
|
||||
|
||||
def register_input(
|
||||
self,
|
||||
data_type: Type[D],
|
||||
processor: Optional[MultiModalInputProcessor[D]] = None):
|
||||
"""
|
||||
Register an input processor for a specific modality to a model class.
|
||||
|
||||
See :meth:`MultiModalPlugin.register_input_processor` for more details.
|
||||
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
||||
"""
|
||||
return self._get_plugin_for_data_type(data_type) \
|
||||
.register_input_processor(processor)
|
||||
.register_input_mapper(mapper)
|
||||
|
||||
def register_image_pixel_input(
|
||||
self,
|
||||
processor: Optional[
|
||||
MultiModalInputProcessor[ImagePixelData]] = None):
|
||||
"""
|
||||
Register an input processor for image pixel data to a model class.
|
||||
|
||||
See :meth:`MultiModalPlugin.register_input_processor` for more details.
|
||||
"""
|
||||
return self.register_input(ImagePixelData, processor)
|
||||
|
||||
def register_image_feature_input(
|
||||
def register_image_pixel_input_mapper(
|
||||
self,
|
||||
processor: Optional[
|
||||
MultiModalInputProcessor[ImageFeatureData]] = None):
|
||||
mapper: Optional[MultiModalInputMapper[ImagePixelData]] = None,
|
||||
):
|
||||
"""
|
||||
Register an input processor for image feature data to a model class.
|
||||
Register an input mapper for image pixel data to a model class.
|
||||
|
||||
See :meth:`MultiModalPlugin.register_input_processor` for more details.
|
||||
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
||||
"""
|
||||
return self.register_input(ImageFeatureData, processor)
|
||||
return self.register_input_mapper(ImagePixelData, mapper)
|
||||
|
||||
def process_input(self, data: MultiModalData, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig):
|
||||
def register_image_feature_input_mapper(
|
||||
self,
|
||||
mapper: Optional[MultiModalInputMapper[ImageFeatureData]] = None,
|
||||
):
|
||||
"""
|
||||
Apply an input processor to a :class:`~MultiModalData` instance passed
|
||||
Register an input mapper for image feature data to a model class.
|
||||
|
||||
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
||||
"""
|
||||
return self.register_input_mapper(ImageFeatureData, mapper)
|
||||
|
||||
def map_input(self, model_config: ModelConfig, data: MultiModalData):
|
||||
"""
|
||||
Apply an input mapper to a :class:`~MultiModalData` instance passed
|
||||
to the model.
|
||||
|
||||
See :meth:`MultiModalPlugin.process_input` for more details.
|
||||
See :meth:`MultiModalPlugin.map_input` for more details.
|
||||
"""
|
||||
return self._get_plugin_for_data_type(type(data)) \
|
||||
.process_input(data, model_config, vlm_config)
|
||||
.map_input(model_config, data)
|
||||
|
||||
def create_input_processor(self, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig):
|
||||
def create_input_mapper(self, model_config: ModelConfig):
|
||||
"""
|
||||
Create an input processor (see :meth:`process_input`) for a
|
||||
specific model.
|
||||
Create an input mapper (see :meth:`map_input`) for a specific model.
|
||||
"""
|
||||
return functools.partial(self.process_input,
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config)
|
||||
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
"""The global :class:`~MultiModalRegistry` which is used by model runners."""
|
||||
return functools.partial(self.map_input, model_config)
|
||||
|
||||
@ -8,12 +8,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.multimodal import MultiModalData
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
@ -221,7 +221,7 @@ class Sequence:
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
inputs: LLMInputs,
|
||||
inputs: "LLMInputs",
|
||||
block_size: int,
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoImageProcessor
|
||||
@ -40,6 +39,3 @@ def get_image_processor(
|
||||
raise e
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
cached_get_image_processor = lru_cache(get_image_processor)
|
||||
|
||||
@ -110,15 +110,9 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
# Create processor for multi-modal data
|
||||
if self.vision_language_config is not None:
|
||||
self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
|
||||
.create_input_processor(
|
||||
self.model_config,
|
||||
self.vision_language_config,
|
||||
)
|
||||
else:
|
||||
self.multi_modal_input_processor = None
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
@ -168,13 +162,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data is not None:
|
||||
# Process multi-modal data
|
||||
if self.multi_modal_input_processor is None:
|
||||
raise ValueError(
|
||||
"Multi-modal inputs are only supported by "
|
||||
"vision language models.")
|
||||
|
||||
mm_kwargs = self.multi_modal_input_processor(mm_data)
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
for k, v in mm_kwargs.items():
|
||||
multi_modal_kwargs_list[k].append(v)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -25,7 +26,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models.interfaces import supports_lora
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
|
||||
is_pin_memory_available, make_tensor_with_pad)
|
||||
from vllm.worker.model_runner_base import (
|
||||
@ -191,15 +192,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.block_size,
|
||||
) if num_attn_heads else None
|
||||
|
||||
# Create processor for multi-modal data
|
||||
if self.vision_language_config is not None:
|
||||
self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
|
||||
.create_input_processor(
|
||||
self.model_config,
|
||||
self.vision_language_config,
|
||||
)
|
||||
else:
|
||||
self.multi_modal_input_processor = None
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
# Lazy initialization
|
||||
self.model: nn.Module # Set after load_model
|
||||
@ -506,12 +501,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data is not None:
|
||||
# Process multi-modal data
|
||||
if self.multi_modal_input_processor is None:
|
||||
raise ValueError(
|
||||
"Multi-modal inputs are only supported by "
|
||||
"vision language models.")
|
||||
|
||||
mm_kwargs = self.multi_modal_input_processor(mm_data)
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
for k, v in mm_kwargs.items():
|
||||
multi_modal_kwargs_list[k].append(v)
|
||||
|
||||
@ -764,12 +754,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
|
||||
if vlm_config is None:
|
||||
seq_data = SequenceData([0] * seq_len)
|
||||
dummy_multi_modal_data = None
|
||||
else:
|
||||
seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \
|
||||
.dummy_data_for_profiling(seq_len, model_config, vlm_config)
|
||||
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
|
||||
.dummy_data_for_profiling(model_config, seq_len)
|
||||
assert len(seq_data.prompt_token_ids) == seq_len
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user