mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 09:15:01 +08:00
Support custom implementations of VideoLoader backends. (#18091)
This commit is contained in:
parent
e6b8e65d2d
commit
4f07a64075
41
tests/multimodal/test_video.py
Normal file
41
tests/multimodal/test_video.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader
|
||||||
|
|
||||||
|
NUM_FRAMES = 10
|
||||||
|
FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
|
||||||
|
FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
|
||||||
|
|
||||||
|
|
||||||
|
@VIDEO_LOADER_REGISTRY.register("test_video_loader_1")
|
||||||
|
class TestVideoLoader1(VideoLoader):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||||
|
return FAKE_OUTPUT_1
|
||||||
|
|
||||||
|
|
||||||
|
@VIDEO_LOADER_REGISTRY.register("test_video_loader_2")
|
||||||
|
class TestVideoLoader2(VideoLoader):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||||
|
return FAKE_OUTPUT_2
|
||||||
|
|
||||||
|
|
||||||
|
def test_video_loader_registry():
|
||||||
|
custom_loader_1 = VIDEO_LOADER_REGISTRY.load("test_video_loader_1")
|
||||||
|
output_1 = custom_loader_1.load_bytes(b"test")
|
||||||
|
np.testing.assert_array_equal(output_1, FAKE_OUTPUT_1)
|
||||||
|
|
||||||
|
custom_loader_2 = VIDEO_LOADER_REGISTRY.load("test_video_loader_2")
|
||||||
|
output_2 = custom_loader_2.load_bytes(b"test")
|
||||||
|
np.testing.assert_array_equal(output_2, FAKE_OUTPUT_2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_video_loader_type_doesnt_exist():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
VIDEO_LOADER_REGISTRY.load("non_existing_video_loader")
|
||||||
11
vllm/envs.py
11
vllm/envs.py
@ -55,6 +55,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||||
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
|
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
|
||||||
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
||||||
|
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
||||||
VLLM_MM_INPUT_CACHE_GIB: int = 8
|
VLLM_MM_INPUT_CACHE_GIB: int = 8
|
||||||
VLLM_TARGET_DEVICE: str = "cuda"
|
VLLM_TARGET_DEVICE: str = "cuda"
|
||||||
MAX_JOBS: Optional[str] = None
|
MAX_JOBS: Optional[str] = None
|
||||||
@ -446,6 +447,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_AUDIO_FETCH_TIMEOUT":
|
"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||||
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
|
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
|
||||||
|
|
||||||
|
# Backend for Video IO
|
||||||
|
# - "opencv": Default backend that uses OpenCV stream buffered backend.
|
||||||
|
#
|
||||||
|
# Custom backend implementations can be registered
|
||||||
|
# via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and
|
||||||
|
# imported at runtime.
|
||||||
|
# If a non-existing backend is used, an AssertionError will be thrown.
|
||||||
|
"VLLM_VIDEO_LOADER_BACKEND":
|
||||||
|
lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"),
|
||||||
|
|
||||||
# Cache size (in GiB) for multimodal input cache
|
# Cache size (in GiB) for multimodal input cache
|
||||||
# Default is 4 GiB
|
# Default is 4 GiB
|
||||||
"VLLM_MM_INPUT_CACHE_GIB":
|
"VLLM_MM_INPUT_CACHE_GIB":
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
from abc import abstractmethod
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -9,6 +10,8 @@ import numpy as np
|
|||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
|
||||||
from .base import MediaIO
|
from .base import MediaIO
|
||||||
from .image import ImageMediaIO
|
from .image import ImageMediaIO
|
||||||
|
|
||||||
@ -48,10 +51,35 @@ def sample_frames_from_video(frames: npt.NDArray,
|
|||||||
class VideoLoader:
|
class VideoLoader:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
@abstractmethod
|
||||||
|
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class VideoLoaderRegistry:
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.name2class: dict[str, type] = {}
|
||||||
|
|
||||||
|
def register(self, name: str):
|
||||||
|
|
||||||
|
def wrap(cls_to_register):
|
||||||
|
self.name2class[name] = cls_to_register
|
||||||
|
return cls_to_register
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(cls_name: str) -> VideoLoader:
|
||||||
|
cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name)
|
||||||
|
assert cls is not None, f"VideoLoader class {cls_name} not found"
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
VIDEO_LOADER_REGISTRY = VideoLoaderRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
@VIDEO_LOADER_REGISTRY.register("opencv")
|
||||||
class OpenCVVideoBackend(VideoLoader):
|
class OpenCVVideoBackend(VideoLoader):
|
||||||
|
|
||||||
def get_cv2_video_api(self):
|
def get_cv2_video_api(self):
|
||||||
@ -122,7 +150,8 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
|
|||||||
|
|
||||||
self.image_io = image_io
|
self.image_io = image_io
|
||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
self.video_loader = OpenCVVideoBackend
|
video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
|
||||||
|
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
|
||||||
|
|
||||||
def load_bytes(self, data: bytes) -> npt.NDArray:
|
def load_bytes(self, data: bytes) -> npt.NDArray:
|
||||||
return self.video_loader.load_bytes(data, self.num_frames)
|
return self.video_loader.load_bytes(data, self.num_frames)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user