vllm/vllm/multimodal/registry.py
Cyrus Leung 5cbe8d155c
[Core] Registry for processing model inputs (#5214)
Co-authored-by: ywang96 <ywang@roblox.com>
2024-06-28 12:09:56 +00:00

104 lines
3.3 KiB
Python

import functools
from typing import Any, Optional, Sequence, Type, TypeVar
from torch import nn
from vllm.config import ModelConfig
from vllm.logger import init_logger
from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
ImagePixelPlugin)
logger = init_logger(__name__)
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type[nn.Module])
class MultiModalRegistry:
"""
A registry to dispatch data processing
according to its modality and the target model.
"""
DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
def __init__(
self,
*,
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS,
) -> None:
self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
data_type = plugin.get_data_type()
if data_type in self._plugins_by_data_type:
logger.warning(
"A plugin is already registered for data type %s, "
"and will be overwritten by the new plugin %s.", data_type,
plugin)
self._plugins_by_data_type[data_type] = plugin
def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]):
for typ in data_type.mro():
plugin = self._plugins_by_data_type.get(typ)
if plugin is not None:
return plugin
msg = f"Unknown multi-modal data type: {data_type}"
raise NotImplementedError(msg)
def register_input_mapper(
self,
data_type: Type[D],
mapper: Optional[MultiModalInputMapper[D]] = None,
):
"""
Register an input mapper for a specific modality to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self._get_plugin_for_data_type(data_type) \
.register_input_mapper(mapper)
def register_image_pixel_input_mapper(
self,
mapper: Optional[MultiModalInputMapper[ImagePixelData]] = None,
):
"""
Register an input mapper for image pixel data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper(ImagePixelData, mapper)
def register_image_feature_input_mapper(
self,
mapper: Optional[MultiModalInputMapper[ImageFeatureData]] = None,
):
"""
Register an input mapper for image feature data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper(ImageFeatureData, mapper)
def map_input(self, model_config: ModelConfig, data: MultiModalData):
"""
Apply an input mapper to a :class:`~MultiModalData` instance passed
to the model.
See :meth:`MultiModalPlugin.map_input` for more details.
"""
return self._get_plugin_for_data_type(type(data)) \
.map_input(model_config, data)
def create_input_mapper(self, model_config: ModelConfig):
"""
Create an input mapper (see :meth:`map_input`) for a specific model.
"""
return functools.partial(self.map_input, model_config)