mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 04:05:01 +08:00
[Multimodal] Make MediaConnector extensible. (#27759)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
This commit is contained in:
parent
611c86ea3c
commit
1fb4217a05
@ -43,11 +43,12 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, Processor
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsMultiModal
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.multimodal.utils import MediaConnector
|
||||
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
@ -806,7 +807,9 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._tracker = tracker
|
||||
multimodal_config = self._tracker.model_config.multimodal_config
|
||||
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
|
||||
self._connector = MediaConnector(
|
||||
|
||||
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
|
||||
envs.VLLM_MEDIA_CONNECTOR,
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
@ -891,7 +894,8 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._tracker = tracker
|
||||
multimodal_config = self._tracker.model_config.multimodal_config
|
||||
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
|
||||
self._connector = MediaConnector(
|
||||
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
|
||||
envs.VLLM_MEDIA_CONNECTOR,
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
|
||||
@ -70,6 +70,7 @@ if TYPE_CHECKING:
|
||||
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
|
||||
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
|
||||
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
||||
VLLM_MEDIA_CONNECTOR: str = "http"
|
||||
VLLM_MM_INPUT_CACHE_GIB: int = 4
|
||||
VLLM_TARGET_DEVICE: str = "cuda"
|
||||
VLLM_MAIN_CUDA_VERSION: str = "12.8"
|
||||
@ -738,6 +739,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv(
|
||||
"VLLM_VIDEO_LOADER_BACKEND", "opencv"
|
||||
),
|
||||
# Media connector implementation.
|
||||
# - "http": Default connector that supports fetching media via HTTP.
|
||||
#
|
||||
# Custom implementations can be registered
|
||||
# via `@MEDIA_CONNECTOR_REGISTRY.register("my_custom_media_connector")` and
|
||||
# imported at runtime.
|
||||
# If a non-existing backend is used, an AssertionError will be thrown.
|
||||
"VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"),
|
||||
# [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
|
||||
# Default is 4 GiB per API process + 4 GiB per engine core process
|
||||
"VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
|
||||
|
||||
@ -20,6 +20,7 @@ import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
from vllm.utils.registry import ExtensionManager
|
||||
|
||||
from .audio import AudioMediaIO
|
||||
from .base import MediaIO
|
||||
@ -46,7 +47,10 @@ atexit.register(global_thread_pool.shutdown)
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
|
||||
|
||||
|
||||
@MEDIA_CONNECTOR_REGISTRY.register("http")
|
||||
class MediaConnector:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -14,6 +14,7 @@ from PIL import Image
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.registry import ExtensionManager
|
||||
|
||||
from .base import MediaIO
|
||||
from .image import ImageMediaIO
|
||||
@ -63,25 +64,7 @@ class VideoLoader:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VideoLoaderRegistry:
|
||||
def __init__(self) -> None:
|
||||
self.name2class: dict[str, type] = {}
|
||||
|
||||
def register(self, name: str):
|
||||
def wrap(cls_to_register):
|
||||
self.name2class[name] = cls_to_register
|
||||
return cls_to_register
|
||||
|
||||
return wrap
|
||||
|
||||
@staticmethod
|
||||
def load(cls_name: str) -> VideoLoader:
|
||||
cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name)
|
||||
assert cls is not None, f"VideoLoader class {cls_name} not found"
|
||||
return cls()
|
||||
|
||||
|
||||
VIDEO_LOADER_REGISTRY = VideoLoaderRegistry()
|
||||
VIDEO_LOADER_REGISTRY = ExtensionManager()
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv")
|
||||
|
||||
49
vllm/utils/registry.py
Normal file
49
vllm/utils/registry.py
Normal file
@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ExtensionManager:
|
||||
"""
|
||||
A registry for managing pluggable extension classes.
|
||||
|
||||
This class provides a simple mechanism to register and instantiate
|
||||
extension classes by name. It is commonly used to implement plugin
|
||||
systems where different implementations can be swapped at runtime.
|
||||
|
||||
Examples:
|
||||
Basic usage with a registry instance:
|
||||
|
||||
>>> FOO_REGISTRY = ExtensionManager()
|
||||
>>> @FOO_REGISTRY.register("my_foo_impl")
|
||||
... class MyFooImpl(Foo):
|
||||
... def __init__(self, value):
|
||||
... self.value = value
|
||||
>>> foo_impl = FOO_REGISTRY.load("my_foo_impl", value=123)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize an empty extension registry.
|
||||
"""
|
||||
self.name2class: dict[str, type] = {}
|
||||
|
||||
def register(self, name: str):
|
||||
"""
|
||||
Decorator to register a class with the given name.
|
||||
"""
|
||||
|
||||
def wrap(cls_to_register):
|
||||
self.name2class[name] = cls_to_register
|
||||
return cls_to_register
|
||||
|
||||
return wrap
|
||||
|
||||
def load(self, cls_name: str, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Instantiate and return a registered extension class by name.
|
||||
"""
|
||||
cls = self.name2class.get(cls_name)
|
||||
assert cls is not None, f"Extension class {cls_name} not found"
|
||||
return cls(*args, **kwargs)
|
||||
Loading…
x
Reference in New Issue
Block a user