mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
Support custom implementations of VideoLoader backends. (#18091)
This commit is contained in:
parent
e6b8e65d2d
commit
4f07a64075
41
tests/multimodal/test_video.py
Normal file
41
tests/multimodal/test_video.py
Normal file
@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader
|
||||
|
||||
NUM_FRAMES = 10
|
||||
FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
|
||||
FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("test_video_loader_1")
|
||||
class TestVideoLoader1(VideoLoader):
|
||||
|
||||
@classmethod
|
||||
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||
return FAKE_OUTPUT_1
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("test_video_loader_2")
|
||||
class TestVideoLoader2(VideoLoader):
|
||||
|
||||
@classmethod
|
||||
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||
return FAKE_OUTPUT_2
|
||||
|
||||
|
||||
def test_video_loader_registry():
|
||||
custom_loader_1 = VIDEO_LOADER_REGISTRY.load("test_video_loader_1")
|
||||
output_1 = custom_loader_1.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(output_1, FAKE_OUTPUT_1)
|
||||
|
||||
custom_loader_2 = VIDEO_LOADER_REGISTRY.load("test_video_loader_2")
|
||||
output_2 = custom_loader_2.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(output_2, FAKE_OUTPUT_2)
|
||||
|
||||
|
||||
def test_video_loader_type_doesnt_exist():
|
||||
with pytest.raises(AssertionError):
|
||||
VIDEO_LOADER_REGISTRY.load("non_existing_video_loader")
|
||||
11
vllm/envs.py
11
vllm/envs.py
@ -55,6 +55,7 @@ if TYPE_CHECKING:
|
||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
|
||||
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
||||
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
||||
VLLM_MM_INPUT_CACHE_GIB: int = 8
|
||||
VLLM_TARGET_DEVICE: str = "cuda"
|
||||
MAX_JOBS: Optional[str] = None
|
||||
@ -446,6 +447,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
|
||||
|
||||
# Backend for Video IO
|
||||
# - "opencv": Default backend that uses OpenCV stream buffered backend.
|
||||
#
|
||||
# Custom backend implementations can be registered
|
||||
# via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and
|
||||
# imported at runtime.
|
||||
# If a non-existing backend is used, an AssertionError will be thrown.
|
||||
"VLLM_VIDEO_LOADER_BACKEND":
|
||||
lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"),
|
||||
|
||||
# Cache size (in GiB) for multimodal input cache
|
||||
# Default is 4 GiB
|
||||
"VLLM_MM_INPUT_CACHE_GIB":
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import base64
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
@ -9,6 +10,8 @@ import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm import envs
|
||||
|
||||
from .base import MediaIO
|
||||
from .image import ImageMediaIO
|
||||
|
||||
@ -48,10 +51,35 @@ def sample_frames_from_video(frames: npt.NDArray,
|
||||
class VideoLoader:
|
||||
|
||||
@classmethod
|
||||
def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||
@abstractmethod
|
||||
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VideoLoaderRegistry:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.name2class: dict[str, type] = {}
|
||||
|
||||
def register(self, name: str):
|
||||
|
||||
def wrap(cls_to_register):
|
||||
self.name2class[name] = cls_to_register
|
||||
return cls_to_register
|
||||
|
||||
return wrap
|
||||
|
||||
@staticmethod
|
||||
def load(cls_name: str) -> VideoLoader:
|
||||
cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name)
|
||||
assert cls is not None, f"VideoLoader class {cls_name} not found"
|
||||
return cls()
|
||||
|
||||
|
||||
VIDEO_LOADER_REGISTRY = VideoLoaderRegistry()
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv")
|
||||
class OpenCVVideoBackend(VideoLoader):
|
||||
|
||||
def get_cv2_video_api(self):
|
||||
@ -122,7 +150,8 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
|
||||
|
||||
self.image_io = image_io
|
||||
self.num_frames = num_frames
|
||||
self.video_loader = OpenCVVideoBackend
|
||||
video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
|
||||
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
|
||||
|
||||
def load_bytes(self, data: bytes) -> npt.NDArray:
|
||||
return self.video_loader.load_bytes(data, self.num_frames)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user