mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-09 21:04:25 +08:00
808 lines
22 KiB
Python
808 lines
22 KiB
Python
# Adapted from OpenSora
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
# --------------------------------------------------------
|
|
# References:
|
|
# OpenSora: https://github.com/hpcaitech/Open-Sora
|
|
# --------------------------------------------------------
|
|
|
|
|
|
import numbers
|
|
import os
|
|
import re
|
|
|
|
import numpy as np
|
|
import requests
|
|
import torch
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
|
from torchvision.io import write_video
|
|
from torchvision.utils import save_image
|
|
|
|
IMG_FPS = 120
|
|
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
|
|
|
regex = re.compile(
|
|
r"^(?:http|ftp)s?://" # http:// or https://
|
|
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
|
|
r"localhost|" # localhost...
|
|
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
|
|
r"(?::\d+)?" # optional port
|
|
r"(?:/?|[/?]\S+)$",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
# H:W
|
|
ASPECT_RATIO_MAP = {
|
|
"3:8": "0.38",
|
|
"9:21": "0.43",
|
|
"12:25": "0.48",
|
|
"1:2": "0.50",
|
|
"9:17": "0.53",
|
|
"27:50": "0.54",
|
|
"9:16": "0.56",
|
|
"5:8": "0.62",
|
|
"2:3": "0.67",
|
|
"3:4": "0.75",
|
|
"1:1": "1.00",
|
|
"4:3": "1.33",
|
|
"3:2": "1.50",
|
|
"16:9": "1.78",
|
|
"17:9": "1.89",
|
|
"2:1": "2.00",
|
|
"50:27": "2.08",
|
|
}
|
|
|
|
|
|
# computed from above code
|
|
# S = 8294400
|
|
ASPECT_RATIO_4K = {
|
|
"0.38": (1764, 4704),
|
|
"0.43": (1886, 4400),
|
|
"0.48": (1996, 4158),
|
|
"0.50": (2036, 4072),
|
|
"0.53": (2096, 3960),
|
|
"0.54": (2118, 3918),
|
|
"0.62": (2276, 3642),
|
|
"0.56": (2160, 3840), # base
|
|
"0.67": (2352, 3528),
|
|
"0.75": (2494, 3326),
|
|
"1.00": (2880, 2880),
|
|
"1.33": (3326, 2494),
|
|
"1.50": (3528, 2352),
|
|
"1.78": (3840, 2160),
|
|
"1.89": (3958, 2096),
|
|
"2.00": (4072, 2036),
|
|
"2.08": (4156, 1994),
|
|
}
|
|
|
|
# S = 3686400
|
|
ASPECT_RATIO_2K = {
|
|
"0.38": (1176, 3136),
|
|
"0.43": (1256, 2930),
|
|
"0.48": (1330, 2770),
|
|
"0.50": (1358, 2716),
|
|
"0.53": (1398, 2640),
|
|
"0.54": (1412, 2612),
|
|
"0.56": (1440, 2560), # base
|
|
"0.62": (1518, 2428),
|
|
"0.67": (1568, 2352),
|
|
"0.75": (1662, 2216),
|
|
"1.00": (1920, 1920),
|
|
"1.33": (2218, 1664),
|
|
"1.50": (2352, 1568),
|
|
"1.78": (2560, 1440),
|
|
"1.89": (2638, 1396),
|
|
"2.00": (2716, 1358),
|
|
"2.08": (2772, 1330),
|
|
}
|
|
|
|
# S = 2073600
|
|
ASPECT_RATIO_1080P = {
|
|
"0.38": (882, 2352),
|
|
"0.43": (942, 2198),
|
|
"0.48": (998, 2080),
|
|
"0.50": (1018, 2036),
|
|
"0.53": (1048, 1980),
|
|
"0.54": (1058, 1958),
|
|
"0.56": (1080, 1920), # base
|
|
"0.62": (1138, 1820),
|
|
"0.67": (1176, 1764),
|
|
"0.75": (1248, 1664),
|
|
"1.00": (1440, 1440),
|
|
"1.33": (1662, 1246),
|
|
"1.50": (1764, 1176),
|
|
"1.78": (1920, 1080),
|
|
"1.89": (1980, 1048),
|
|
"2.00": (2036, 1018),
|
|
"2.08": (2078, 998),
|
|
}
|
|
|
|
# S = 921600
|
|
ASPECT_RATIO_720P = {
|
|
"0.38": (588, 1568),
|
|
"0.43": (628, 1466),
|
|
"0.48": (666, 1388),
|
|
"0.50": (678, 1356),
|
|
"0.53": (698, 1318),
|
|
"0.54": (706, 1306),
|
|
"0.56": (720, 1280), # base
|
|
"0.62": (758, 1212),
|
|
"0.67": (784, 1176),
|
|
"0.75": (832, 1110),
|
|
"1.00": (960, 960),
|
|
"1.33": (1108, 832),
|
|
"1.50": (1176, 784),
|
|
"1.78": (1280, 720),
|
|
"1.89": (1320, 698),
|
|
"2.00": (1358, 680),
|
|
"2.08": (1386, 666),
|
|
}
|
|
|
|
# S = 409920
|
|
ASPECT_RATIO_480P = {
|
|
"0.38": (392, 1046),
|
|
"0.43": (420, 980),
|
|
"0.48": (444, 925),
|
|
"0.50": (452, 904),
|
|
"0.53": (466, 880),
|
|
"0.54": (470, 870),
|
|
"0.56": (480, 854), # base
|
|
"0.62": (506, 810),
|
|
"0.67": (522, 784),
|
|
"0.75": (554, 738),
|
|
"1.00": (640, 640),
|
|
"1.33": (740, 555),
|
|
"1.50": (784, 522),
|
|
"1.78": (854, 480),
|
|
"1.89": (880, 466),
|
|
"2.00": (906, 454),
|
|
"2.08": (924, 444),
|
|
}
|
|
|
|
# S = 230400
|
|
ASPECT_RATIO_360P = {
|
|
"0.38": (294, 784),
|
|
"0.43": (314, 732),
|
|
"0.48": (332, 692),
|
|
"0.50": (340, 680),
|
|
"0.53": (350, 662),
|
|
"0.54": (352, 652),
|
|
"0.56": (360, 640), # base
|
|
"0.62": (380, 608),
|
|
"0.67": (392, 588),
|
|
"0.75": (416, 554),
|
|
"1.00": (480, 480),
|
|
"1.33": (554, 416),
|
|
"1.50": (588, 392),
|
|
"1.78": (640, 360),
|
|
"1.89": (660, 350),
|
|
"2.00": (678, 340),
|
|
"2.08": (692, 332),
|
|
}
|
|
|
|
# S = 102240
|
|
ASPECT_RATIO_240P = {
|
|
"0.38": (196, 522),
|
|
"0.43": (210, 490),
|
|
"0.48": (222, 462),
|
|
"0.50": (226, 452),
|
|
"0.53": (232, 438),
|
|
"0.54": (236, 436),
|
|
"0.56": (240, 426), # base
|
|
"0.62": (252, 404),
|
|
"0.67": (262, 393),
|
|
"0.75": (276, 368),
|
|
"1.00": (320, 320),
|
|
"1.33": (370, 278),
|
|
"1.50": (392, 262),
|
|
"1.78": (426, 240),
|
|
"1.89": (440, 232),
|
|
"2.00": (452, 226),
|
|
"2.08": (462, 222),
|
|
}
|
|
|
|
# S = 36864
|
|
ASPECT_RATIO_144P = {
|
|
"0.38": (117, 312),
|
|
"0.43": (125, 291),
|
|
"0.48": (133, 277),
|
|
"0.50": (135, 270),
|
|
"0.53": (139, 262),
|
|
"0.54": (141, 260),
|
|
"0.56": (144, 256), # base
|
|
"0.62": (151, 241),
|
|
"0.67": (156, 234),
|
|
"0.75": (166, 221),
|
|
"1.00": (192, 192),
|
|
"1.33": (221, 165),
|
|
"1.50": (235, 156),
|
|
"1.78": (256, 144),
|
|
"1.89": (263, 139),
|
|
"2.00": (271, 135),
|
|
"2.08": (277, 132),
|
|
}
|
|
|
|
# from PixArt
|
|
# S = 8294400
|
|
ASPECT_RATIO_2880 = {
|
|
"0.25": (1408, 5760),
|
|
"0.26": (1408, 5568),
|
|
"0.27": (1408, 5376),
|
|
"0.28": (1408, 5184),
|
|
"0.32": (1600, 4992),
|
|
"0.33": (1600, 4800),
|
|
"0.34": (1600, 4672),
|
|
"0.4": (1792, 4480),
|
|
"0.42": (1792, 4288),
|
|
"0.47": (1920, 4096),
|
|
"0.49": (1920, 3904),
|
|
"0.51": (1920, 3776),
|
|
"0.55": (2112, 3840),
|
|
"0.59": (2112, 3584),
|
|
"0.68": (2304, 3392),
|
|
"0.72": (2304, 3200),
|
|
"0.78": (2496, 3200),
|
|
"0.83": (2496, 3008),
|
|
"0.89": (2688, 3008),
|
|
"0.93": (2688, 2880),
|
|
"1.0": (2880, 2880),
|
|
"1.07": (2880, 2688),
|
|
"1.12": (3008, 2688),
|
|
"1.21": (3008, 2496),
|
|
"1.28": (3200, 2496),
|
|
"1.39": (3200, 2304),
|
|
"1.47": (3392, 2304),
|
|
"1.7": (3584, 2112),
|
|
"1.82": (3840, 2112),
|
|
"2.03": (3904, 1920),
|
|
"2.13": (4096, 1920),
|
|
"2.39": (4288, 1792),
|
|
"2.5": (4480, 1792),
|
|
"2.92": (4672, 1600),
|
|
"3.0": (4800, 1600),
|
|
"3.12": (4992, 1600),
|
|
"3.68": (5184, 1408),
|
|
"3.82": (5376, 1408),
|
|
"3.95": (5568, 1408),
|
|
"4.0": (5760, 1408),
|
|
}
|
|
|
|
# S = 4194304
|
|
ASPECT_RATIO_2048 = {
|
|
"0.25": (1024, 4096),
|
|
"0.26": (1024, 3968),
|
|
"0.27": (1024, 3840),
|
|
"0.28": (1024, 3712),
|
|
"0.32": (1152, 3584),
|
|
"0.33": (1152, 3456),
|
|
"0.35": (1152, 3328),
|
|
"0.4": (1280, 3200),
|
|
"0.42": (1280, 3072),
|
|
"0.48": (1408, 2944),
|
|
"0.5": (1408, 2816),
|
|
"0.52": (1408, 2688),
|
|
"0.57": (1536, 2688),
|
|
"0.6": (1536, 2560),
|
|
"0.68": (1664, 2432),
|
|
"0.72": (1664, 2304),
|
|
"0.78": (1792, 2304),
|
|
"0.82": (1792, 2176),
|
|
"0.88": (1920, 2176),
|
|
"0.94": (1920, 2048),
|
|
"1.0": (2048, 2048),
|
|
"1.07": (2048, 1920),
|
|
"1.13": (2176, 1920),
|
|
"1.21": (2176, 1792),
|
|
"1.29": (2304, 1792),
|
|
"1.38": (2304, 1664),
|
|
"1.46": (2432, 1664),
|
|
"1.67": (2560, 1536),
|
|
"1.75": (2688, 1536),
|
|
"2.0": (2816, 1408),
|
|
"2.09": (2944, 1408),
|
|
"2.4": (3072, 1280),
|
|
"2.5": (3200, 1280),
|
|
"2.89": (3328, 1152),
|
|
"3.0": (3456, 1152),
|
|
"3.11": (3584, 1152),
|
|
"3.62": (3712, 1024),
|
|
"3.75": (3840, 1024),
|
|
"3.88": (3968, 1024),
|
|
"4.0": (4096, 1024),
|
|
}
|
|
|
|
# S = 1048576
|
|
ASPECT_RATIO_1024 = {
|
|
"0.25": (512, 2048),
|
|
"0.26": (512, 1984),
|
|
"0.27": (512, 1920),
|
|
"0.28": (512, 1856),
|
|
"0.32": (576, 1792),
|
|
"0.33": (576, 1728),
|
|
"0.35": (576, 1664),
|
|
"0.4": (640, 1600),
|
|
"0.42": (640, 1536),
|
|
"0.48": (704, 1472),
|
|
"0.5": (704, 1408),
|
|
"0.52": (704, 1344),
|
|
"0.57": (768, 1344),
|
|
"0.6": (768, 1280),
|
|
"0.68": (832, 1216),
|
|
"0.72": (832, 1152),
|
|
"0.78": (896, 1152),
|
|
"0.82": (896, 1088),
|
|
"0.88": (960, 1088),
|
|
"0.94": (960, 1024),
|
|
"1.0": (1024, 1024),
|
|
"1.07": (1024, 960),
|
|
"1.13": (1088, 960),
|
|
"1.21": (1088, 896),
|
|
"1.29": (1152, 896),
|
|
"1.38": (1152, 832),
|
|
"1.46": (1216, 832),
|
|
"1.67": (1280, 768),
|
|
"1.75": (1344, 768),
|
|
"2.0": (1408, 704),
|
|
"2.09": (1472, 704),
|
|
"2.4": (1536, 640),
|
|
"2.5": (1600, 640),
|
|
"2.89": (1664, 576),
|
|
"3.0": (1728, 576),
|
|
"3.11": (1792, 576),
|
|
"3.62": (1856, 512),
|
|
"3.75": (1920, 512),
|
|
"3.88": (1984, 512),
|
|
"4.0": (2048, 512),
|
|
}
|
|
|
|
# S = 262144
|
|
ASPECT_RATIO_512 = {
|
|
"0.25": (256, 1024),
|
|
"0.26": (256, 992),
|
|
"0.27": (256, 960),
|
|
"0.28": (256, 928),
|
|
"0.32": (288, 896),
|
|
"0.33": (288, 864),
|
|
"0.35": (288, 832),
|
|
"0.4": (320, 800),
|
|
"0.42": (320, 768),
|
|
"0.48": (352, 736),
|
|
"0.5": (352, 704),
|
|
"0.52": (352, 672),
|
|
"0.57": (384, 672),
|
|
"0.6": (384, 640),
|
|
"0.68": (416, 608),
|
|
"0.72": (416, 576),
|
|
"0.78": (448, 576),
|
|
"0.82": (448, 544),
|
|
"0.88": (480, 544),
|
|
"0.94": (480, 512),
|
|
"1.0": (512, 512),
|
|
"1.07": (512, 480),
|
|
"1.13": (544, 480),
|
|
"1.21": (544, 448),
|
|
"1.29": (576, 448),
|
|
"1.38": (576, 416),
|
|
"1.46": (608, 416),
|
|
"1.67": (640, 384),
|
|
"1.75": (672, 384),
|
|
"2.0": (704, 352),
|
|
"2.09": (736, 352),
|
|
"2.4": (768, 320),
|
|
"2.5": (800, 320),
|
|
"2.89": (832, 288),
|
|
"3.0": (864, 288),
|
|
"3.11": (896, 288),
|
|
"3.62": (928, 256),
|
|
"3.75": (960, 256),
|
|
"3.88": (992, 256),
|
|
"4.0": (1024, 256),
|
|
}
|
|
|
|
# S = 65536
|
|
ASPECT_RATIO_256 = {
|
|
"0.25": (128, 512),
|
|
"0.26": (128, 496),
|
|
"0.27": (128, 480),
|
|
"0.28": (128, 464),
|
|
"0.32": (144, 448),
|
|
"0.33": (144, 432),
|
|
"0.35": (144, 416),
|
|
"0.4": (160, 400),
|
|
"0.42": (160, 384),
|
|
"0.48": (176, 368),
|
|
"0.5": (176, 352),
|
|
"0.52": (176, 336),
|
|
"0.57": (192, 336),
|
|
"0.6": (192, 320),
|
|
"0.68": (208, 304),
|
|
"0.72": (208, 288),
|
|
"0.78": (224, 288),
|
|
"0.82": (224, 272),
|
|
"0.88": (240, 272),
|
|
"0.94": (240, 256),
|
|
"1.0": (256, 256),
|
|
"1.07": (256, 240),
|
|
"1.13": (272, 240),
|
|
"1.21": (272, 224),
|
|
"1.29": (288, 224),
|
|
"1.38": (288, 208),
|
|
"1.46": (304, 208),
|
|
"1.67": (320, 192),
|
|
"1.75": (336, 192),
|
|
"2.0": (352, 176),
|
|
"2.09": (368, 176),
|
|
"2.4": (384, 160),
|
|
"2.5": (400, 160),
|
|
"2.89": (416, 144),
|
|
"3.0": (432, 144),
|
|
"3.11": (448, 144),
|
|
"3.62": (464, 128),
|
|
"3.75": (480, 128),
|
|
"3.88": (496, 128),
|
|
"4.0": (512, 128),
|
|
}
|
|
|
|
|
|
def get_closest_ratio(height: float, width: float, ratios: dict):
|
|
aspect_ratio = height / width
|
|
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
|
return closest_ratio
|
|
|
|
|
|
ASPECT_RATIOS = {
|
|
"144p": (36864, ASPECT_RATIO_144P),
|
|
"256": (65536, ASPECT_RATIO_256),
|
|
"240p": (102240, ASPECT_RATIO_240P),
|
|
"360p": (230400, ASPECT_RATIO_360P),
|
|
"512": (262144, ASPECT_RATIO_512),
|
|
"480p": (409920, ASPECT_RATIO_480P),
|
|
"720p": (921600, ASPECT_RATIO_720P),
|
|
"1024": (1048576, ASPECT_RATIO_1024),
|
|
"1080p": (2073600, ASPECT_RATIO_1080P),
|
|
"2k": (3686400, ASPECT_RATIO_2K),
|
|
"2048": (4194304, ASPECT_RATIO_2048),
|
|
"2880": (8294400, ASPECT_RATIO_2880),
|
|
"4k": (8294400, ASPECT_RATIO_4K),
|
|
}
|
|
|
|
|
|
def get_image_size(resolution, ar_ratio):
|
|
ar_key = ASPECT_RATIO_MAP[ar_ratio]
|
|
rs_dict = ASPECT_RATIOS[resolution][1]
|
|
assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}"
|
|
return rs_dict[ar_key]
|
|
|
|
|
|
NUM_FRAMES_MAP = {
|
|
"1x": 51,
|
|
"2x": 102,
|
|
"4x": 204,
|
|
"8x": 408,
|
|
"16x": 816,
|
|
"2s": 51,
|
|
"4s": 102,
|
|
"8s": 204,
|
|
"16s": 408,
|
|
"32s": 816,
|
|
}
|
|
|
|
|
|
def get_num_frames(num_frames):
|
|
if num_frames in NUM_FRAMES_MAP:
|
|
return NUM_FRAMES_MAP[num_frames]
|
|
else:
|
|
return int(num_frames)
|
|
|
|
|
|
def save_sample(x, save_path=None, fps=8, normalize=True, value_range=(-1, 1), force_video=False, verbose=True):
|
|
"""
|
|
Args:
|
|
x (Tensor): shape [C, T, H, W]
|
|
"""
|
|
assert x.ndim == 4
|
|
|
|
if not force_video and x.shape[1] == 1: # T = 1: save as image
|
|
save_path += ".png"
|
|
x = x.squeeze(1)
|
|
save_image([x], save_path, normalize=normalize, value_range=value_range)
|
|
else:
|
|
save_path += ".mp4"
|
|
if normalize:
|
|
low, high = value_range
|
|
x.clamp_(min=low, max=high)
|
|
x.sub_(low).div_(max(high - low, 1e-5))
|
|
|
|
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
|
|
write_video(save_path, x, fps=fps, video_codec="h264")
|
|
if verbose:
|
|
print(f"Saved to {save_path}")
|
|
return save_path
|
|
|
|
|
|
def is_url(url):
|
|
return re.match(regex, url) is not None
|
|
|
|
|
|
def download_url(input_path):
|
|
output_dir = "cache"
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
base_name = os.path.basename(input_path)
|
|
output_path = os.path.join(output_dir, base_name)
|
|
img_data = requests.get(input_path).content
|
|
with open(output_path, "wb") as handler:
|
|
handler.write(img_data)
|
|
print(f"URL {input_path} downloaded to {output_path}")
|
|
return output_path
|
|
|
|
|
|
def get_transforms_video(name="center", image_size=(256, 256)):
|
|
if name is None:
|
|
return None
|
|
elif name == "center":
|
|
assert image_size[0] == image_size[1], "image_size must be square for center crop"
|
|
transform_video = transforms.Compose(
|
|
[
|
|
ToTensorVideo(), # TCHW
|
|
# video_transforms.RandomHorizontalFlipVideo(),
|
|
UCFCenterCropVideo(image_size[0]),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
|
]
|
|
)
|
|
elif name == "resize_crop":
|
|
transform_video = transforms.Compose(
|
|
[
|
|
ToTensorVideo(), # TCHW
|
|
ResizeCrop(image_size),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
|
]
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Transform {name} not implemented")
|
|
return transform_video
|
|
|
|
|
|
def crop(clip, i, j, h, w):
|
|
"""
|
|
Args:
|
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
|
"""
|
|
if len(clip.size()) != 4:
|
|
raise ValueError("clip should be a 4D tensor")
|
|
return clip[..., i : i + h, j : j + w]
|
|
|
|
|
|
def center_crop(clip, crop_size):
|
|
if not _is_tensor_video_clip(clip):
|
|
raise ValueError("clip should be a 4D torch.tensor")
|
|
h, w = clip.size(-2), clip.size(-1)
|
|
th, tw = crop_size
|
|
if h < th or w < tw:
|
|
raise ValueError("height and width must be no smaller than crop_size")
|
|
|
|
i = int(round((h - th) / 2.0))
|
|
j = int(round((w - tw) / 2.0))
|
|
return crop(clip, i, j, th, tw)
|
|
|
|
|
|
def resize_scale(clip, target_size, interpolation_mode):
|
|
if len(target_size) != 2:
|
|
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
|
H, W = clip.size(-2), clip.size(-1)
|
|
scale_ = target_size[0] / min(H, W)
|
|
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
|
|
|
|
|
|
class UCFCenterCropVideo:
|
|
"""
|
|
First scale to the specified size in equal proportion to the short edge,
|
|
then center cropping
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
size,
|
|
interpolation_mode="bilinear",
|
|
):
|
|
if isinstance(size, tuple):
|
|
if len(size) != 2:
|
|
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
|
self.size = size
|
|
else:
|
|
self.size = (size, size)
|
|
|
|
self.interpolation_mode = interpolation_mode
|
|
|
|
def __call__(self, clip):
|
|
"""
|
|
Args:
|
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
|
Returns:
|
|
torch.tensor: scale resized / center cropped video clip.
|
|
size is (T, C, crop_size, crop_size)
|
|
"""
|
|
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
|
clip_center_crop = center_crop(clip_resize, self.size)
|
|
return clip_center_crop
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
|
|
|
|
|
def _is_tensor_video_clip(clip):
|
|
if not torch.is_tensor(clip):
|
|
raise TypeError("clip should be Tensor. Got %s" % type(clip))
|
|
|
|
if not clip.ndimension() == 4:
|
|
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
|
|
|
return True
|
|
|
|
|
|
def to_tensor(clip):
|
|
"""
|
|
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
|
permute the dimensions of clip tensor
|
|
Args:
|
|
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
|
Return:
|
|
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
|
"""
|
|
_is_tensor_video_clip(clip)
|
|
if not clip.dtype == torch.uint8:
|
|
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
|
|
# return clip.float().permute(3, 0, 1, 2) / 255.0
|
|
return clip.float() / 255.0
|
|
|
|
|
|
class ToTensorVideo:
|
|
"""
|
|
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
|
permute the dimensions of clip tensor
|
|
"""
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, clip):
|
|
"""
|
|
Args:
|
|
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
|
Return:
|
|
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
|
"""
|
|
return to_tensor(clip)
|
|
|
|
def __repr__(self) -> str:
|
|
return self.__class__.__name__
|
|
|
|
|
|
class ResizeCrop:
|
|
def __init__(self, size):
|
|
if isinstance(size, numbers.Number):
|
|
self.size = (int(size), int(size))
|
|
else:
|
|
self.size = size
|
|
|
|
def __call__(self, clip):
|
|
clip = resize_crop_to_fill(clip, self.size)
|
|
return clip
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(size={self.size})"
|
|
|
|
|
|
def get_transforms_image(name="center", image_size=(256, 256)):
|
|
if name is None:
|
|
return None
|
|
elif name == "center":
|
|
assert image_size[0] == image_size[1], "Image size must be square for center crop"
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
|
|
# transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
|
]
|
|
)
|
|
elif name == "resize_crop":
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
|
]
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Transform {name} not implemented")
|
|
return transform
|
|
|
|
|
|
def center_crop_arr(pil_image, image_size):
|
|
"""
|
|
Center cropping implementation from ADM.
|
|
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
|
"""
|
|
while min(*pil_image.size) >= 2 * image_size:
|
|
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
|
|
|
|
scale = image_size / min(*pil_image.size)
|
|
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
|
|
|
|
arr = np.array(pil_image)
|
|
crop_y = (arr.shape[0] - image_size) // 2
|
|
crop_x = (arr.shape[1] - image_size) // 2
|
|
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
|
|
|
|
|
|
def resize_crop_to_fill(pil_image, image_size):
|
|
w, h = pil_image.size # PIL is (W, H)
|
|
th, tw = image_size
|
|
rh, rw = th / h, tw / w
|
|
if rh > rw:
|
|
sh, sw = th, round(w * rh)
|
|
image = pil_image.resize((sw, sh), Image.BICUBIC)
|
|
i = 0
|
|
j = int(round((sw - tw) / 2.0))
|
|
else:
|
|
sh, sw = round(h * rw), tw
|
|
image = pil_image.resize((sw, sh), Image.BICUBIC)
|
|
i = int(round((sh - th) / 2.0))
|
|
j = 0
|
|
arr = np.array(image)
|
|
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
|
|
return Image.fromarray(arr[i : i + th, j : j + tw])
|
|
|
|
|
|
def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)):
|
|
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
|
if transform is None:
|
|
transform = get_transforms_video(image_size=image_size, name=transform_name)
|
|
video = transform(vframes) # T C H W
|
|
video = video.permute(1, 0, 2, 3)
|
|
return video
|
|
|
|
|
|
def read_from_path(path, image_size, transform_name="center"):
|
|
if is_url(path):
|
|
path = download_url(path)
|
|
ext = os.path.splitext(path)[-1].lower()
|
|
if ext.lower() in VID_EXTENSIONS:
|
|
return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
|
|
else:
|
|
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
|
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
|
|
|
|
|
|
def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)):
|
|
image = pil_loader(path)
|
|
if transform is None:
|
|
transform = get_transforms_image(image_size=image_size, name=transform_name)
|
|
image = transform(image)
|
|
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
|
|
video = video.permute(1, 0, 2, 3)
|
|
return video
|
|
|
|
|
|
def prepare_multi_resolution_info(info_type, batch_size, image_size, num_frames, fps, device, dtype):
|
|
if info_type is None:
|
|
return dict()
|
|
elif info_type == "PixArtMS":
|
|
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(batch_size, 1)
|
|
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(batch_size, 1)
|
|
return dict(ar=ar, hw=hw)
|
|
elif info_type in ["STDiT2", "OpenSora"]:
|
|
fps = fps if num_frames > 1 else IMG_FPS
|
|
fps = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size)
|
|
height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(batch_size)
|
|
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(batch_size)
|
|
num_frames = torch.tensor([num_frames], device=device, dtype=dtype).repeat(batch_size)
|
|
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(batch_size)
|
|
return dict(height=height, width=width, num_frames=num_frames, ar=ar, fps=fps)
|
|
else:
|
|
raise NotImplementedError
|