Support custom implementations of VideoLoader backends. (#18091)

This commit is contained in:
Chenheli Hua 2025-05-14 22:26:49 -07:00 committed by GitHub
parent e6b8e65d2d
commit 4f07a64075
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 83 additions and 2 deletions

View File

@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import numpy.typing as npt
import pytest
from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader
NUM_FRAMES = 10
FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
@VIDEO_LOADER_REGISTRY.register("test_video_loader_1")
class TestVideoLoader1(VideoLoader):
@classmethod
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
return FAKE_OUTPUT_1
@VIDEO_LOADER_REGISTRY.register("test_video_loader_2")
class TestVideoLoader2(VideoLoader):
@classmethod
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
return FAKE_OUTPUT_2
def test_video_loader_registry():
custom_loader_1 = VIDEO_LOADER_REGISTRY.load("test_video_loader_1")
output_1 = custom_loader_1.load_bytes(b"test")
np.testing.assert_array_equal(output_1, FAKE_OUTPUT_1)
custom_loader_2 = VIDEO_LOADER_REGISTRY.load("test_video_loader_2")
output_2 = custom_loader_2.load_bytes(b"test")
np.testing.assert_array_equal(output_2, FAKE_OUTPUT_2)
def test_video_loader_type_doesnt_exist():
with pytest.raises(AssertionError):
VIDEO_LOADER_REGISTRY.load("non_existing_video_loader")

View File

@ -55,6 +55,7 @@ if TYPE_CHECKING:
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_VIDEO_FETCH_TIMEOUT: int = 30
VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
VLLM_MM_INPUT_CACHE_GIB: int = 8 VLLM_MM_INPUT_CACHE_GIB: int = 8
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None MAX_JOBS: Optional[str] = None
@ -446,6 +447,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_AUDIO_FETCH_TIMEOUT": "VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
# Backend for Video IO
# - "opencv": Default backend that uses OpenCV stream buffered backend.
#
# Custom backend implementations can be registered
# via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and
# imported at runtime.
# If a non-existing backend is used, an AssertionError will be thrown.
"VLLM_VIDEO_LOADER_BACKEND":
lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"),
# Cache size (in GiB) for multimodal input cache # Cache size (in GiB) for multimodal input cache
# Default is 4 GiB # Default is 4 GiB
"VLLM_MM_INPUT_CACHE_GIB": "VLLM_MM_INPUT_CACHE_GIB":

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import base64 import base64
from abc import abstractmethod
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
@ -9,6 +10,8 @@ import numpy as np
import numpy.typing as npt import numpy.typing as npt
from PIL import Image from PIL import Image
from vllm import envs
from .base import MediaIO from .base import MediaIO
from .image import ImageMediaIO from .image import ImageMediaIO
@ -48,10 +51,35 @@ def sample_frames_from_video(frames: npt.NDArray,
class VideoLoader: class VideoLoader:
@classmethod @classmethod
def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray: @abstractmethod
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
raise NotImplementedError raise NotImplementedError
class VideoLoaderRegistry:
def __init__(self) -> None:
self.name2class: dict[str, type] = {}
def register(self, name: str):
def wrap(cls_to_register):
self.name2class[name] = cls_to_register
return cls_to_register
return wrap
@staticmethod
def load(cls_name: str) -> VideoLoader:
cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name)
assert cls is not None, f"VideoLoader class {cls_name} not found"
return cls()
VIDEO_LOADER_REGISTRY = VideoLoaderRegistry()
@VIDEO_LOADER_REGISTRY.register("opencv")
class OpenCVVideoBackend(VideoLoader): class OpenCVVideoBackend(VideoLoader):
def get_cv2_video_api(self): def get_cv2_video_api(self):
@ -122,7 +150,8 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
self.image_io = image_io self.image_io = image_io
self.num_frames = num_frames self.num_frames = num_frames
self.video_loader = OpenCVVideoBackend video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
def load_bytes(self, data: bytes) -> npt.NDArray: def load_bytes(self, data: bytes) -> npt.NDArray:
return self.video_loader.load_bytes(data, self.num_frames) return self.video_loader.load_bytes(data, self.num_frames)