mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 15:34:28 +08:00
[Bugfix] Fix ndarray video color from VideoAsset (#21064)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
1eaff27815
commit
18e519ec86
@ -1,14 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from vllm import envs
|
||||
from vllm.assets.base import get_vllm_public_assets
|
||||
from vllm.assets.video import video_to_ndarrays, video_to_pil_images_list
|
||||
from vllm.multimodal.image import ImageMediaIO
|
||||
from vllm.multimodal.video import (VIDEO_LOADER_REGISTRY, VideoLoader,
|
||||
VideoMediaIO)
|
||||
|
||||
from .utils import cosine_similarity, create_video_from_image, normalize_image
|
||||
|
||||
NUM_FRAMES = 10
|
||||
FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
|
||||
FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
|
||||
@ -59,30 +67,79 @@ class Assert10Frames1FPSVideoLoader(VideoLoader):
|
||||
return FAKE_OUTPUT_2
|
||||
|
||||
|
||||
def test_video_media_io_kwargs():
|
||||
envs.VLLM_VIDEO_LOADER_BACKEND = "assert_10_frames_1_fps"
|
||||
imageio = ImageMediaIO()
|
||||
def test_video_media_io_kwargs(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "assert_10_frames_1_fps")
|
||||
imageio = ImageMediaIO()
|
||||
|
||||
# Verify that different args pass/fail assertions as expected.
|
||||
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
videoio = VideoMediaIO(
|
||||
imageio, **{
|
||||
"num_frames": 10,
|
||||
"fps": 1.0,
|
||||
"not_used": "not_used"
|
||||
})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
with pytest.raises(AssertionError, match="bad num_frames"):
|
||||
videoio = VideoMediaIO(imageio, **{})
|
||||
# Verify that different args pass/fail assertions as expected.
|
||||
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
with pytest.raises(AssertionError, match="bad num_frames"):
|
||||
videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0})
|
||||
videoio = VideoMediaIO(
|
||||
imageio, **{
|
||||
"num_frames": 10,
|
||||
"fps": 1.0,
|
||||
"not_used": "not_used"
|
||||
})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
with pytest.raises(AssertionError, match="bad fps"):
|
||||
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
with pytest.raises(AssertionError, match="bad num_frames"):
|
||||
videoio = VideoMediaIO(imageio, **{})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
with pytest.raises(AssertionError, match="bad num_frames"):
|
||||
videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
with pytest.raises(AssertionError, match="bad fps"):
|
||||
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_color", [True, False])
|
||||
@pytest.mark.parametrize("fourcc, ext", [("mp4v", "mp4"), ("XVID", "avi")])
|
||||
def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str):
|
||||
"""
|
||||
Test all functions that use OpenCV for video I/O return RGB format.
|
||||
Both RGB and grayscale videos are tested.
|
||||
"""
|
||||
image_path = get_vllm_public_assets(filename="stop_sign.jpg",
|
||||
s3_prefix="vision_model_images")
|
||||
image = Image.open(image_path)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
if not is_color:
|
||||
image_path = f"{tmpdir}/test_grayscale_image.png"
|
||||
image = image.convert("L")
|
||||
image.save(image_path)
|
||||
# Convert to gray RGB for comparison
|
||||
image = image.convert("RGB")
|
||||
video_path = f"{tmpdir}/test_RGB_video.{ext}"
|
||||
create_video_from_image(
|
||||
image_path,
|
||||
video_path,
|
||||
num_frames=2,
|
||||
is_color=is_color,
|
||||
fourcc=fourcc,
|
||||
)
|
||||
|
||||
frames = video_to_ndarrays(video_path)
|
||||
for frame in frames:
|
||||
sim = cosine_similarity(normalize_image(np.array(frame)),
|
||||
normalize_image(np.array(image)))
|
||||
assert np.sum(np.isnan(sim)) / sim.size < 0.001
|
||||
assert np.nanmean(sim) > 0.99
|
||||
|
||||
pil_frames = video_to_pil_images_list(video_path)
|
||||
for frame in pil_frames:
|
||||
sim = cosine_similarity(normalize_image(np.array(frame)),
|
||||
normalize_image(np.array(image)))
|
||||
assert np.sum(np.isnan(sim)) / sim.size < 0.001
|
||||
assert np.nanmean(sim) > 0.99
|
||||
|
||||
io_frames, _ = VideoMediaIO(ImageMediaIO()).load_file(Path(video_path))
|
||||
for frame in io_frames:
|
||||
sim = cosine_similarity(normalize_image(np.array(frame)),
|
||||
normalize_image(np.array(image)))
|
||||
assert np.sum(np.isnan(sim)) / sim.size < 0.001
|
||||
assert np.nanmean(sim) > 0.99
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@ -31,3 +33,47 @@ def random_audio(
|
||||
):
|
||||
audio_len = rng.randint(min_len, max_len)
|
||||
return rng.rand(audio_len), sr
|
||||
|
||||
|
||||
def create_video_from_image(
|
||||
image_path: str,
|
||||
video_path: str,
|
||||
num_frames: int = 10,
|
||||
fps: float = 1.0,
|
||||
is_color: bool = True,
|
||||
fourcc: str = "mp4v",
|
||||
):
|
||||
image = cv2.imread(image_path)
|
||||
if not is_color:
|
||||
# Convert to grayscale if is_color is False
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
height, width = image.shape
|
||||
else:
|
||||
height, width, _ = image.shape
|
||||
|
||||
video_writer = cv2.VideoWriter(
|
||||
video_path,
|
||||
cv2.VideoWriter_fourcc(*fourcc),
|
||||
fps,
|
||||
(width, height),
|
||||
isColor=is_color,
|
||||
)
|
||||
|
||||
for _ in range(num_frames):
|
||||
video_writer.write(image)
|
||||
|
||||
video_writer.release()
|
||||
return video_path
|
||||
|
||||
|
||||
def cosine_similarity(A: npt.NDArray,
|
||||
B: npt.NDArray,
|
||||
axis: int = -1) -> npt.NDArray:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
return (np.sum(A * B, axis=axis) /
|
||||
(np.linalg.norm(A, axis=axis) * np.linalg.norm(B, axis=axis)))
|
||||
|
||||
|
||||
def normalize_image(image: npt.NDArray) -> npt.NDArray:
|
||||
"""Normalize image to [0, 1] range."""
|
||||
return image.astype(np.float32) / 255.0
|
||||
@ -59,7 +59,9 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||
if idx in frame_indices: # only decompress needed
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
frames.append(frame)
|
||||
# OpenCV uses BGR format, we need to convert it to RGB
|
||||
# for PIL and transformers compatibility
|
||||
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
|
||||
frames = np.stack(frames)
|
||||
if len(frames) < num_frames:
|
||||
@ -71,10 +73,7 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||
def video_to_pil_images_list(path: str,
|
||||
num_frames: int = -1) -> list[Image.Image]:
|
||||
frames = video_to_ndarrays(path, num_frames)
|
||||
return [
|
||||
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
for frame in frames
|
||||
]
|
||||
return [Image.fromarray(frame) for frame in frames]
|
||||
|
||||
|
||||
def video_get_metadata(path: str) -> dict[str, Any]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user