diff --git a/tests/multimodal/test_video.py b/tests/multimodal/test_video.py new file mode 100644 index 0000000000000..e67624ecefcb6 --- /dev/null +++ b/tests/multimodal/test_video.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +import numpy as np +import numpy.typing as npt +import pytest + +from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader + +NUM_FRAMES = 10 +FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3) +FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3) + + +@VIDEO_LOADER_REGISTRY.register("test_video_loader_1") +class TestVideoLoader1(VideoLoader): + + @classmethod + def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: + return FAKE_OUTPUT_1 + + +@VIDEO_LOADER_REGISTRY.register("test_video_loader_2") +class TestVideoLoader2(VideoLoader): + + @classmethod + def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: + return FAKE_OUTPUT_2 + + +def test_video_loader_registry(): + custom_loader_1 = VIDEO_LOADER_REGISTRY.load("test_video_loader_1") + output_1 = custom_loader_1.load_bytes(b"test") + np.testing.assert_array_equal(output_1, FAKE_OUTPUT_1) + + custom_loader_2 = VIDEO_LOADER_REGISTRY.load("test_video_loader_2") + output_2 = custom_loader_2.load_bytes(b"test") + np.testing.assert_array_equal(output_2, FAKE_OUTPUT_2) + + +def test_video_loader_type_doesnt_exist(): + with pytest.raises(AssertionError): + VIDEO_LOADER_REGISTRY.load("non_existing_video_loader") diff --git a/vllm/envs.py b/vllm/envs.py index 9d585bf3578e1..fe3fa91fbe337 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -55,6 +55,7 @@ if TYPE_CHECKING: VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_AUDIO_FETCH_TIMEOUT: int = 10 + VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_MM_INPUT_CACHE_GIB: int = 8 VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None @@ -446,6 +447,16 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_AUDIO_FETCH_TIMEOUT": lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), + # Backend for Video IO + # - "opencv": Default backend that uses OpenCV stream buffered backend. + # + # Custom backend implementations can be registered + # via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and + # imported at runtime. + # If a non-existing backend is used, an AssertionError will be thrown. + "VLLM_VIDEO_LOADER_BACKEND": + lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"), + # Cache size (in GiB) for multimodal input cache # Default is 4 GiB "VLLM_MM_INPUT_CACHE_GIB": diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 72e9b65d763cd..3685fd4c34580 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import base64 +from abc import abstractmethod from functools import partial from io import BytesIO from pathlib import Path @@ -9,6 +10,8 @@ import numpy as np import numpy.typing as npt from PIL import Image +from vllm import envs + from .base import MediaIO from .image import ImageMediaIO @@ -48,10 +51,35 @@ def sample_frames_from_video(frames: npt.NDArray, class VideoLoader: @classmethod - def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray: + @abstractmethod + def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: raise NotImplementedError +class VideoLoaderRegistry: + + def __init__(self) -> None: + self.name2class: dict[str, type] = {} + + def register(self, name: str): + + def wrap(cls_to_register): + self.name2class[name] = cls_to_register + return cls_to_register + + return wrap + + @staticmethod + def load(cls_name: str) -> VideoLoader: + cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name) + assert cls is not None, f"VideoLoader class {cls_name} not found" + return cls() + + +VIDEO_LOADER_REGISTRY = VideoLoaderRegistry() + + +@VIDEO_LOADER_REGISTRY.register("opencv") class OpenCVVideoBackend(VideoLoader): def get_cv2_video_api(self): @@ -122,7 +150,8 @@ class VideoMediaIO(MediaIO[npt.NDArray]): self.image_io = image_io self.num_frames = num_frames - self.video_loader = OpenCVVideoBackend + video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND + self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend) def load_bytes(self, data: bytes) -> npt.NDArray: return self.video_loader.load_bytes(data, self.num_frames)