From 18e519ec8640ef66b70bb1b3ceb23e0bb883de0b Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 19 Jul 2025 17:17:16 +0800 Subject: [PATCH] [Bugfix] Fix ndarray video color from VideoAsset (#21064) Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/multimodal/test_video.py | 103 +++++++++++++++++++++++++-------- tests/multimodal/utils.py | 46 +++++++++++++++ vllm/assets/video.py | 9 ++- 3 files changed, 130 insertions(+), 28 deletions(-) diff --git a/tests/multimodal/test_video.py b/tests/multimodal/test_video.py index 897c9c33461ac..05b7b84be7f34 100644 --- a/tests/multimodal/test_video.py +++ b/tests/multimodal/test_video.py @@ -1,14 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import tempfile +from pathlib import Path + import numpy as np import numpy.typing as npt import pytest +from PIL import Image -from vllm import envs +from vllm.assets.base import get_vllm_public_assets +from vllm.assets.video import video_to_ndarrays, video_to_pil_images_list from vllm.multimodal.image import ImageMediaIO from vllm.multimodal.video import (VIDEO_LOADER_REGISTRY, VideoLoader, VideoMediaIO) +from .utils import cosine_similarity, create_video_from_image, normalize_image + NUM_FRAMES = 10 FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3) FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3) @@ -59,30 +67,79 @@ class Assert10Frames1FPSVideoLoader(VideoLoader): return FAKE_OUTPUT_2 -def test_video_media_io_kwargs(): - envs.VLLM_VIDEO_LOADER_BACKEND = "assert_10_frames_1_fps" - imageio = ImageMediaIO() +def test_video_media_io_kwargs(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + m.setenv("VLLM_VIDEO_LOADER_BACKEND", "assert_10_frames_1_fps") + imageio = ImageMediaIO() - # Verify that different args pass/fail assertions as expected. - videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0}) - _ = videoio.load_bytes(b"test") - - videoio = VideoMediaIO( - imageio, **{ - "num_frames": 10, - "fps": 1.0, - "not_used": "not_used" - }) - _ = videoio.load_bytes(b"test") - - with pytest.raises(AssertionError, match="bad num_frames"): - videoio = VideoMediaIO(imageio, **{}) + # Verify that different args pass/fail assertions as expected. + videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0}) _ = videoio.load_bytes(b"test") - with pytest.raises(AssertionError, match="bad num_frames"): - videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0}) + videoio = VideoMediaIO( + imageio, **{ + "num_frames": 10, + "fps": 1.0, + "not_used": "not_used" + }) _ = videoio.load_bytes(b"test") - with pytest.raises(AssertionError, match="bad fps"): - videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0}) - _ = videoio.load_bytes(b"test") + with pytest.raises(AssertionError, match="bad num_frames"): + videoio = VideoMediaIO(imageio, **{}) + _ = videoio.load_bytes(b"test") + + with pytest.raises(AssertionError, match="bad num_frames"): + videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0}) + _ = videoio.load_bytes(b"test") + + with pytest.raises(AssertionError, match="bad fps"): + videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0}) + _ = videoio.load_bytes(b"test") + + +@pytest.mark.parametrize("is_color", [True, False]) +@pytest.mark.parametrize("fourcc, ext", [("mp4v", "mp4"), ("XVID", "avi")]) +def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str): + """ + Test all functions that use OpenCV for video I/O return RGB format. + Both RGB and grayscale videos are tested. + """ + image_path = get_vllm_public_assets(filename="stop_sign.jpg", + s3_prefix="vision_model_images") + image = Image.open(image_path) + with tempfile.TemporaryDirectory() as tmpdir: + if not is_color: + image_path = f"{tmpdir}/test_grayscale_image.png" + image = image.convert("L") + image.save(image_path) + # Convert to gray RGB for comparison + image = image.convert("RGB") + video_path = f"{tmpdir}/test_RGB_video.{ext}" + create_video_from_image( + image_path, + video_path, + num_frames=2, + is_color=is_color, + fourcc=fourcc, + ) + + frames = video_to_ndarrays(video_path) + for frame in frames: + sim = cosine_similarity(normalize_image(np.array(frame)), + normalize_image(np.array(image))) + assert np.sum(np.isnan(sim)) / sim.size < 0.001 + assert np.nanmean(sim) > 0.99 + + pil_frames = video_to_pil_images_list(video_path) + for frame in pil_frames: + sim = cosine_similarity(normalize_image(np.array(frame)), + normalize_image(np.array(image))) + assert np.sum(np.isnan(sim)) / sim.size < 0.001 + assert np.nanmean(sim) > 0.99 + + io_frames, _ = VideoMediaIO(ImageMediaIO()).load_file(Path(video_path)) + for frame in io_frames: + sim = cosine_similarity(normalize_image(np.array(frame)), + normalize_image(np.array(image))) + assert np.sum(np.isnan(sim)) / sim.size < 0.001 + assert np.nanmean(sim) > 0.99 diff --git a/tests/multimodal/utils.py b/tests/multimodal/utils.py index 23346509a06fd..9a58292f9f4a5 100644 --- a/tests/multimodal/utils.py +++ b/tests/multimodal/utils.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import cv2 import numpy as np +import numpy.typing as npt from PIL import Image @@ -31,3 +33,47 @@ def random_audio( ): audio_len = rng.randint(min_len, max_len) return rng.rand(audio_len), sr + + +def create_video_from_image( + image_path: str, + video_path: str, + num_frames: int = 10, + fps: float = 1.0, + is_color: bool = True, + fourcc: str = "mp4v", +): + image = cv2.imread(image_path) + if not is_color: + # Convert to grayscale if is_color is False + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + height, width = image.shape + else: + height, width, _ = image.shape + + video_writer = cv2.VideoWriter( + video_path, + cv2.VideoWriter_fourcc(*fourcc), + fps, + (width, height), + isColor=is_color, + ) + + for _ in range(num_frames): + video_writer.write(image) + + video_writer.release() + return video_path + + +def cosine_similarity(A: npt.NDArray, + B: npt.NDArray, + axis: int = -1) -> npt.NDArray: + """Compute cosine similarity between two vectors.""" + return (np.sum(A * B, axis=axis) / + (np.linalg.norm(A, axis=axis) * np.linalg.norm(B, axis=axis))) + + +def normalize_image(image: npt.NDArray) -> npt.NDArray: + """Normalize image to [0, 1] range.""" + return image.astype(np.float32) / 255.0 \ No newline at end of file diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 16412121cf0a8..8ab0e9760be87 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -59,7 +59,9 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: if idx in frame_indices: # only decompress needed ret, frame = cap.retrieve() if ret: - frames.append(frame) + # OpenCV uses BGR format, we need to convert it to RGB + # for PIL and transformers compatibility + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) frames = np.stack(frames) if len(frames) < num_frames: @@ -71,10 +73,7 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: def video_to_pil_images_list(path: str, num_frames: int = -1) -> list[Image.Image]: frames = video_to_ndarrays(path, num_frames) - return [ - Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - for frame in frames - ] + return [Image.fromarray(frame) for frame in frames] def video_get_metadata(path: str) -> dict[str, Any]: