mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 12:07:12 +08:00
GLM-V video segmentation solution adjustment (#28941)
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
This commit is contained in:
parent
a8b70304d6
commit
0c80efd94f
@ -37,7 +37,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature
|
||||
from transformers import BatchFeature, Glm4vProcessor
|
||||
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
|
||||
from transformers.models.glm4v.image_processing_glm4v import (
|
||||
Glm4vImageProcessor,
|
||||
@ -1028,7 +1028,7 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
def _get_video_second_idx(
|
||||
def _get_video_second_idx_glm4v(
|
||||
self, metadata: dict[str, Any], total_frames: int
|
||||
) -> list[int]:
|
||||
video_processor = self.get_video_processor()
|
||||
@ -1079,6 +1079,83 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
|
||||
selected_timestamps.append(timestamps_list[idx])
|
||||
return selected_timestamps
|
||||
|
||||
def _get_video_second_idx_glm46v(
|
||||
self, metadata: dict[str, Any], total_frames: int
|
||||
) -> list[int]:
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
video_fps = metadata["fps"]
|
||||
meta_frames = metadata.get("total_num_frames", total_frames)
|
||||
max_frame_idx = meta_frames - 1
|
||||
duration = metadata.get("duration", round(max_frame_idx / video_fps) + 1)
|
||||
|
||||
do_sample_frames = metadata.get("do_sample_frames", True)
|
||||
if not do_sample_frames:
|
||||
frame_indices = metadata["frames_indices"]
|
||||
else:
|
||||
DYNAMIC_FPS_THRES = {30: 3, 300: 1, 2400: 0.5}
|
||||
MAX_FRAME_COUNT_DYNAMIC = 640
|
||||
MAX_DURATION = 2400
|
||||
|
||||
effective_duration = min(duration, MAX_DURATION)
|
||||
if effective_duration <= 30:
|
||||
target_fps = DYNAMIC_FPS_THRES[30]
|
||||
elif effective_duration <= 300:
|
||||
target_fps = DYNAMIC_FPS_THRES[300]
|
||||
else:
|
||||
target_fps = DYNAMIC_FPS_THRES[2400]
|
||||
|
||||
temporal_patch_size = getattr(video_processor, "temporal_patch_size", 1)
|
||||
extract_t = int(effective_duration * target_fps * temporal_patch_size)
|
||||
extract_t = min(extract_t, MAX_FRAME_COUNT_DYNAMIC)
|
||||
|
||||
duration_per_frame = 1 / video_fps
|
||||
timestamps = [i * duration_per_frame for i in range(meta_frames)]
|
||||
max_second = int(duration)
|
||||
|
||||
if meta_frames < extract_t:
|
||||
frame_indices = np.linspace(
|
||||
0, meta_frames - 1, extract_t, dtype=int
|
||||
).tolist()
|
||||
else:
|
||||
frame_indices = []
|
||||
current_second = 0.0
|
||||
inv_fps = 1 / (temporal_patch_size * target_fps)
|
||||
for frame_index in range(meta_frames):
|
||||
if timestamps[frame_index] >= current_second:
|
||||
current_second += inv_fps
|
||||
frame_indices.append(frame_index)
|
||||
if current_second >= max_second:
|
||||
break
|
||||
|
||||
if len(frame_indices) < extract_t:
|
||||
if len(frame_indices) == 0:
|
||||
start, end = 0, max(meta_frames - 1, 0)
|
||||
else:
|
||||
start, end = frame_indices[0], frame_indices[-1]
|
||||
frame_indices = np.linspace(start, end, extract_t, dtype=int).tolist()
|
||||
elif len(frame_indices) > extract_t:
|
||||
frame_indices = np.linspace(
|
||||
0, meta_frames - 1, extract_t, dtype=int
|
||||
).tolist()
|
||||
|
||||
seen, uniq = set(), []
|
||||
for idx in frame_indices:
|
||||
if idx not in seen:
|
||||
seen.add(idx)
|
||||
uniq.append(idx)
|
||||
|
||||
if len(uniq) & 1:
|
||||
uniq.append(uniq[-1])
|
||||
|
||||
frame_indices = uniq
|
||||
full_second_idxs = [int(idx / video_fps) for idx in frame_indices]
|
||||
timestamps_list = full_second_idxs[::2]
|
||||
selected_timestamps = []
|
||||
for idx in range(len(timestamps_list)):
|
||||
selected_timestamps.append(timestamps_list[idx])
|
||||
return selected_timestamps
|
||||
|
||||
def _construct_video_placeholder(
|
||||
self,
|
||||
video_array: np.ndarray,
|
||||
@ -1097,9 +1174,18 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
|
||||
merge_length = image_processor.merge_size**2
|
||||
|
||||
assert isinstance(grid_thw, torch.Tensor)
|
||||
timestamps = self._get_video_second_idx(metadata, len(video_array))
|
||||
timestamps = (
|
||||
self._get_video_second_idx_glm4v(metadata, len(video_array))
|
||||
if isinstance(hf_processor, Glm4vProcessor)
|
||||
else self._get_video_second_idx_glm46v(metadata, len(video_array))
|
||||
)
|
||||
|
||||
timestamp_format = (
|
||||
"{}" if isinstance(hf_processor, Glm4vProcessor) else "{:.1f} seconds"
|
||||
)
|
||||
frames_idx_token = [
|
||||
tokenizer.encode(str(i), add_special_tokens=False) for i in timestamps
|
||||
tokenizer.encode(timestamp_format.format(i), add_special_tokens=False)
|
||||
for i in timestamps
|
||||
]
|
||||
T, H, W = grid_thw
|
||||
num_tokens_per_frame = int(H * W) // merge_length
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user