mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-09 21:04:25 +08:00
206 lines
7.7 KiB
Python
206 lines
7.7 KiB
Python
import argparse
|
|
import os
|
|
|
|
import imageio
|
|
import torch
|
|
import torchvision.transforms.functional as F
|
|
import tqdm
|
|
from calculate_lpips import calculate_lpips
|
|
from calculate_psnr import calculate_psnr
|
|
from calculate_ssim import calculate_ssim
|
|
|
|
|
|
def load_video(video_path):
|
|
"""
|
|
Load a video from the given path and convert it to a PyTorch tensor.
|
|
"""
|
|
# Read the video using imageio
|
|
reader = imageio.get_reader(video_path, "ffmpeg")
|
|
|
|
# Extract frames and convert to a list of tensors
|
|
frames = []
|
|
for frame in reader:
|
|
# Convert the frame to a tensor and permute the dimensions to match (C, H, W)
|
|
frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
|
|
frames.append(frame_tensor)
|
|
|
|
# Stack the list of tensors into a single tensor with shape (T, C, H, W)
|
|
video_tensor = torch.stack(frames)
|
|
|
|
return video_tensor
|
|
|
|
|
|
def resize_video(video, target_height, target_width):
|
|
resized_frames = []
|
|
for frame in video:
|
|
resized_frame = F.resize(frame, [target_height, target_width])
|
|
resized_frames.append(resized_frame)
|
|
return torch.stack(resized_frames)
|
|
|
|
|
|
def resize_gt_video(gt_video, gen_video):
|
|
gen_video_shape = gen_video.shape
|
|
T_gen, _, H_gen, W_gen = gen_video_shape
|
|
T_eval, _, H_eval, W_eval = gt_video.shape
|
|
|
|
if T_eval < T_gen:
|
|
raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
|
|
|
|
if H_eval < H_gen or W_eval < W_gen:
|
|
# Resize the video maintaining the aspect ratio
|
|
resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
|
|
resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
|
|
gt_video = resize_video(gt_video, resize_height, resize_width)
|
|
# Recalculate the dimensions
|
|
T_eval, _, H_eval, W_eval = gt_video.shape
|
|
|
|
# Center crop
|
|
start_h = (H_eval - H_gen) // 2
|
|
start_w = (W_eval - W_gen) // 2
|
|
cropped_video = gt_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
|
|
|
|
return cropped_video
|
|
|
|
|
|
def get_video_ids(gt_video_dirs, gen_video_dirs):
|
|
video_ids = []
|
|
for f in os.listdir(gt_video_dirs[0]):
|
|
if f.endswith(f".mp4"):
|
|
video_ids.append(f.replace(f".mp4", ""))
|
|
video_ids.sort()
|
|
|
|
for video_dir in gt_video_dirs + gen_video_dirs:
|
|
tmp_video_ids = []
|
|
for f in os.listdir(video_dir):
|
|
if f.endswith(f".mp4"):
|
|
tmp_video_ids.append(f.replace(f".mp4", ""))
|
|
tmp_video_ids.sort()
|
|
if tmp_video_ids != video_ids:
|
|
raise ValueError(f"Video IDs in {video_dir} are different.")
|
|
return video_ids
|
|
|
|
|
|
def get_videos(video_ids, gt_video_dirs, gen_video_dirs):
|
|
gt_videos = {}
|
|
generated_videos = {}
|
|
|
|
for gt_video_dir in gt_video_dirs:
|
|
tmp_gt_videos_tensor = []
|
|
for video_id in video_ids:
|
|
gt_video = load_video(os.path.join(gt_video_dir, f"{video_id}.mp4"))
|
|
tmp_gt_videos_tensor.append(gt_video)
|
|
gt_videos[gt_video_dir] = tmp_gt_videos_tensor
|
|
|
|
for generated_video_dir in gen_video_dirs:
|
|
tmp_generated_videos_tensor = []
|
|
for video_id in video_ids:
|
|
generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.mp4"))
|
|
tmp_generated_videos_tensor.append(generated_video)
|
|
generated_videos[generated_video_dir] = tmp_generated_videos_tensor
|
|
|
|
return gt_videos, generated_videos
|
|
|
|
|
|
def print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs):
|
|
out_str = ""
|
|
|
|
for gt_video_dir in gt_video_dirs:
|
|
for generated_video_dir in gen_video_dirs:
|
|
if gt_video_dir == generated_video_dir:
|
|
continue
|
|
lpips = sum(lpips_results[gt_video_dir][generated_video_dir]) / len(
|
|
lpips_results[gt_video_dir][generated_video_dir]
|
|
)
|
|
psnr = sum(psnr_results[gt_video_dir][generated_video_dir]) / len(
|
|
psnr_results[gt_video_dir][generated_video_dir]
|
|
)
|
|
ssim = sum(ssim_results[gt_video_dir][generated_video_dir]) / len(
|
|
ssim_results[gt_video_dir][generated_video_dir]
|
|
)
|
|
out_str += f"\ngt: {gt_video_dir} -> gen: {generated_video_dir}, lpips: {lpips:.4f}, psnr: {psnr:.4f}, ssim: {ssim:.4f}"
|
|
|
|
return out_str
|
|
|
|
|
|
def main(args):
|
|
device = "cuda"
|
|
gt_video_dirs = args.gt_video_dirs
|
|
gen_video_dirs = args.gen_video_dirs
|
|
|
|
video_ids = get_video_ids(gt_video_dirs, gen_video_dirs)
|
|
print(f"Find {len(video_ids)} videos")
|
|
|
|
prompt_interval = 1
|
|
batch_size = 8
|
|
calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
|
|
|
|
lpips_results = {}
|
|
psnr_results = {}
|
|
ssim_results = {}
|
|
for gt_video_dir in gt_video_dirs:
|
|
lpips_results[gt_video_dir] = {}
|
|
psnr_results[gt_video_dir] = {}
|
|
ssim_results[gt_video_dir] = {}
|
|
for generated_video_dir in gen_video_dirs:
|
|
lpips_results[gt_video_dir][generated_video_dir] = []
|
|
psnr_results[gt_video_dir][generated_video_dir] = []
|
|
ssim_results[gt_video_dir][generated_video_dir] = []
|
|
|
|
total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
|
|
|
|
for idx in tqdm.tqdm(range(total_len)):
|
|
video_ids_batch = video_ids[idx * batch_size : (idx + 1) * batch_size]
|
|
gt_videos, generated_videos = get_videos(video_ids_batch, gt_video_dirs, gen_video_dirs)
|
|
|
|
for gt_video_dir, gt_videos_tensor in gt_videos.items():
|
|
for generated_video_dir, generated_videos_tensor in generated_videos.items():
|
|
if gt_video_dir == generated_video_dir:
|
|
continue
|
|
|
|
if not isinstance(gt_videos_tensor, torch.Tensor):
|
|
for i in range(len(gt_videos_tensor)):
|
|
gt_videos_tensor[i] = resize_gt_video(gt_videos_tensor[i], generated_videos_tensor[0])
|
|
gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
|
|
|
|
generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
|
|
|
|
if calculate_lpips_flag:
|
|
result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
|
|
result = result["value"].values()
|
|
result = float(sum(result) / len(result))
|
|
lpips_results[gt_video_dir][generated_video_dir].append(result)
|
|
|
|
if calculate_psnr_flag:
|
|
result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
|
|
result = result["value"].values()
|
|
result = float(sum(result) / len(result))
|
|
psnr_results[gt_video_dir][generated_video_dir].append(result)
|
|
|
|
if calculate_ssim_flag:
|
|
result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
|
|
result = result["value"].values()
|
|
result = float(sum(result) / len(result))
|
|
ssim_results[gt_video_dir][generated_video_dir].append(result)
|
|
|
|
if (idx + 1) % prompt_interval == 0:
|
|
out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
|
|
print(f"Processed {idx + 1} / {total_len} videos. {out_str}")
|
|
|
|
out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
|
|
|
|
# save
|
|
with open(f"./batch_eval.txt", "w+") as f:
|
|
f.write(out_str)
|
|
|
|
print(f"Processed all videos. {out_str}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--gt_video_dirs", type=str, nargs="+")
|
|
parser.add_argument("--gen_video_dirs", type=str, nargs="+")
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|