mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-28 23:17:14 +08:00
104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
import functools
|
|
from typing import Any, Optional, Sequence, Type, TypeVar
|
|
|
|
from torch import nn
|
|
|
|
from vllm.config import ModelConfig
|
|
from vllm.logger import init_logger
|
|
|
|
from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin
|
|
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
|
|
ImagePixelPlugin)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
D = TypeVar("D", bound=MultiModalData)
|
|
N = TypeVar("N", bound=Type[nn.Module])
|
|
|
|
|
|
class MultiModalRegistry:
|
|
"""
|
|
A registry to dispatch data processing
|
|
according to its modality and the target model.
|
|
"""
|
|
|
|
DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS,
|
|
) -> None:
|
|
self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
|
|
|
|
def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
|
|
data_type = plugin.get_data_type()
|
|
|
|
if data_type in self._plugins_by_data_type:
|
|
logger.warning(
|
|
"A plugin is already registered for data type %s, "
|
|
"and will be overwritten by the new plugin %s.", data_type,
|
|
plugin)
|
|
|
|
self._plugins_by_data_type[data_type] = plugin
|
|
|
|
def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]):
|
|
for typ in data_type.mro():
|
|
plugin = self._plugins_by_data_type.get(typ)
|
|
if plugin is not None:
|
|
return plugin
|
|
|
|
msg = f"Unknown multi-modal data type: {data_type}"
|
|
raise NotImplementedError(msg)
|
|
|
|
def register_input_mapper(
|
|
self,
|
|
data_type: Type[D],
|
|
mapper: Optional[MultiModalInputMapper[D]] = None,
|
|
):
|
|
"""
|
|
Register an input mapper for a specific modality to a model class.
|
|
|
|
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
|
"""
|
|
return self._get_plugin_for_data_type(data_type) \
|
|
.register_input_mapper(mapper)
|
|
|
|
def register_image_pixel_input_mapper(
|
|
self,
|
|
mapper: Optional[MultiModalInputMapper[ImagePixelData]] = None,
|
|
):
|
|
"""
|
|
Register an input mapper for image pixel data to a model class.
|
|
|
|
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
|
"""
|
|
return self.register_input_mapper(ImagePixelData, mapper)
|
|
|
|
def register_image_feature_input_mapper(
|
|
self,
|
|
mapper: Optional[MultiModalInputMapper[ImageFeatureData]] = None,
|
|
):
|
|
"""
|
|
Register an input mapper for image feature data to a model class.
|
|
|
|
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
|
"""
|
|
return self.register_input_mapper(ImageFeatureData, mapper)
|
|
|
|
def map_input(self, model_config: ModelConfig, data: MultiModalData):
|
|
"""
|
|
Apply an input mapper to a :class:`~MultiModalData` instance passed
|
|
to the model.
|
|
|
|
See :meth:`MultiModalPlugin.map_input` for more details.
|
|
"""
|
|
return self._get_plugin_for_data_type(type(data)) \
|
|
.map_input(model_config, data)
|
|
|
|
def create_input_mapper(self, model_config: ModelConfig):
|
|
"""
|
|
Create an input mapper (see :meth:`map_input`) for a specific model.
|
|
"""
|
|
return functools.partial(self.map_input, model_config)
|