[Bugfix] Fix ndarray video color from VideoAsset (#21064)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-07-19 17:17:16 +08:00 committed by GitHub
parent 1eaff27815
commit 18e519ec86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 130 additions and 28 deletions

View File

@ -1,14 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from pathlib import Path
import numpy as np
import numpy.typing as npt
import pytest
from PIL import Image
from vllm import envs
from vllm.assets.base import get_vllm_public_assets
from vllm.assets.video import video_to_ndarrays, video_to_pil_images_list
from vllm.multimodal.image import ImageMediaIO
from vllm.multimodal.video import (VIDEO_LOADER_REGISTRY, VideoLoader,
VideoMediaIO)
from .utils import cosine_similarity, create_video_from_image, normalize_image
NUM_FRAMES = 10
FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
@ -59,30 +67,79 @@ class Assert10Frames1FPSVideoLoader(VideoLoader):
return FAKE_OUTPUT_2
def test_video_media_io_kwargs():
envs.VLLM_VIDEO_LOADER_BACKEND = "assert_10_frames_1_fps"
imageio = ImageMediaIO()
def test_video_media_io_kwargs(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "assert_10_frames_1_fps")
imageio = ImageMediaIO()
# Verify that different args pass/fail assertions as expected.
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0})
_ = videoio.load_bytes(b"test")
videoio = VideoMediaIO(
imageio, **{
"num_frames": 10,
"fps": 1.0,
"not_used": "not_used"
})
_ = videoio.load_bytes(b"test")
with pytest.raises(AssertionError, match="bad num_frames"):
videoio = VideoMediaIO(imageio, **{})
# Verify that different args pass/fail assertions as expected.
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0})
_ = videoio.load_bytes(b"test")
with pytest.raises(AssertionError, match="bad num_frames"):
videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0})
videoio = VideoMediaIO(
imageio, **{
"num_frames": 10,
"fps": 1.0,
"not_used": "not_used"
})
_ = videoio.load_bytes(b"test")
with pytest.raises(AssertionError, match="bad fps"):
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0})
_ = videoio.load_bytes(b"test")
with pytest.raises(AssertionError, match="bad num_frames"):
videoio = VideoMediaIO(imageio, **{})
_ = videoio.load_bytes(b"test")
with pytest.raises(AssertionError, match="bad num_frames"):
videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0})
_ = videoio.load_bytes(b"test")
with pytest.raises(AssertionError, match="bad fps"):
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0})
_ = videoio.load_bytes(b"test")
@pytest.mark.parametrize("is_color", [True, False])
@pytest.mark.parametrize("fourcc, ext", [("mp4v", "mp4"), ("XVID", "avi")])
def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str):
"""
Test all functions that use OpenCV for video I/O return RGB format.
Both RGB and grayscale videos are tested.
"""
image_path = get_vllm_public_assets(filename="stop_sign.jpg",
s3_prefix="vision_model_images")
image = Image.open(image_path)
with tempfile.TemporaryDirectory() as tmpdir:
if not is_color:
image_path = f"{tmpdir}/test_grayscale_image.png"
image = image.convert("L")
image.save(image_path)
# Convert to gray RGB for comparison
image = image.convert("RGB")
video_path = f"{tmpdir}/test_RGB_video.{ext}"
create_video_from_image(
image_path,
video_path,
num_frames=2,
is_color=is_color,
fourcc=fourcc,
)
frames = video_to_ndarrays(video_path)
for frame in frames:
sim = cosine_similarity(normalize_image(np.array(frame)),
normalize_image(np.array(image)))
assert np.sum(np.isnan(sim)) / sim.size < 0.001
assert np.nanmean(sim) > 0.99
pil_frames = video_to_pil_images_list(video_path)
for frame in pil_frames:
sim = cosine_similarity(normalize_image(np.array(frame)),
normalize_image(np.array(image)))
assert np.sum(np.isnan(sim)) / sim.size < 0.001
assert np.nanmean(sim) > 0.99
io_frames, _ = VideoMediaIO(ImageMediaIO()).load_file(Path(video_path))
for frame in io_frames:
sim = cosine_similarity(normalize_image(np.array(frame)),
normalize_image(np.array(image)))
assert np.sum(np.isnan(sim)) / sim.size < 0.001
assert np.nanmean(sim) > 0.99

View File

@ -1,7 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import cv2
import numpy as np
import numpy.typing as npt
from PIL import Image
@ -31,3 +33,47 @@ def random_audio(
):
audio_len = rng.randint(min_len, max_len)
return rng.rand(audio_len), sr
def create_video_from_image(
image_path: str,
video_path: str,
num_frames: int = 10,
fps: float = 1.0,
is_color: bool = True,
fourcc: str = "mp4v",
):
image = cv2.imread(image_path)
if not is_color:
# Convert to grayscale if is_color is False
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
height, width = image.shape
else:
height, width, _ = image.shape
video_writer = cv2.VideoWriter(
video_path,
cv2.VideoWriter_fourcc(*fourcc),
fps,
(width, height),
isColor=is_color,
)
for _ in range(num_frames):
video_writer.write(image)
video_writer.release()
return video_path
def cosine_similarity(A: npt.NDArray,
B: npt.NDArray,
axis: int = -1) -> npt.NDArray:
"""Compute cosine similarity between two vectors."""
return (np.sum(A * B, axis=axis) /
(np.linalg.norm(A, axis=axis) * np.linalg.norm(B, axis=axis)))
def normalize_image(image: npt.NDArray) -> npt.NDArray:
"""Normalize image to [0, 1] range."""
return image.astype(np.float32) / 255.0

View File

@ -59,7 +59,9 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
if idx in frame_indices: # only decompress needed
ret, frame = cap.retrieve()
if ret:
frames.append(frame)
# OpenCV uses BGR format, we need to convert it to RGB
# for PIL and transformers compatibility
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frames = np.stack(frames)
if len(frames) < num_frames:
@ -71,10 +73,7 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
def video_to_pil_images_list(path: str,
num_frames: int = -1) -> list[Image.Image]:
frames = video_to_ndarrays(path, num_frames)
return [
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
for frame in frames
]
return [Image.fromarray(frame) for frame in frames]
def video_get_metadata(path: str) -> dict[str, Any]: