GLM-V video segmentation solution adjustment (#28941)

Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
This commit is contained in:
Yuxuan Zhang 2025-11-20 01:32:55 +08:00 committed by GitHub
parent a8b70304d6
commit 0c80efd94f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -37,7 +37,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from transformers import BatchFeature, Glm4vProcessor
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
from transformers.models.glm4v.image_processing_glm4v import (
Glm4vImageProcessor,
@ -1028,7 +1028,7 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
return max(max_frames_per_video, 1)
def _get_video_second_idx(
def _get_video_second_idx_glm4v(
self, metadata: dict[str, Any], total_frames: int
) -> list[int]:
video_processor = self.get_video_processor()
@ -1079,6 +1079,83 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
selected_timestamps.append(timestamps_list[idx])
return selected_timestamps
def _get_video_second_idx_glm46v(
self, metadata: dict[str, Any], total_frames: int
) -> list[int]:
video_processor = self.get_video_processor()
video_fps = metadata["fps"]
meta_frames = metadata.get("total_num_frames", total_frames)
max_frame_idx = meta_frames - 1
duration = metadata.get("duration", round(max_frame_idx / video_fps) + 1)
do_sample_frames = metadata.get("do_sample_frames", True)
if not do_sample_frames:
frame_indices = metadata["frames_indices"]
else:
DYNAMIC_FPS_THRES = {30: 3, 300: 1, 2400: 0.5}
MAX_FRAME_COUNT_DYNAMIC = 640
MAX_DURATION = 2400
effective_duration = min(duration, MAX_DURATION)
if effective_duration <= 30:
target_fps = DYNAMIC_FPS_THRES[30]
elif effective_duration <= 300:
target_fps = DYNAMIC_FPS_THRES[300]
else:
target_fps = DYNAMIC_FPS_THRES[2400]
temporal_patch_size = getattr(video_processor, "temporal_patch_size", 1)
extract_t = int(effective_duration * target_fps * temporal_patch_size)
extract_t = min(extract_t, MAX_FRAME_COUNT_DYNAMIC)
duration_per_frame = 1 / video_fps
timestamps = [i * duration_per_frame for i in range(meta_frames)]
max_second = int(duration)
if meta_frames < extract_t:
frame_indices = np.linspace(
0, meta_frames - 1, extract_t, dtype=int
).tolist()
else:
frame_indices = []
current_second = 0.0
inv_fps = 1 / (temporal_patch_size * target_fps)
for frame_index in range(meta_frames):
if timestamps[frame_index] >= current_second:
current_second += inv_fps
frame_indices.append(frame_index)
if current_second >= max_second:
break
if len(frame_indices) < extract_t:
if len(frame_indices) == 0:
start, end = 0, max(meta_frames - 1, 0)
else:
start, end = frame_indices[0], frame_indices[-1]
frame_indices = np.linspace(start, end, extract_t, dtype=int).tolist()
elif len(frame_indices) > extract_t:
frame_indices = np.linspace(
0, meta_frames - 1, extract_t, dtype=int
).tolist()
seen, uniq = set(), []
for idx in frame_indices:
if idx not in seen:
seen.add(idx)
uniq.append(idx)
if len(uniq) & 1:
uniq.append(uniq[-1])
frame_indices = uniq
full_second_idxs = [int(idx / video_fps) for idx in frame_indices]
timestamps_list = full_second_idxs[::2]
selected_timestamps = []
for idx in range(len(timestamps_list)):
selected_timestamps.append(timestamps_list[idx])
return selected_timestamps
def _construct_video_placeholder(
self,
video_array: np.ndarray,
@ -1097,9 +1174,18 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
merge_length = image_processor.merge_size**2
assert isinstance(grid_thw, torch.Tensor)
timestamps = self._get_video_second_idx(metadata, len(video_array))
timestamps = (
self._get_video_second_idx_glm4v(metadata, len(video_array))
if isinstance(hf_processor, Glm4vProcessor)
else self._get_video_second_idx_glm46v(metadata, len(video_array))
)
timestamp_format = (
"{}" if isinstance(hf_processor, Glm4vProcessor) else "{:.1f} seconds"
)
frames_idx_token = [
tokenizer.encode(str(i), add_special_tokens=False) for i in timestamps
tokenizer.encode(timestamp_format.format(i), add_special_tokens=False)
for i in timestamps
]
T, H, W = grid_thw
num_tokens_per_frame = int(H * W) // merge_length