Add GLM4.1V model (Draft) (#19331)

Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Yuxuan Zhang 2025-07-01 20:48:26 +08:00 committed by GitHub
parent 650d5dbd04
commit ed70f3c64f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1946 additions and 16 deletions

View File

@ -553,6 +553,7 @@ Specified using `--task generate`.
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `THUDM/GLM-4.1V-9B-Thinkg`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎\* |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | | ✅︎ |
@ -583,7 +584,7 @@ Specified using `--task generate`.
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ |
| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ |
| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ |
<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
&nbsp;&nbsp;&nbsp;&nbsp;• For example, to use DeepSeek-VL2 series models:

View File

@ -248,6 +248,42 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
)
# GLM-4.1V
def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData:
model_name = "THUDM/GLM-4.1V-9B-Thinking"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
mm_processor_kwargs={
"size": {"shortest_edge": 12544, "longest_edge": 47040000},
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
enforce_eager=True,
)
if modality == "image":
placeholder = "<|begin_of_image|><|image|><|end_of_image|>"
elif modality == "video":
placeholder = "<|begin_of_video|><|video|><|end_of_video|>"
prompts = [
(
"[gMASK]<sop><|system|>\nYou are a helpful assistant.<|user|>\n"
f"{placeholder}"
f"{question}<|assistant|>assistant\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# H2OVL-Mississippi
def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@ -1114,6 +1150,7 @@ model_example_map = {
"fuyu": run_fuyu,
"gemma3": run_gemma3,
"glm4v": run_glm4v,
"glm4_1v": run_glm4_1v,
"h2ovl_chat": run_h2ovl,
"idefics3": run_idefics3,
"internvl_chat": run_internvl,
@ -1172,10 +1209,11 @@ def get_multi_modal_input(args):
if args.modality == "video":
# Input video and question
video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays
metadata = VideoAsset(name="baby_reading", num_frames=args.num_frames).metadata
vid_questions = ["Why is this video funny?"]
return {
"data": video,
"data": [(video, metadata)] if args.model_type == "glm4_1v" else video,
"questions": vid_questions,
}

View File

@ -50,7 +50,7 @@ async def client(server):
@pytest.fixture(scope="session")
def base64_encoded_video() -> dict[str, str]:
return {
video_url: encode_video_base64(fetch_video(video_url))
video_url: encode_video_base64(fetch_video(video_url)[0])
for video_url in TEST_VIDEO_URLS
}

View File

@ -309,6 +309,34 @@ VLM_TEST_SETTINGS = {
num_logprobs=10,
marks=[large_gpu_mark(min_gb=32)],
),
"glm4_1v": VLMTestInfo(
models=["THUDM/GLM-4.1V-9B-Thinking"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501
img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", # noqa: E501
video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>", # noqa: E501
max_model_len=2048,
max_num_seqs=2,
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
auto_cls=AutoModelForImageTextToText,
),
"glm4_1v-video": VLMTestInfo(
models=["THUDM/GLM-4.1V-9B-Thinking"],
# GLM4.1V require include video metadata for input
test_type=VLMTestType.CUSTOM_INPUTS,
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
patch_hf_runner=model_utils.glm4_1v_patch_hf_runner,
custom_test_opts=[CustomTestOptions(
inputs=custom_inputs.video_with_metadata_glm4_1v(),
limit_mm_per_prompt={"video": 1},
)],
# This is needed to run on machine with 24GB VRAM
vllm_runner_kwargs={"gpu_memory_utilization": 0.95},
),
"h2ovl": VLMTestInfo(
models = [
"h2oai/h2ovl-mississippi-800m",

View File

@ -129,3 +129,23 @@ def windows_attention_image_qwen2_5_vl():
wrapped_sf = ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=[0.5])
return build_single_image_inputs([image], [prompt], wrapped_sf)
def video_with_metadata_glm4_1v():
video_array = VIDEO_ASSETS[0].np_ndarrays
metadata = VIDEO_ASSETS[0].metadata
question = "Describe the video."
video_prompt = "<|begin_of_video|><|video|><|end_of_video|>"
formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n"
scales = [0.1, 0.2, 0.25]
video_input = [[(rescale_video_size(video_array, scale), metadata)]
for scale in scales]
prompts = [formatted_prompt] * len(video_input)
return [
PromptWithMultiModalInput(
prompts=prompts,
video_data=video_input,
)
]

View File

@ -16,9 +16,11 @@ import torch
from PIL.Image import Image
from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
GenerationConfig, GenerationMixin)
from transformers.video_utils import VideoMetadata
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import patch_padding_side
from vllm.utils import is_list_of
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
from .types import RunnerOutput
@ -373,6 +375,28 @@ def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
def glm4_1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4.1V."""
hf_processor = hf_model.processor
def processor(*args, videos=None, **kwargs):
if videos is not None and is_list_of(videos, tuple):
# If videos is a list of tuples, we assume each tuple contains
# (video_array, metadata) as in the case of GLM4.1V.
video_metadata = [[VideoMetadata(**video[1])] for video in videos]
videos = [[video[0]] for video in videos]
else:
video_metadata = None
return hf_processor(*args,
videos=videos,
video_metadata=video_metadata,
**kwargs)
hf_model.processor = processor
return hf_model
def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for H2OVL."""

View File

@ -24,6 +24,22 @@ from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS
def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
"""
Patch the multimodal data for GLM4.1V model.
"""
# Ensure video metadata is included
if "video" in mm_data:
video = mm_data["video"]
mm_data["video"] = (video, {
"total_num_frames": len(video),
"fps": len(video),
"duration": 1,
"video_backend": "opencv"
})
return mm_data
def _test_processing_correctness(
model_id: str,
hit_rate: float,
@ -154,6 +170,11 @@ _IGNORE_MM_KEYS = {
"ultravox": {"audio_features"},
}
MM_DATA_PATCHES = {
# GLM4.1V requires video metadata to be included in the input
"glm4v": glm4_1v_patch_mm_data,
}
def _test_processing_correctness_one(
model_config: ModelConfig,
@ -166,6 +187,8 @@ def _test_processing_correctness_one(
):
model_type = model_config.hf_config.model_type
ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
if model_type in MM_DATA_PATCHES:
mm_data = MM_DATA_PATCHES[model_type](mm_data)
if isinstance(prompt, str):
text_prompt = prompt
@ -245,6 +268,7 @@ def _test_processing_correctness_one(
"adept/fuyu-8b",
"google/gemma-3-4b-it",
"THUDM/glm-4v-9b",
"THUDM/GLM-4.1V-9B-Thinking",
"ibm-granite/granite-speech-3.3-2b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",

View File

@ -338,6 +338,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501

View File

@ -172,7 +172,9 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
video_sync = connector.fetch_video(video_url, num_frames=num_frames)
video_async = await connector.fetch_video_async(video_url,
num_frames=num_frames)
assert np.array_equal(video_sync, video_async)
# Check that the video frames are equal and metadata are same
assert np.array_equal(video_sync[0], video_async[0])
assert video_sync[1] == video_async[1]
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.

View File

@ -3,7 +3,7 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import ClassVar, Literal, Optional
from typing import Any, ClassVar, Literal, Optional
import cv2
import numpy as np
@ -77,6 +77,24 @@ def video_to_pil_images_list(path: str,
]
def video_get_metadata(path: str) -> dict[str, Any]:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video file {path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames / fps if fps > 0 else 0
metadata = {
"total_num_frames": total_frames,
"fps": fps,
"duration": duration,
"video_backend": "opencv"
}
return metadata
VideoAssetName = Literal["baby_reading"]
@ -105,6 +123,12 @@ class VideoAsset:
ret = video_to_ndarrays(video_path, self.num_frames)
return ret
@property
def metadata(self) -> dict[str, Any]:
video_path = download_video_asset(self.filename)
ret = video_get_metadata(video_path)
return ret
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.

View File

@ -515,6 +515,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if modality in ("image", "image_embeds"):
if model_type == "chatglm":
return "<|begin_of_image|><|endoftext|><|end_of_image|>"
if model_type == "glm4v":
return "<|begin_of_image|><|image|><|end_of_image|>"
if model_type in ("phi3_v", "phi4mm"):
return f"<|image_{current_count}|>"
if model_type in ("minicpmo", "minicpmv"):
@ -563,6 +565,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
elif modality == "video":
if model_type == "internvl_chat":
return "<video>"
if model_type == "glm4v":
return "<|begin_of_video|><|video|><|end_of_video|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":

View File

@ -23,6 +23,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rotary Positional Embeddings."""
import itertools
import math
from typing import Any, Optional, Union
@ -1118,6 +1119,15 @@ class MRotaryEmbedding(RotaryEmbedding):
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
elif "glm4v" in hf_config.model_type:
return cls._glm4v_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
@ -1129,6 +1139,115 @@ class MRotaryEmbedding(RotaryEmbedding):
seq_len=seq_len,
)
@classmethod
def _glm4v_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
for t_idx in range(llm_grid_t):
t_index = torch.tensor(t_idx).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(
1, -1, 1).expand(1, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(
1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) +
st_idx)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
return llm_positions, mrope_position_delta
@classmethod
def _vl_get_input_positions_tensor(
cls,

File diff suppressed because it is too large Load Diff

View File

@ -190,6 +190,7 @@ _MULTIMODAL_MODELS = {
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
"GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),

View File

@ -57,10 +57,12 @@ which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor"]
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor",
tuple[HfVideoItem, dict[str, Any]]]
"""
A `transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace `VideoProcessor`.
A `transformers.video_utils.VideoInput` representing a single video item.
This can be passed to a HuggingFace `VideoProcessor`
with `transformers.video_utils.VideoMetadata`.
Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;

View File

@ -224,8 +224,14 @@ class ImageEmbeddingItems(EmbeddingItems):
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
def __init__(self, data: Sequence[HfVideoItem]) -> None:
def __init__(
self,
data: Sequence[HfVideoItem],
metadata: Optional[Union[dict[str, Any],
list[Optional[dict[str, Any]]]]] = None,
) -> None:
super().__init__(data, "video")
self.metadata = metadata
def get_num_frames(self, item_idx: int) -> int:
return len(self.get(item_idx))
@ -320,6 +326,7 @@ class MultiModalDataParser:
*,
target_sr: Optional[float] = None,
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
video_needs_metadata: bool = False,
) -> None:
super().__init__()
@ -327,6 +334,7 @@ class MultiModalDataParser:
target_sr=target_sr,
method=audio_resample_method,
)
self.video_needs_metadata = video_needs_metadata
def _is_embeddings(
self, data: object
@ -361,6 +369,21 @@ class MultiModalDataParser:
assert_never(audio)
def _get_video_with_metadata(
self,
video: VideoItem,
) -> tuple[np.ndarray, Optional[dict[str, Any]]]:
if isinstance(video, tuple):
return video
if isinstance(video, list):
return np.array(video), None
if isinstance(video, np.ndarray):
return video, None
if isinstance(video, torch.Tensor):
return video.numpy(), None
assert_never(video)
def _parse_audio_data(
self,
data: ModalityData[AudioItem],
@ -433,10 +456,25 @@ class MultiModalDataParser:
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
elif isinstance(data, tuple) and len(data) == 2:
data_items = [data]
else:
data_items = data
return VideoProcessorItems(data_items)
new_videos = list[tuple[np.ndarray, Optional[dict[str, Any]]]]()
metadata_lst: list[Optional[dict[str, Any]]] = []
for data_item in data_items:
video, metadata = self._get_video_with_metadata(data_item)
if self.video_needs_metadata:
new_videos.append((video, metadata))
metadata_lst.append(metadata)
else:
new_videos.append(video)
if not self.video_needs_metadata:
metadata = None
return VideoProcessorItems(new_videos, metadata=metadata_lst)
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
return {

View File

@ -24,6 +24,7 @@ def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
dtype=frames.dtype)
# lazy import cv2 to avoid bothering users who only use text models
import cv2
for i, frame in enumerate(frames):
resized_frame = cv2.resize(frame, (new_width, new_height))
resized_frames[i] = resized_frame
@ -92,14 +93,16 @@ class OpenCVVideoBackend(VideoLoader):
continue
if not vr.isBackendBuiltIn(backend):
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
if (abi < 1 or (abi == 1 and api < 2)):
if abi < 1 or (abi == 1 and api < 2):
continue
api_pref = backend
break
return api_pref
@classmethod
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
def load_bytes(cls,
data: bytes,
num_frames: int = -1) -> tuple[npt.NDArray, dict]:
import cv2
backend = cls().get_cv2_video_api()
@ -108,6 +111,9 @@ class OpenCVVideoBackend(VideoLoader):
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames_num / original_fps if original_fps > 0 else 0
full_read = num_frames == -1 or total_frames_num < num_frames
if full_read:
num_frames = total_frames_num
@ -125,18 +131,27 @@ class OpenCVVideoBackend(VideoLoader):
i = 0
for idx in range(total_frames_num):
ok = cap.grab() # next img
ok = cap.grab()
if not ok:
break
if idx in frame_idx: # only decompress needed
if idx in frame_idx:
ret, frame = cap.retrieve()
if ret:
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
i += 1
# we expect all frames loaded
assert i == num_frames, (f"Expected reading {num_frames} frames, "
f"but only loaded {i} frames from video.")
return frames
# Use transformers transformers.video_utils.VideoMetadata format
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv"
}
return frames, metadata
class VideoMediaIO(MediaIO[npt.NDArray]):