[Misc] Small: Fix video loader return type annotations. (#20389)

Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
This commit is contained in:
Chenheli Hua 2025-07-02 20:10:39 -07:00 committed by GitHub
parent 2e25bb12a8
commit b616f6a53d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 11 deletions

View File

@ -172,9 +172,10 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
"num_frames": num_frames, "num_frames": num_frames,
}}) }})
video_sync = connector.fetch_video(video_url) video_sync, metadata_sync = connector.fetch_video(video_url)
video_async = await connector.fetch_video_async(video_url) video_async, metadata_async = await connector.fetch_video_async(video_url)
assert np.array_equal(video_sync[0], video_async[0]) assert np.array_equal(video_sync, video_async)
assert metadata_sync == metadata_async
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`. # Used for the next two tests related to `merge_and_sort_multimodal_metadata`.

View File

@ -228,7 +228,7 @@ class MediaConnector:
video_url: str, video_url: str,
*, *,
image_mode: str = "RGB", image_mode: str = "RGB",
) -> npt.NDArray: ) -> tuple[npt.NDArray, dict[str, Any]]:
""" """
Load video from a HTTP or base64 data URL. Load video from a HTTP or base64 data URL.
""" """
@ -248,7 +248,7 @@ class MediaConnector:
video_url: str, video_url: str,
*, *,
image_mode: str = "RGB", image_mode: str = "RGB",
) -> npt.NDArray: ) -> tuple[npt.NDArray, dict[str, Any]]:
""" """
Asynchronously load video from a HTTP or base64 data URL. Asynchronously load video from a HTTP or base64 data URL.

View File

@ -6,6 +6,7 @@ from abc import abstractmethod
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Any
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -57,7 +58,7 @@ class VideoLoader:
def load_bytes(cls, def load_bytes(cls,
data: bytes, data: bytes,
num_frames: int = -1, num_frames: int = -1,
**kwargs) -> npt.NDArray: **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:
raise NotImplementedError raise NotImplementedError
@ -106,7 +107,7 @@ class OpenCVVideoBackend(VideoLoader):
def load_bytes(cls, def load_bytes(cls,
data: bytes, data: bytes,
num_frames: int = -1, num_frames: int = -1,
**kwargs) -> npt.NDArray: **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:
import cv2 import cv2
backend = cls().get_cv2_video_api() backend = cls().get_cv2_video_api()
@ -179,12 +180,13 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend) self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
def load_bytes(self, data: bytes) -> npt.NDArray: def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
return self.video_loader.load_bytes(data, return self.video_loader.load_bytes(data,
num_frames=self.num_frames, num_frames=self.num_frames,
**self.kwargs) **self.kwargs)
def load_base64(self, media_type: str, data: str) -> npt.NDArray: def load_base64(self, media_type: str,
data: str) -> tuple[npt.NDArray, dict[str, Any]]:
if media_type.lower() == "video/jpeg": if media_type.lower() == "video/jpeg":
load_frame = partial( load_frame = partial(
self.image_io.load_base64, self.image_io.load_base64,
@ -194,11 +196,11 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
return np.stack([ return np.stack([
np.asarray(load_frame(frame_data)) np.asarray(load_frame(frame_data))
for frame_data in data.split(",") for frame_data in data.split(",")
]) ]), {}
return self.load_bytes(base64.b64decode(data)) return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> npt.NDArray: def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:
with filepath.open("rb") as f: with filepath.open("rb") as f:
data = f.read() data = f.read()