mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 06:05:01 +08:00
[Multimodal] Make MediaConnector extensible. (#27759)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
This commit is contained in:
parent
611c86ea3c
commit
1fb4217a05
@ -43,11 +43,12 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, Processor
|
|||||||
# pydantic needs the TypedDict from typing_extensions
|
# pydantic needs the TypedDict from typing_extensions
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models import SupportsMultiModal
|
from vllm.model_executor.models import SupportsMultiModal
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
|
||||||
from vllm.multimodal.utils import MediaConnector
|
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
|
||||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
@ -806,7 +807,9 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
self._tracker = tracker
|
self._tracker = tracker
|
||||||
multimodal_config = self._tracker.model_config.multimodal_config
|
multimodal_config = self._tracker.model_config.multimodal_config
|
||||||
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
|
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
|
||||||
self._connector = MediaConnector(
|
|
||||||
|
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
|
||||||
|
envs.VLLM_MEDIA_CONNECTOR,
|
||||||
media_io_kwargs=media_io_kwargs,
|
media_io_kwargs=media_io_kwargs,
|
||||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||||
allowed_media_domains=tracker.allowed_media_domains,
|
allowed_media_domains=tracker.allowed_media_domains,
|
||||||
@ -891,7 +894,8 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
self._tracker = tracker
|
self._tracker = tracker
|
||||||
multimodal_config = self._tracker.model_config.multimodal_config
|
multimodal_config = self._tracker.model_config.multimodal_config
|
||||||
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
|
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
|
||||||
self._connector = MediaConnector(
|
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
|
||||||
|
envs.VLLM_MEDIA_CONNECTOR,
|
||||||
media_io_kwargs=media_io_kwargs,
|
media_io_kwargs=media_io_kwargs,
|
||||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||||
allowed_media_domains=tracker.allowed_media_domains,
|
allowed_media_domains=tracker.allowed_media_domains,
|
||||||
|
|||||||
@ -70,6 +70,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
|
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
|
||||||
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
|
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
|
||||||
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
||||||
|
VLLM_MEDIA_CONNECTOR: str = "http"
|
||||||
VLLM_MM_INPUT_CACHE_GIB: int = 4
|
VLLM_MM_INPUT_CACHE_GIB: int = 4
|
||||||
VLLM_TARGET_DEVICE: str = "cuda"
|
VLLM_TARGET_DEVICE: str = "cuda"
|
||||||
VLLM_MAIN_CUDA_VERSION: str = "12.8"
|
VLLM_MAIN_CUDA_VERSION: str = "12.8"
|
||||||
@ -738,6 +739,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv(
|
"VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv(
|
||||||
"VLLM_VIDEO_LOADER_BACKEND", "opencv"
|
"VLLM_VIDEO_LOADER_BACKEND", "opencv"
|
||||||
),
|
),
|
||||||
|
# Media connector implementation.
|
||||||
|
# - "http": Default connector that supports fetching media via HTTP.
|
||||||
|
#
|
||||||
|
# Custom implementations can be registered
|
||||||
|
# via `@MEDIA_CONNECTOR_REGISTRY.register("my_custom_media_connector")` and
|
||||||
|
# imported at runtime.
|
||||||
|
# If a non-existing backend is used, an AssertionError will be thrown.
|
||||||
|
"VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"),
|
||||||
# [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
|
# [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
|
||||||
# Default is 4 GiB per API process + 4 GiB per engine core process
|
# Default is 4 GiB per API process + 4 GiB per engine core process
|
||||||
"VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
|
"VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import vllm.envs as envs
|
|||||||
from vllm.connections import HTTPConnection, global_http_connection
|
from vllm.connections import HTTPConnection, global_http_connection
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.jsontree import json_map_leaves
|
from vllm.utils.jsontree import json_map_leaves
|
||||||
|
from vllm.utils.registry import ExtensionManager
|
||||||
|
|
||||||
from .audio import AudioMediaIO
|
from .audio import AudioMediaIO
|
||||||
from .base import MediaIO
|
from .base import MediaIO
|
||||||
@ -46,7 +47,10 @@ atexit.register(global_thread_pool.shutdown)
|
|||||||
|
|
||||||
_M = TypeVar("_M")
|
_M = TypeVar("_M")
|
||||||
|
|
||||||
|
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
|
||||||
|
|
||||||
|
|
||||||
|
@MEDIA_CONNECTOR_REGISTRY.register("http")
|
||||||
class MediaConnector:
|
class MediaConnector:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.registry import ExtensionManager
|
||||||
|
|
||||||
from .base import MediaIO
|
from .base import MediaIO
|
||||||
from .image import ImageMediaIO
|
from .image import ImageMediaIO
|
||||||
@ -63,25 +64,7 @@ class VideoLoader:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class VideoLoaderRegistry:
|
VIDEO_LOADER_REGISTRY = ExtensionManager()
|
||||||
def __init__(self) -> None:
|
|
||||||
self.name2class: dict[str, type] = {}
|
|
||||||
|
|
||||||
def register(self, name: str):
|
|
||||||
def wrap(cls_to_register):
|
|
||||||
self.name2class[name] = cls_to_register
|
|
||||||
return cls_to_register
|
|
||||||
|
|
||||||
return wrap
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(cls_name: str) -> VideoLoader:
|
|
||||||
cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name)
|
|
||||||
assert cls is not None, f"VideoLoader class {cls_name} not found"
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
|
|
||||||
VIDEO_LOADER_REGISTRY = VideoLoaderRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
@VIDEO_LOADER_REGISTRY.register("opencv")
|
@VIDEO_LOADER_REGISTRY.register("opencv")
|
||||||
|
|||||||
49
vllm/utils/registry.py
Normal file
49
vllm/utils/registry.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionManager:
|
||||||
|
"""
|
||||||
|
A registry for managing pluggable extension classes.
|
||||||
|
|
||||||
|
This class provides a simple mechanism to register and instantiate
|
||||||
|
extension classes by name. It is commonly used to implement plugin
|
||||||
|
systems where different implementations can be swapped at runtime.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Basic usage with a registry instance:
|
||||||
|
|
||||||
|
>>> FOO_REGISTRY = ExtensionManager()
|
||||||
|
>>> @FOO_REGISTRY.register("my_foo_impl")
|
||||||
|
... class MyFooImpl(Foo):
|
||||||
|
... def __init__(self, value):
|
||||||
|
... self.value = value
|
||||||
|
>>> foo_impl = FOO_REGISTRY.load("my_foo_impl", value=123)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Initialize an empty extension registry.
|
||||||
|
"""
|
||||||
|
self.name2class: dict[str, type] = {}
|
||||||
|
|
||||||
|
def register(self, name: str):
|
||||||
|
"""
|
||||||
|
Decorator to register a class with the given name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrap(cls_to_register):
|
||||||
|
self.name2class[name] = cls_to_register
|
||||||
|
return cls_to_register
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
def load(self, cls_name: str, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Instantiate and return a registered extension class by name.
|
||||||
|
"""
|
||||||
|
cls = self.name2class.get(cls_name)
|
||||||
|
assert cls is not None, f"Extension class {cls_name} not found"
|
||||||
|
return cls(*args, **kwargs)
|
||||||
Loading…
x
Reference in New Issue
Block a user