[Multimodal] Make MediaConnector extensible. (#27759)

Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
This commit is contained in:
Chenheli Hua 2025-11-04 10:28:01 -08:00 committed by GitHub
parent 611c86ea3c
commit 1fb4217a05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 71 additions and 22 deletions

View File

@ -43,11 +43,12 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, Processor
# pydantic needs the TypedDict from typing_extensions # pydantic needs the TypedDict from typing_extensions
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict
from vllm import envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import MediaConnector from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
@ -806,7 +807,9 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._tracker = tracker self._tracker = tracker
multimodal_config = self._tracker.model_config.multimodal_config multimodal_config = self._tracker.model_config.multimodal_config
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None) media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
self._connector = MediaConnector(
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
envs.VLLM_MEDIA_CONNECTOR,
media_io_kwargs=media_io_kwargs, media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains, allowed_media_domains=tracker.allowed_media_domains,
@ -891,7 +894,8 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._tracker = tracker self._tracker = tracker
multimodal_config = self._tracker.model_config.multimodal_config multimodal_config = self._tracker.model_config.multimodal_config
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None) media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
self._connector = MediaConnector( self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
envs.VLLM_MEDIA_CONNECTOR,
media_io_kwargs=media_io_kwargs, media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains, allowed_media_domains=tracker.allowed_media_domains,

View File

@ -70,6 +70,7 @@ if TYPE_CHECKING:
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8 VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
VLLM_MEDIA_CONNECTOR: str = "http"
VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_MM_INPUT_CACHE_GIB: int = 4
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
VLLM_MAIN_CUDA_VERSION: str = "12.8" VLLM_MAIN_CUDA_VERSION: str = "12.8"
@ -738,6 +739,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv( "VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv(
"VLLM_VIDEO_LOADER_BACKEND", "opencv" "VLLM_VIDEO_LOADER_BACKEND", "opencv"
), ),
# Media connector implementation.
# - "http": Default connector that supports fetching media via HTTP.
#
# Custom implementations can be registered
# via `@MEDIA_CONNECTOR_REGISTRY.register("my_custom_media_connector")` and
# imported at runtime.
# If a non-existing backend is used, an AssertionError will be thrown.
"VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"),
# [DEPRECATED] Cache size (in GiB per process) for multimodal input cache # [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
# Default is 4 GiB per API process + 4 GiB per engine core process # Default is 4 GiB per API process + 4 GiB per engine core process
"VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), "VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),

View File

@ -20,6 +20,7 @@ import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.jsontree import json_map_leaves from vllm.utils.jsontree import json_map_leaves
from vllm.utils.registry import ExtensionManager
from .audio import AudioMediaIO from .audio import AudioMediaIO
from .base import MediaIO from .base import MediaIO
@ -46,7 +47,10 @@ atexit.register(global_thread_pool.shutdown)
_M = TypeVar("_M") _M = TypeVar("_M")
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
@MEDIA_CONNECTOR_REGISTRY.register("http")
class MediaConnector: class MediaConnector:
def __init__( def __init__(
self, self,

View File

@ -14,6 +14,7 @@ from PIL import Image
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.registry import ExtensionManager
from .base import MediaIO from .base import MediaIO
from .image import ImageMediaIO from .image import ImageMediaIO
@ -63,25 +64,7 @@ class VideoLoader:
raise NotImplementedError raise NotImplementedError
class VideoLoaderRegistry: VIDEO_LOADER_REGISTRY = ExtensionManager()
def __init__(self) -> None:
self.name2class: dict[str, type] = {}
def register(self, name: str):
def wrap(cls_to_register):
self.name2class[name] = cls_to_register
return cls_to_register
return wrap
@staticmethod
def load(cls_name: str) -> VideoLoader:
cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name)
assert cls is not None, f"VideoLoader class {cls_name} not found"
return cls()
VIDEO_LOADER_REGISTRY = VideoLoaderRegistry()
@VIDEO_LOADER_REGISTRY.register("opencv") @VIDEO_LOADER_REGISTRY.register("opencv")

49
vllm/utils/registry.py Normal file
View File

@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
class ExtensionManager:
"""
A registry for managing pluggable extension classes.
This class provides a simple mechanism to register and instantiate
extension classes by name. It is commonly used to implement plugin
systems where different implementations can be swapped at runtime.
Examples:
Basic usage with a registry instance:
>>> FOO_REGISTRY = ExtensionManager()
>>> @FOO_REGISTRY.register("my_foo_impl")
... class MyFooImpl(Foo):
... def __init__(self, value):
... self.value = value
>>> foo_impl = FOO_REGISTRY.load("my_foo_impl", value=123)
"""
def __init__(self) -> None:
"""
Initialize an empty extension registry.
"""
self.name2class: dict[str, type] = {}
def register(self, name: str):
"""
Decorator to register a class with the given name.
"""
def wrap(cls_to_register):
self.name2class[name] = cls_to_register
return cls_to_register
return wrap
def load(self, cls_name: str, *args, **kwargs) -> Any:
"""
Instantiate and return a registered extension class by name.
"""
cls = self.name2class.get(cls_name)
assert cls is not None, f"Extension class {cls_name} not found"
return cls(*args, **kwargs)