Cleanup unused dependencies

This commit is contained in:
kijai 2024-09-21 16:37:54 +03:00
parent 33e67e0c98
commit 73fa4be48f

View File

@ -1,20 +1,10 @@
import os import os
import gc import gc
import imageio
import numpy as np import numpy as np
import torch import torch
import torchvision
import cv2
from einops import rearrange
from PIL import Image from PIL import Image
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os
import cv2
import numpy as np
import torch
from PIL import Image
def tensor2pil(image): def tensor2pil(image):
return Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) return Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8))
@ -73,60 +63,6 @@ def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
height_slider = round(original_height * ratio) height_slider = round(original_height * ratio)
return height_slider, width_slider return height_slider, width_slider
def color_transfer(sc, dc):
"""
Transfer color distribution from of sc, referred to dc.
Args:
sc (numpy.ndarray): input image to be transfered.
dc (numpy.ndarray): reference image
Returns:
numpy.ndarray: Transferred color distribution on the sc.
"""
def get_mean_and_std(img):
x_mean, x_std = cv2.meanStdDev(img)
x_mean = np.hstack(np.around(x_mean, 2))
x_std = np.hstack(np.around(x_std, 2))
return x_mean, x_std
sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB)
s_mean, s_std = get_mean_and_std(sc)
dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB)
t_mean, t_std = get_mean_and_std(dc)
img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean
np.putmask(img_n, img_n > 255, 255)
np.putmask(img_n, img_n < 0, 0)
dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB)
return dst
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False):
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).numpy().astype(np.uint8)
outputs.append(Image.fromarray(x))
if color_transfer_post_process:
for i in range(1, len(outputs)):
outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0])))
os.makedirs(os.path.dirname(path), exist_ok=True)
if imageio_backend:
if path.endswith("mp4"):
imageio.mimsave(path, outputs, fps=fps)
else:
imageio.mimsave(path, outputs, duration=(1000 * 1/fps))
else:
if path.endswith("mp4"):
path = path.replace('.mp4', '.gif')
outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size): def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
if validation_image_start is not None and validation_image_end is not None: if validation_image_start is not None and validation_image_end is not None:
if type(validation_image_start) is str and os.path.isfile(validation_image_start): if type(validation_image_start) is str and os.path.isfile(validation_image_start):
@ -224,18 +160,7 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide
return input_video, input_video_mask, clip_image return input_video, input_video_mask, clip_image
def get_video_to_video_latent(input_video_path, video_length, sample_size): def get_video_to_video_latent(input_video_path, video_length, sample_size):
if type(input_video_path) is str: input_video = input_video_path
cap = cv2.VideoCapture(input_video_path)
input_video = []
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
else:
input_video = input_video_path
input_video = torch.from_numpy(np.array(input_video))[:video_length] input_video = torch.from_numpy(np.array(input_video))[:video_length]
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255 input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255