mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2026-01-24 11:54:28 +08:00
update the version of videosys
This commit is contained in:
parent
66038c512f
commit
21ea48e71a
30
eval/teacache/README.md
Normal file
30
eval/teacache/README.md
Normal file
@ -0,0 +1,30 @@
|
||||
# Evaluation of TeaCache
|
||||
|
||||
We first generate videos according to VBench's prompts.
|
||||
|
||||
And then calculate Vbench, PSNR, LPIPS and SSIM based on the video generated.
|
||||
|
||||
1. Generate video
|
||||
```
|
||||
cd eval/teacache
|
||||
python experiments/latte.py
|
||||
python experiments/opensora.py
|
||||
python experiments/open_sora_plan.py
|
||||
```
|
||||
|
||||
2. Calculate Vbench score
|
||||
```
|
||||
# vbench is calculated independently
|
||||
# get scores for all metrics
|
||||
python vbench/run_vbench.py --video_path aaa --save_path bbb
|
||||
# calculate final score
|
||||
python vbench/cal_vbench.py --score_dir bbb
|
||||
```
|
||||
|
||||
3. Calculate other metrics
|
||||
```
|
||||
# these metrics are calculated compared with original model
|
||||
# gt video is the video of original model
|
||||
# generated video is our methods's results
|
||||
python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb
|
||||
```
|
||||
205
eval/teacache/common_metrics/batch_eval.py
Normal file
205
eval/teacache/common_metrics/batch_eval.py
Normal file
@ -0,0 +1,205 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import imageio
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
import tqdm
|
||||
from calculate_lpips import calculate_lpips
|
||||
from calculate_psnr import calculate_psnr
|
||||
from calculate_ssim import calculate_ssim
|
||||
|
||||
|
||||
def load_video(video_path):
|
||||
"""
|
||||
Load a video from the given path and convert it to a PyTorch tensor.
|
||||
"""
|
||||
# Read the video using imageio
|
||||
reader = imageio.get_reader(video_path, "ffmpeg")
|
||||
|
||||
# Extract frames and convert to a list of tensors
|
||||
frames = []
|
||||
for frame in reader:
|
||||
# Convert the frame to a tensor and permute the dimensions to match (C, H, W)
|
||||
frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
|
||||
frames.append(frame_tensor)
|
||||
|
||||
# Stack the list of tensors into a single tensor with shape (T, C, H, W)
|
||||
video_tensor = torch.stack(frames)
|
||||
|
||||
return video_tensor
|
||||
|
||||
|
||||
def resize_video(video, target_height, target_width):
|
||||
resized_frames = []
|
||||
for frame in video:
|
||||
resized_frame = F.resize(frame, [target_height, target_width])
|
||||
resized_frames.append(resized_frame)
|
||||
return torch.stack(resized_frames)
|
||||
|
||||
|
||||
def resize_gt_video(gt_video, gen_video):
|
||||
gen_video_shape = gen_video.shape
|
||||
T_gen, _, H_gen, W_gen = gen_video_shape
|
||||
T_eval, _, H_eval, W_eval = gt_video.shape
|
||||
|
||||
if T_eval < T_gen:
|
||||
raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
|
||||
|
||||
if H_eval < H_gen or W_eval < W_gen:
|
||||
# Resize the video maintaining the aspect ratio
|
||||
resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
|
||||
resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
|
||||
gt_video = resize_video(gt_video, resize_height, resize_width)
|
||||
# Recalculate the dimensions
|
||||
T_eval, _, H_eval, W_eval = gt_video.shape
|
||||
|
||||
# Center crop
|
||||
start_h = (H_eval - H_gen) // 2
|
||||
start_w = (W_eval - W_gen) // 2
|
||||
cropped_video = gt_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
|
||||
|
||||
return cropped_video
|
||||
|
||||
|
||||
def get_video_ids(gt_video_dirs, gen_video_dirs):
|
||||
video_ids = []
|
||||
for f in os.listdir(gt_video_dirs[0]):
|
||||
if f.endswith(f".mp4"):
|
||||
video_ids.append(f.replace(f".mp4", ""))
|
||||
video_ids.sort()
|
||||
|
||||
for video_dir in gt_video_dirs + gen_video_dirs:
|
||||
tmp_video_ids = []
|
||||
for f in os.listdir(video_dir):
|
||||
if f.endswith(f".mp4"):
|
||||
tmp_video_ids.append(f.replace(f".mp4", ""))
|
||||
tmp_video_ids.sort()
|
||||
if tmp_video_ids != video_ids:
|
||||
raise ValueError(f"Video IDs in {video_dir} are different.")
|
||||
return video_ids
|
||||
|
||||
|
||||
def get_videos(video_ids, gt_video_dirs, gen_video_dirs):
|
||||
gt_videos = {}
|
||||
generated_videos = {}
|
||||
|
||||
for gt_video_dir in gt_video_dirs:
|
||||
tmp_gt_videos_tensor = []
|
||||
for video_id in video_ids:
|
||||
gt_video = load_video(os.path.join(gt_video_dir, f"{video_id}.mp4"))
|
||||
tmp_gt_videos_tensor.append(gt_video)
|
||||
gt_videos[gt_video_dir] = tmp_gt_videos_tensor
|
||||
|
||||
for generated_video_dir in gen_video_dirs:
|
||||
tmp_generated_videos_tensor = []
|
||||
for video_id in video_ids:
|
||||
generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.mp4"))
|
||||
tmp_generated_videos_tensor.append(generated_video)
|
||||
generated_videos[generated_video_dir] = tmp_generated_videos_tensor
|
||||
|
||||
return gt_videos, generated_videos
|
||||
|
||||
|
||||
def print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs):
|
||||
out_str = ""
|
||||
|
||||
for gt_video_dir in gt_video_dirs:
|
||||
for generated_video_dir in gen_video_dirs:
|
||||
if gt_video_dir == generated_video_dir:
|
||||
continue
|
||||
lpips = sum(lpips_results[gt_video_dir][generated_video_dir]) / len(
|
||||
lpips_results[gt_video_dir][generated_video_dir]
|
||||
)
|
||||
psnr = sum(psnr_results[gt_video_dir][generated_video_dir]) / len(
|
||||
psnr_results[gt_video_dir][generated_video_dir]
|
||||
)
|
||||
ssim = sum(ssim_results[gt_video_dir][generated_video_dir]) / len(
|
||||
ssim_results[gt_video_dir][generated_video_dir]
|
||||
)
|
||||
out_str += f"\ngt: {gt_video_dir} -> gen: {generated_video_dir}, lpips: {lpips:.4f}, psnr: {psnr:.4f}, ssim: {ssim:.4f}"
|
||||
|
||||
return out_str
|
||||
|
||||
|
||||
def main(args):
|
||||
device = "cuda"
|
||||
gt_video_dirs = args.gt_video_dirs
|
||||
gen_video_dirs = args.gen_video_dirs
|
||||
|
||||
video_ids = get_video_ids(gt_video_dirs, gen_video_dirs)
|
||||
print(f"Find {len(video_ids)} videos")
|
||||
|
||||
prompt_interval = 1
|
||||
batch_size = 8
|
||||
calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
|
||||
|
||||
lpips_results = {}
|
||||
psnr_results = {}
|
||||
ssim_results = {}
|
||||
for gt_video_dir in gt_video_dirs:
|
||||
lpips_results[gt_video_dir] = {}
|
||||
psnr_results[gt_video_dir] = {}
|
||||
ssim_results[gt_video_dir] = {}
|
||||
for generated_video_dir in gen_video_dirs:
|
||||
lpips_results[gt_video_dir][generated_video_dir] = []
|
||||
psnr_results[gt_video_dir][generated_video_dir] = []
|
||||
ssim_results[gt_video_dir][generated_video_dir] = []
|
||||
|
||||
total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
|
||||
|
||||
for idx in tqdm.tqdm(range(total_len)):
|
||||
video_ids_batch = video_ids[idx * batch_size : (idx + 1) * batch_size]
|
||||
gt_videos, generated_videos = get_videos(video_ids_batch, gt_video_dirs, gen_video_dirs)
|
||||
|
||||
for gt_video_dir, gt_videos_tensor in gt_videos.items():
|
||||
for generated_video_dir, generated_videos_tensor in generated_videos.items():
|
||||
if gt_video_dir == generated_video_dir:
|
||||
continue
|
||||
|
||||
if not isinstance(gt_videos_tensor, torch.Tensor):
|
||||
for i in range(len(gt_videos_tensor)):
|
||||
gt_videos_tensor[i] = resize_gt_video(gt_videos_tensor[i], generated_videos_tensor[0])
|
||||
gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
|
||||
|
||||
generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
|
||||
|
||||
if calculate_lpips_flag:
|
||||
result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
|
||||
result = result["value"].values()
|
||||
result = float(sum(result) / len(result))
|
||||
lpips_results[gt_video_dir][generated_video_dir].append(result)
|
||||
|
||||
if calculate_psnr_flag:
|
||||
result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
|
||||
result = result["value"].values()
|
||||
result = float(sum(result) / len(result))
|
||||
psnr_results[gt_video_dir][generated_video_dir].append(result)
|
||||
|
||||
if calculate_ssim_flag:
|
||||
result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
|
||||
result = result["value"].values()
|
||||
result = float(sum(result) / len(result))
|
||||
ssim_results[gt_video_dir][generated_video_dir].append(result)
|
||||
|
||||
if (idx + 1) % prompt_interval == 0:
|
||||
out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
|
||||
print(f"Processed {idx + 1} / {total_len} videos. {out_str}")
|
||||
|
||||
out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
|
||||
|
||||
# save
|
||||
with open(f"./batch_eval.txt", "w+") as f:
|
||||
f.write(out_str)
|
||||
|
||||
print(f"Processed all videos. {out_str}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--gt_video_dirs", type=str, nargs="+")
|
||||
parser.add_argument("--gen_video_dirs", type=str, nargs="+")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@ -5,12 +5,7 @@ from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from videosys.core.parallel_mgr import (
|
||||
enable_sequence_parallel,
|
||||
get_cfg_parallel_size,
|
||||
get_data_parallel_group,
|
||||
get_sequence_parallel_group,
|
||||
)
|
||||
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
|
||||
|
||||
def teacache_forward(
|
||||
self,
|
||||
@ -67,7 +62,7 @@ def teacache_forward(
|
||||
"""
|
||||
|
||||
# 0. Split batch for data parallelism
|
||||
if get_cfg_parallel_size() > 1:
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
(
|
||||
hidden_states,
|
||||
timestep,
|
||||
@ -77,7 +72,7 @@ def teacache_forward(
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
) = batch_func(
|
||||
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
|
||||
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
||||
hidden_states,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
@ -193,14 +188,14 @@ def teacache_forward(
|
||||
if not should_calc:
|
||||
hidden_states += self.previous_residual
|
||||
else:
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(frame + use_image_num)
|
||||
set_spatial_pad(num_patches)
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
|
||||
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
|
||||
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
|
||||
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
|
||||
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
|
||||
temp_pos_embed = split_sequence(
|
||||
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
|
||||
)
|
||||
else:
|
||||
temp_pos_embed = self.temp_pos_embed
|
||||
@ -323,14 +318,14 @@ def teacache_forward(
|
||||
).contiguous()
|
||||
self.previous_residual = hidden_states - hidden_states_origin
|
||||
else:
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(frame + use_image_num)
|
||||
set_spatial_pad(num_patches)
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
|
||||
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
|
||||
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
|
||||
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
|
||||
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
|
||||
temp_pos_embed = split_sequence(
|
||||
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
|
||||
)
|
||||
else:
|
||||
temp_pos_embed = self.temp_pos_embed
|
||||
@ -451,7 +446,8 @@ def teacache_forward(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
if enable_sequence_parallel():
|
||||
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
if self.enable_teacache:
|
||||
if should_calc:
|
||||
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
|
||||
@ -488,8 +484,8 @@ def teacache_forward(
|
||||
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
|
||||
|
||||
# 3. Gather batch for data parallelism
|
||||
if get_cfg_parallel_size() > 1:
|
||||
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
@ -497,10 +493,6 @@ def teacache_forward(
|
||||
return Transformer3DModelOutput(sample=output)
|
||||
|
||||
|
||||
def eval_base(prompt_list):
|
||||
config = LatteConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
generate_func(engine, prompt_list, "./samples/latte_base", loop=5)
|
||||
|
||||
def eval_teacache_slow(prompt_list):
|
||||
config = LatteConfig()
|
||||
@ -523,10 +515,18 @@ def eval_teacache_fast(prompt_list):
|
||||
engine.driver_worker.transformer.previous_residual = None
|
||||
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
||||
generate_func(engine, prompt_list, "./samples/latte_teacache_fast", loop=5)
|
||||
|
||||
|
||||
|
||||
def eval_base(prompt_list):
|
||||
config = LatteConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
generate_func(engine, prompt_list, "./samples/latte_base", loop=5)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
||||
# eval_base(prompt_list)
|
||||
eval_base(prompt_list)
|
||||
eval_teacache_slow(prompt_list)
|
||||
# eval_teacache_fast(prompt_list)
|
||||
|
||||
eval_teacache_fast(prompt_list)
|
||||
@ -3,23 +3,23 @@ from videosys import OpenSoraConfig, VideoSysEngine
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from videosys.models.transformers.open_sora_transformer_3d import t2i_modulate, auto_grad_checkpoint
|
||||
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_temporal_pad, set_spatial_pad, set_temporal_pad, split_sequence
|
||||
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
|
||||
import numpy as np
|
||||
from videosys.utils.utils import batch_func
|
||||
from videosys.core.parallel_mgr import (
|
||||
enable_sequence_parallel,
|
||||
get_cfg_parallel_size,
|
||||
get_data_parallel_group,
|
||||
get_sequence_parallel_group,
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
def teacache_forward(
|
||||
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
|
||||
):
|
||||
# === Split batch ===
|
||||
if get_cfg_parallel_size() > 1:
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
x, timestep, y, x_mask, mask = batch_func(
|
||||
partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
|
||||
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
||||
x,
|
||||
timestep,
|
||||
y,
|
||||
x_mask,
|
||||
mask,
|
||||
)
|
||||
|
||||
dtype = self.x_embedder.proj.weight.dtype
|
||||
@ -60,7 +60,7 @@ def teacache_forward(
|
||||
# === get x embed ===
|
||||
x = self.x_embedder(x) # [B, N, C]
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = x + pos_emb
|
||||
x = x + pos_emb
|
||||
|
||||
if self.enable_teacache:
|
||||
inp = x.clone()
|
||||
@ -84,7 +84,6 @@ def teacache_forward(
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
|
||||
|
||||
# === blocks ===
|
||||
if self.enable_teacache:
|
||||
if not should_calc:
|
||||
@ -92,16 +91,15 @@ def teacache_forward(
|
||||
x += self.previous_residual
|
||||
else:
|
||||
# shard over the sequence dim if sp is enabled
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(T)
|
||||
set_spatial_pad(S)
|
||||
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
set_pad("temporal", T, self.parallel_manager.sp_group)
|
||||
set_pad("spatial", S, self.parallel_manager.sp_group)
|
||||
x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
|
||||
T = x.shape[1]
|
||||
x_mask_org = x_mask
|
||||
x_mask = split_sequence(
|
||||
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
|
||||
)
|
||||
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
origin_x = x.clone().detach()
|
||||
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
|
||||
@ -135,16 +133,17 @@ def teacache_forward(
|
||||
self.previous_residual = x - origin_x
|
||||
else:
|
||||
# shard over the sequence dim if sp is enabled
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(T)
|
||||
set_spatial_pad(S)
|
||||
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
set_pad("temporal", T, self.parallel_manager.sp_group)
|
||||
set_pad("spatial", S, self.parallel_manager.sp_group)
|
||||
x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
|
||||
T = x.shape[1]
|
||||
x_mask_org = x_mask
|
||||
x_mask = split_sequence(
|
||||
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
|
||||
)
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
|
||||
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
|
||||
x = auto_grad_checkpoint(
|
||||
spatial_block,
|
||||
@ -174,25 +173,23 @@ def teacache_forward(
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
if enable_sequence_parallel():
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
if self.enable_teacache:
|
||||
if should_calc:
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
self.previous_residual = rearrange(self.previous_residual, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
|
||||
self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
|
||||
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
|
||||
self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
|
||||
T, S = x.shape[1], x.shape[2]
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
self.previous_residual = rearrange(self.previous_residual, "B T S C -> B (T S) C", T=T, S=S)
|
||||
x_mask = x_mask_org
|
||||
else:
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
|
||||
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
|
||||
T, S = x.shape[1], x.shape[2]
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
x_mask = x_mask_org
|
||||
|
||||
|
||||
# === final layer ===
|
||||
x = self.final_layer(x, t, x_mask, t0, T, S)
|
||||
x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
|
||||
@ -201,12 +198,11 @@ def teacache_forward(
|
||||
x = x.to(torch.float32)
|
||||
|
||||
# === Gather Output ===
|
||||
if get_cfg_parallel_size() > 1:
|
||||
x = gather_sequence(x, get_data_parallel_group(), dim=0)
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
x = gather_sequence(x, self.parallel_manager.cp_group, dim=0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def eval_base(prompt_list):
|
||||
config = OpenSoraConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
@ -235,9 +231,9 @@ def eval_teacache_fast(prompt_list):
|
||||
generate_func(engine, prompt_list, "./samples/opensora_teacache_fast", loop=5)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
||||
# eval_base(prompt_list)
|
||||
eval_base(prompt_list)
|
||||
eval_teacache_slow(prompt_list)
|
||||
# eval_teacache_fast(prompt_list)
|
||||
eval_teacache_fast(prompt_list)
|
||||
|
||||
@ -5,12 +5,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from videosys.core.parallel_mgr import (
|
||||
enable_sequence_parallel,
|
||||
get_cfg_parallel_group,
|
||||
get_cfg_parallel_size,
|
||||
get_sequence_parallel_group,
|
||||
)
|
||||
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
|
||||
|
||||
def teacache_forward(
|
||||
self,
|
||||
@ -66,7 +61,7 @@ def teacache_forward(
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. Split batch
|
||||
if get_cfg_parallel_size() > 1:
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
(
|
||||
hidden_states,
|
||||
timestep,
|
||||
@ -75,7 +70,7 @@ def teacache_forward(
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
) = batch_func(
|
||||
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
|
||||
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
||||
hidden_states,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
@ -204,20 +199,20 @@ def teacache_forward(
|
||||
if not should_calc:
|
||||
hidden_states += self.previous_residual
|
||||
else:
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(frame + use_image_num)
|
||||
set_spatial_pad(num_patches)
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
|
||||
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
|
||||
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
|
||||
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
|
||||
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
|
||||
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
|
||||
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
|
||||
temp_pos_embed = split_sequence(
|
||||
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
|
||||
)
|
||||
else:
|
||||
temp_pos_embed = self.temp_pos_embed
|
||||
ori_hidden_states = hidden_states.clone()
|
||||
ori_hidden_states = hidden_states.clone().detach()
|
||||
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
@ -359,20 +354,20 @@ def teacache_forward(
|
||||
).contiguous()
|
||||
self.previous_residual = hidden_states - ori_hidden_states
|
||||
else:
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(frame + use_image_num)
|
||||
set_spatial_pad(num_patches)
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
|
||||
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
|
||||
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
|
||||
self.previous_residual = self.split_from_second_dim(self.previous_residual, input_batch_size) if self.previous_residual is not None else None
|
||||
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
|
||||
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
|
||||
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
|
||||
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
|
||||
temp_pos_embed = split_sequence(
|
||||
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
|
||||
)
|
||||
else:
|
||||
temp_pos_embed = self.temp_pos_embed
|
||||
|
||||
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
@ -512,8 +507,8 @@ def teacache_forward(
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
if enable_sequence_parallel():
|
||||
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
if self.enable_teacache:
|
||||
if should_calc:
|
||||
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
|
||||
@ -550,8 +545,8 @@ def teacache_forward(
|
||||
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
|
||||
|
||||
# 3. Gather batch for data parallelism
|
||||
if get_cfg_parallel_size() > 1:
|
||||
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
@ -559,14 +554,8 @@ def teacache_forward(
|
||||
return Transformer3DModelOutput(sample=output)
|
||||
|
||||
|
||||
|
||||
def eval_base(prompt_list):
|
||||
config = OpenSoraPlanConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)
|
||||
|
||||
def eval_teacache_slow(prompt_list):
|
||||
config = OpenSoraPlanConfig()
|
||||
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
|
||||
engine = VideoSysEngine(config)
|
||||
engine.driver_worker.transformer.enable_teacache = True
|
||||
engine.driver_worker.transformer.rel_l1_thresh = 0.1
|
||||
@ -577,7 +566,7 @@ def eval_teacache_slow(prompt_list):
|
||||
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_slow", loop=5)
|
||||
|
||||
def eval_teacache_fast(prompt_list):
|
||||
config = OpenSoraPlanConfig()
|
||||
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
|
||||
engine = VideoSysEngine(config)
|
||||
engine.driver_worker.transformer.enable_teacache = True
|
||||
engine.driver_worker.transformer.rel_l1_thresh = 0.2
|
||||
@ -587,8 +576,16 @@ def eval_teacache_fast(prompt_list):
|
||||
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
||||
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_fast", loop=5)
|
||||
|
||||
|
||||
def eval_base(prompt_list):
|
||||
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", )
|
||||
engine = VideoSysEngine(config)
|
||||
generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
||||
# eval_base(prompt_list)
|
||||
eval_base(prompt_list)
|
||||
eval_teacache_slow(prompt_list)
|
||||
# eval_teacache_fast(prompt_list)
|
||||
eval_teacache_fast(prompt_list)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
accelerate>0.17.0
|
||||
bs4
|
||||
click
|
||||
colossalai
|
||||
colossalai==0.4.0
|
||||
diffusers==0.30.0
|
||||
einops
|
||||
fabric
|
||||
@ -16,7 +18,9 @@ pydantic
|
||||
ray
|
||||
rich
|
||||
safetensors
|
||||
sentencepiece
|
||||
timm
|
||||
torch>=1.13
|
||||
tqdm
|
||||
transformers
|
||||
peft==0.13.2
|
||||
transformers==4.39.3
|
||||
|
||||
38
setup.py
38
setup.py
@ -1,6 +1,9 @@
|
||||
from typing import List
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
from setuptools.command.develop import develop
|
||||
from setuptools.command.egg_info import egg_info
|
||||
from setuptools.command.install import install
|
||||
|
||||
|
||||
def fetch_requirements(path) -> List[str]:
|
||||
@ -14,7 +17,9 @@ def fetch_requirements(path) -> List[str]:
|
||||
The lines in the requirements file.
|
||||
"""
|
||||
with open(path, "r") as fd:
|
||||
return [r.strip() for r in fd.readlines()]
|
||||
requirements = [r.strip() for r in fd.readlines()]
|
||||
# requirements.remove("colossalai")
|
||||
return requirements
|
||||
|
||||
|
||||
def fetch_readme() -> str:
|
||||
@ -28,6 +33,28 @@ def fetch_readme() -> str:
|
||||
return f.read()
|
||||
|
||||
|
||||
def custom_install():
|
||||
return ["pip", "install", "colossalai", "--no-deps"]
|
||||
|
||||
|
||||
class CustomInstallCommand(install):
|
||||
def run(self):
|
||||
install.run(self)
|
||||
self.spawn(custom_install())
|
||||
|
||||
|
||||
class CustomDevelopCommand(develop):
|
||||
def run(self):
|
||||
develop.run(self)
|
||||
self.spawn(custom_install())
|
||||
|
||||
|
||||
class CustomEggInfoCommand(egg_info):
|
||||
def run(self):
|
||||
egg_info.run(self)
|
||||
self.spawn(custom_install())
|
||||
|
||||
|
||||
setup(
|
||||
name="videosys",
|
||||
version="2.0.0",
|
||||
@ -39,12 +66,17 @@ setup(
|
||||
"*.egg-info",
|
||||
)
|
||||
),
|
||||
description="VideoSys",
|
||||
description="TeaCache",
|
||||
long_description=fetch_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
license="Apache Software License 2.0",
|
||||
install_requires=fetch_requirements("requirements.txt"),
|
||||
python_requires=">=3.6",
|
||||
python_requires=">=3.7",
|
||||
# cmdclass={
|
||||
# "install": CustomInstallCommand,
|
||||
# "develop": CustomDevelopCommand,
|
||||
# "egg_info": CustomEggInfoCommand,
|
||||
# },
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
|
||||
@ -3,13 +3,20 @@ from .core.parallel_mgr import initialize
|
||||
from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
|
||||
from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline
|
||||
from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
|
||||
from .pipelines.open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
|
||||
from .pipelines.open_sora_plan import (
|
||||
OpenSoraPlanConfig,
|
||||
OpenSoraPlanPipeline,
|
||||
OpenSoraPlanV110PABConfig,
|
||||
OpenSoraPlanV120PABConfig,
|
||||
)
|
||||
from .pipelines.vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline
|
||||
|
||||
__all__ = [
|
||||
"initialize",
|
||||
"VideoSysEngine",
|
||||
"LattePipeline", "LatteConfig", "LattePABConfig",
|
||||
"OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanPABConfig",
|
||||
"OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig",
|
||||
"OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig",
|
||||
"CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"
|
||||
"CogVideoXPipeline", "CogVideoXConfig", "CogVideoXPABConfig",
|
||||
"VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig"
|
||||
] # fmt: skip
|
||||
|
||||
@ -7,8 +7,6 @@ from einops import rearrange
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from videosys.core.parallel_mgr import get_sequence_parallel_size
|
||||
|
||||
# ======================================================
|
||||
# Model
|
||||
# ======================================================
|
||||
@ -369,30 +367,18 @@ def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
|
||||
# Pad
|
||||
# ==============================
|
||||
|
||||
SPTIAL_PAD = 0
|
||||
TEMPORAL_PAD = 0
|
||||
PAD_DICT = {}
|
||||
|
||||
|
||||
def set_spatial_pad(dim_size: int):
|
||||
sp_size = get_sequence_parallel_size()
|
||||
def set_pad(name: str, dim_size: int, parallel_group: dist.ProcessGroup):
|
||||
sp_size = dist.get_world_size(parallel_group)
|
||||
pad = (sp_size - (dim_size % sp_size)) % sp_size
|
||||
global SPTIAL_PAD
|
||||
SPTIAL_PAD = pad
|
||||
global PAD_DICT
|
||||
PAD_DICT[name] = pad
|
||||
|
||||
|
||||
def get_spatial_pad() -> int:
|
||||
return SPTIAL_PAD
|
||||
|
||||
|
||||
def set_temporal_pad(dim_size: int):
|
||||
sp_size = get_sequence_parallel_size()
|
||||
pad = (sp_size - (dim_size % sp_size)) % sp_size
|
||||
global TEMPORAL_PAD
|
||||
TEMPORAL_PAD = pad
|
||||
|
||||
|
||||
def get_temporal_pad() -> int:
|
||||
return TEMPORAL_PAD
|
||||
def get_pad(name) -> int:
|
||||
return PAD_DICT[name]
|
||||
|
||||
|
||||
def all_to_all_with_pad(
|
||||
@ -418,3 +404,17 @@ def all_to_all_with_pad(
|
||||
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
|
||||
|
||||
return input_
|
||||
|
||||
|
||||
def split_from_second_dim(x, batch_size, parallel_group):
|
||||
x = x.view(batch_size, -1, *x.shape[1:])
|
||||
x = split_sequence(x, parallel_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
|
||||
x = x.reshape(-1, *x.shape[2:])
|
||||
return x
|
||||
|
||||
|
||||
def gather_from_second_dim(x, batch_size, parallel_group):
|
||||
x = x.view(batch_size, -1, *x.shape[1:])
|
||||
x = gather_sequence(x, parallel_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
|
||||
x = x.reshape(-1, *x.shape[2:])
|
||||
return x
|
||||
|
||||
@ -66,7 +66,7 @@ class VideoSysEngine:
|
||||
|
||||
# TODO: add more options here for pipeline, or wrap all options into config
|
||||
def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None):
|
||||
videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method, seed=42)
|
||||
videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method)
|
||||
|
||||
pipeline = pipeline_cls(self.config)
|
||||
return pipeline
|
||||
|
||||
@ -6,7 +6,6 @@ PAB_MANAGER = None
|
||||
class PABConfig:
|
||||
def __init__(
|
||||
self,
|
||||
steps: int,
|
||||
cross_broadcast: bool = False,
|
||||
cross_threshold: list = None,
|
||||
cross_range: int = None,
|
||||
@ -20,7 +19,7 @@ class PABConfig:
|
||||
mlp_spatial_broadcast_config: dict = None,
|
||||
mlp_temporal_broadcast_config: dict = None,
|
||||
):
|
||||
self.steps = steps
|
||||
self.steps = None
|
||||
|
||||
self.cross_broadcast = cross_broadcast
|
||||
self.cross_threshold = cross_threshold
|
||||
@ -45,7 +44,7 @@ class PABManager:
|
||||
def __init__(self, config: PABConfig):
|
||||
self.config: PABConfig = config
|
||||
|
||||
init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
|
||||
init_prompt = f"Init Pyramid Attention Broadcast."
|
||||
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
|
||||
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
|
||||
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
|
||||
@ -78,7 +77,7 @@ class PABManager:
|
||||
count = (count + 1) % self.config.steps
|
||||
return flag, count
|
||||
|
||||
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
|
||||
def if_broadcast_spatial(self, timestep: int, count: int):
|
||||
if (
|
||||
self.config.spatial_broadcast
|
||||
and (timestep is not None)
|
||||
@ -213,10 +212,10 @@ def if_broadcast_temporal(timestep: int, count: int):
|
||||
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
|
||||
|
||||
|
||||
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
|
||||
def if_broadcast_spatial(timestep: int, count: int):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
|
||||
return PAB_MANAGER.if_broadcast_spatial(timestep, count)
|
||||
|
||||
|
||||
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
||||
|
||||
@ -1,14 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from videosys.utils.logging import init_dist_logger, logger
|
||||
from videosys.utils.utils import set_seed
|
||||
|
||||
PARALLEL_MANAGER = None
|
||||
|
||||
|
||||
class ParallelManager(ProcessGroupMesh):
|
||||
@ -21,71 +16,28 @@ class ParallelManager(ProcessGroupMesh):
|
||||
self.dp_rank = dist.get_rank(self.dp_group)
|
||||
|
||||
self.cp_size = cp_size
|
||||
self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
|
||||
self.cp_rank = dist.get_rank(self.cp_group)
|
||||
if cp_size > 1:
|
||||
self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
|
||||
self.cp_rank = dist.get_rank(self.cp_group)
|
||||
else:
|
||||
self.cp_group = None
|
||||
self.cp_rank = None
|
||||
|
||||
self.sp_size = sp_size
|
||||
self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
|
||||
self.sp_rank = dist.get_rank(self.sp_group)
|
||||
self.enable_sp = sp_size > 1
|
||||
if sp_size > 1:
|
||||
self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
|
||||
self.sp_rank = dist.get_rank(self.sp_group)
|
||||
else:
|
||||
self.sp_group = None
|
||||
self.sp_rank = None
|
||||
|
||||
logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}")
|
||||
|
||||
|
||||
def set_parallel_manager(dp_size, cp_size, sp_size):
|
||||
global PARALLEL_MANAGER
|
||||
PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size)
|
||||
|
||||
|
||||
def get_data_parallel_group():
|
||||
return PARALLEL_MANAGER.dp_group
|
||||
|
||||
|
||||
def get_data_parallel_size():
|
||||
return PARALLEL_MANAGER.dp_size
|
||||
|
||||
|
||||
def get_data_parallel_rank():
|
||||
return PARALLEL_MANAGER.dp_rank
|
||||
|
||||
|
||||
def get_sequence_parallel_group():
|
||||
return PARALLEL_MANAGER.sp_group
|
||||
|
||||
|
||||
def get_sequence_parallel_size():
|
||||
return PARALLEL_MANAGER.sp_size
|
||||
|
||||
|
||||
def get_sequence_parallel_rank():
|
||||
return PARALLEL_MANAGER.sp_rank
|
||||
|
||||
|
||||
def get_cfg_parallel_group():
|
||||
return PARALLEL_MANAGER.cp_group
|
||||
|
||||
|
||||
def get_cfg_parallel_size():
|
||||
return PARALLEL_MANAGER.cp_size
|
||||
|
||||
|
||||
def enable_sequence_parallel():
|
||||
if PARALLEL_MANAGER is None:
|
||||
return False
|
||||
return PARALLEL_MANAGER.enable_sp
|
||||
|
||||
|
||||
def get_parallel_manager():
|
||||
return PARALLEL_MANAGER
|
||||
|
||||
|
||||
def initialize(
|
||||
rank=0,
|
||||
world_size=1,
|
||||
init_method=None,
|
||||
seed: Optional[int] = None,
|
||||
sp_size: Optional[int] = None,
|
||||
enable_cp: bool = False,
|
||||
):
|
||||
if not dist.is_initialized():
|
||||
try:
|
||||
@ -97,24 +49,3 @@ def initialize(
|
||||
init_dist_logger()
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# init sequence parallel
|
||||
if sp_size is None:
|
||||
sp_size = dist.get_world_size()
|
||||
dp_size = 1
|
||||
else:
|
||||
assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size"
|
||||
dp_size = dist.get_world_size() // sp_size
|
||||
|
||||
# update cfg parallel
|
||||
# NOTE: enable cp parallel will be slower. disable it for now.
|
||||
if False and enable_cp and sp_size % 2 == 0:
|
||||
sp_size = sp_size // 2
|
||||
cp_size = 2
|
||||
else:
|
||||
cp_size = 1
|
||||
|
||||
set_parallel_manager(dp_size, cp_size, sp_size)
|
||||
|
||||
if seed is not None:
|
||||
set_seed(seed + get_data_parallel_rank())
|
||||
|
||||
@ -13,9 +13,10 @@ class VideoSysPipeline(DiffusionPipeline):
|
||||
|
||||
@staticmethod
|
||||
def set_eval_and_device(device: torch.device, *modules):
|
||||
for module in modules:
|
||||
module.eval()
|
||||
module.to(device)
|
||||
modules = list(modules)
|
||||
for i in range(len(modules)):
|
||||
modules[i] = modules[i].eval()
|
||||
modules[i] = modules[i].to(device)
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, *args, **kwargs):
|
||||
|
||||
@ -694,7 +694,10 @@ class VideoAutoencoderPipeline(PreTrainedModel):
|
||||
else:
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, decode_only=False, **kwargs):
|
||||
if decode_only:
|
||||
return self.decode(x, **kwargs)
|
||||
|
||||
assert self.cal_loss, "This method is only available when cal_loss is True"
|
||||
z, posterior, x_z = self.encode(x)
|
||||
x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2])
|
||||
|
||||
1643
videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v110.py
Normal file
1643
videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v110.py
Normal file
File diff suppressed because it is too large
Load Diff
1139
videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v120.py
Normal file
1139
videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v120.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,12 +1,21 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from diffusers.models.attention import Attention
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.amp import autocast
|
||||
|
||||
from videosys.models.modules.normalization import LlamaRMSNorm
|
||||
from videosys.core.comm import all_to_all_with_pad, get_pad, set_pad
|
||||
from videosys.core.pab_mgr import enable_pab, if_broadcast_cross, if_broadcast_spatial, if_broadcast_temporal
|
||||
from videosys.models.modules.normalization import LlamaRMSNorm, VchitectSpatialNorm
|
||||
from videosys.utils.logging import logger
|
||||
|
||||
|
||||
class OpenSoraAttention(nn.Module):
|
||||
@ -203,3 +212,634 @@ class _SeqLenInfo:
|
||||
seqstart=seqstart,
|
||||
seqstart_py=seqstart_py,
|
||||
)
|
||||
|
||||
|
||||
class VchitectAttention(nn.Module):
|
||||
r"""
|
||||
A cross attention layer.
|
||||
|
||||
Parameters:
|
||||
query_dim (`int`):
|
||||
The number of channels in the query.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
||||
heads (`int`, *optional*, defaults to 8):
|
||||
The number of heads to use for multi-head attention.
|
||||
dim_head (`int`, *optional*, defaults to 64):
|
||||
The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability to use.
|
||||
bias (`bool`, *optional*, defaults to False):
|
||||
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
||||
upcast_attention (`bool`, *optional*, defaults to False):
|
||||
Set to `True` to upcast the attention computation to `float32`.
|
||||
upcast_softmax (`bool`, *optional*, defaults to False):
|
||||
Set to `True` to upcast the softmax computation to `float32`.
|
||||
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
||||
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
||||
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the group norm in the cross attention.
|
||||
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
||||
norm_num_groups (`int`, *optional*, defaults to `None`):
|
||||
The number of groups to use for the group norm in the attention.
|
||||
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the spatial normalization.
|
||||
out_bias (`bool`, *optional*, defaults to `True`):
|
||||
Set to `True` to use a bias in the output linear layer.
|
||||
scale_qk (`bool`, *optional*, defaults to `True`):
|
||||
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
||||
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
||||
`added_kv_proj_dim` is not `None`.
|
||||
eps (`float`, *optional*, defaults to 1e-5):
|
||||
An additional value added to the denominator in group normalization that is used for numerical stability.
|
||||
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
||||
A factor to rescale the output by dividing it with this value.
|
||||
residual_connection (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` to add the residual connection to the output.
|
||||
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` if the attention block is loaded from a deprecated state dict.
|
||||
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
||||
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
|
||||
`AttnProcessor` otherwise.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
cross_attention_norm: Optional[str] = None,
|
||||
cross_attention_norm_num_groups: int = 32,
|
||||
qk_norm: Optional[str] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
spatial_norm_dim: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
scale_qk: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
eps: float = 1e-5,
|
||||
rescale_output_factor: float = 1.0,
|
||||
residual_connection: bool = False,
|
||||
_from_deprecated_attn_block: bool = False,
|
||||
processor: Optional[AttnProcessor] = None,
|
||||
out_dim: int = None,
|
||||
context_pre_only: bool = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.residual_connection = residual_connection
|
||||
self.dropout = dropout
|
||||
self.fused_projections = False
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
# we make use of this private variable to know whether this class is loaded
|
||||
# with an deprecated state dict so that we can convert it on the fly
|
||||
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
self.sliceable_head_dim = heads
|
||||
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
||||
raise ValueError(
|
||||
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
||||
)
|
||||
|
||||
if norm_num_groups is not None:
|
||||
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
||||
else:
|
||||
self.group_norm = None
|
||||
|
||||
if spatial_norm_dim is not None:
|
||||
self.spatial_norm = VchitectSpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
||||
else:
|
||||
self.spatial_norm = None
|
||||
|
||||
if qk_norm is None:
|
||||
self.norm_q = None
|
||||
self.norm_k = None
|
||||
elif qk_norm == "layer_norm":
|
||||
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
|
||||
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
|
||||
else:
|
||||
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
|
||||
|
||||
if cross_attention_norm is None:
|
||||
self.norm_cross = None
|
||||
elif cross_attention_norm == "layer_norm":
|
||||
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
||||
elif cross_attention_norm == "group_norm":
|
||||
if self.added_kv_proj_dim is not None:
|
||||
# The given `encoder_hidden_states` are initially of shape
|
||||
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
||||
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
||||
# before the projection, so we need to use `added_kv_proj_dim` as
|
||||
# the number of channels for the group norm.
|
||||
norm_cross_num_channels = added_kv_proj_dim
|
||||
else:
|
||||
norm_cross_num_channels = self.cross_attention_dim
|
||||
|
||||
self.norm_cross = nn.GroupNorm(
|
||||
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
||||
)
|
||||
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
self.to_q_cross = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_q_temp = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k_temp = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
self.to_v_temp = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
else:
|
||||
self.to_k_temp = None
|
||||
self.to_v_temp = None
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
if self.context_pre_only is not None:
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
self.to_out_temporal = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
|
||||
nn.init.constant_(self.to_out_temporal.weight, 0)
|
||||
nn.init.constant_(self.to_out_temporal.bias, 0)
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
|
||||
self.to_add_out_temporal = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
|
||||
nn.init.constant_(self.to_add_out_temporal.weight, 0)
|
||||
nn.init.constant_(self.to_add_out_temporal.bias, 0)
|
||||
|
||||
self.to_out_context = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
|
||||
nn.init.constant_(self.to_out_context.weight, 0)
|
||||
nn.init.constant_(self.to_out_context.bias, 0)
|
||||
|
||||
# set attention processor
|
||||
self.set_processor(processor)
|
||||
|
||||
# parallel
|
||||
self.parallel_manager = None
|
||||
|
||||
# pab
|
||||
self.spatial_count = 0
|
||||
self.last_spatial = None
|
||||
self.temporal_count = 0
|
||||
self.last_temporal = None
|
||||
self.cross_count = 0
|
||||
self.last_cross = None
|
||||
|
||||
def set_processor(self, processor: "AttnProcessor") -> None:
|
||||
r"""
|
||||
Set the attention processor to use.
|
||||
|
||||
Args:
|
||||
processor (`AttnProcessor`):
|
||||
The attention processor to use.
|
||||
"""
|
||||
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
||||
# pop `processor` from `self._modules`
|
||||
if (
|
||||
hasattr(self, "processor")
|
||||
and isinstance(self.processor, torch.nn.Module)
|
||||
and not isinstance(processor, torch.nn.Module)
|
||||
):
|
||||
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
||||
self._modules.pop("processor")
|
||||
self.processor = processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
full_seqlen: Optional[int] = None,
|
||||
Frame: Optional[int] = None,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
The forward method of the `Attention` class.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
The hidden states of the query.
|
||||
encoder_hidden_states (`torch.Tensor`, *optional*):
|
||||
The hidden states of the encoder.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
The attention mask to use. If `None`, no mask is applied.
|
||||
**cross_attention_kwargs:
|
||||
Additional keyword arguments to pass along to the cross attention.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The output of the attention layer.
|
||||
"""
|
||||
# The `Attention` class can call different attention processors / attention functions
|
||||
# here we simply pass along all tensors to the selected processor class
|
||||
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
||||
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {"ip_adapter_masks"}
|
||||
unused_kwargs = [
|
||||
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
||||
]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
||||
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
full_seqlen=full_seqlen,
|
||||
Frame=Frame,
|
||||
timestep=timestep,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse_projections(self, fuse=True):
|
||||
device = self.to_q.weight.data.device
|
||||
dtype = self.to_q.weight.data.dtype
|
||||
|
||||
if not self.is_cross_attention:
|
||||
# fetch weight matrices.
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
# create a new single projection layer and copy over the weights.
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_qkv.weight.copy_(concatenated_weights)
|
||||
if self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
else:
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_kv.weight.copy_(concatenated_weights)
|
||||
if self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_kv.bias.copy_(concatenated_bias)
|
||||
|
||||
self.fused_projections = fuse
|
||||
|
||||
|
||||
class VchitectAttnProcessor:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
@autocast("cuda", enabled=False)
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
):
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
def spatial_attn(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
batch_size,
|
||||
head_dim,
|
||||
):
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# attention
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
xq, xk = query.to(value.dtype), key.to(value.dtype)
|
||||
|
||||
query_spatial, key_spatial, value_spatial = xq, xk, value
|
||||
query_spatial = query_spatial.transpose(1, 2)
|
||||
key_spatial = key_spatial.transpose(1, 2)
|
||||
value_spatial = value_spatial.transpose(1, 2)
|
||||
|
||||
hidden_states = hidden_states = F.scaled_dot_product_attention(
|
||||
query_spatial, key_spatial, value_spatial, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def temporal_attention(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
residual,
|
||||
batch_size,
|
||||
batchsize,
|
||||
Frame,
|
||||
head_dim,
|
||||
freqs_cis,
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
):
|
||||
query_t = attn.to_q_temp(hidden_states)
|
||||
key_t = attn.to_k_temp(hidden_states)
|
||||
value_t = attn.to_v_temp(hidden_states)
|
||||
|
||||
query_t = torch.cat([query_t, encoder_hidden_states_query_proj], dim=1)
|
||||
key_t = torch.cat([key_t, encoder_hidden_states_key_proj], dim=1)
|
||||
value_t = torch.cat([value_t, encoder_hidden_states_value_proj], dim=1)
|
||||
|
||||
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
query_t, key_t = query_t.to(value_t.dtype), key_t.to(value_t.dtype)
|
||||
|
||||
if attn.parallel_manager.sp_size > 1:
|
||||
func = lambda x: self.dynamic_switch(attn, x, batchsize, to_spatial_shard=True)
|
||||
query_t, key_t, value_t = map(func, [query_t, key_t, value_t])
|
||||
|
||||
func = lambda x: rearrange(x, "(B T) S H C -> (B S) T H C", T=Frame, B=batchsize)
|
||||
xq_gather_temporal, xk_gather_temporal, xv_gather_temporal = map(func, [query_t, key_t, value_t])
|
||||
|
||||
freqs_cis_temporal = freqs_cis[: xq_gather_temporal.shape[1], :]
|
||||
xq_gather_temporal, xk_gather_temporal = self.apply_rotary_emb(
|
||||
xq_gather_temporal, xk_gather_temporal, freqs_cis=freqs_cis_temporal
|
||||
)
|
||||
|
||||
xq_gather_temporal = xq_gather_temporal.transpose(1, 2)
|
||||
xk_gather_temporal = xk_gather_temporal.transpose(1, 2)
|
||||
xv_gather_temporal = xv_gather_temporal.transpose(1, 2)
|
||||
|
||||
batch_size_temp = xv_gather_temporal.shape[0]
|
||||
hidden_states_temp = F.scaled_dot_product_attention(
|
||||
xq_gather_temporal, xk_gather_temporal, xv_gather_temporal, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_temp = hidden_states_temp.transpose(1, 2).reshape(batch_size_temp, -1, attn.heads * head_dim)
|
||||
hidden_states_temp = hidden_states_temp.to(value_t.dtype)
|
||||
hidden_states_temp = rearrange(hidden_states_temp, "(B S) T C -> (B T) S C", T=Frame, B=batchsize)
|
||||
if attn.parallel_manager.sp_size > 1:
|
||||
hidden_states_temp = self.dynamic_switch(attn, hidden_states_temp, batchsize, to_spatial_shard=False)
|
||||
|
||||
hidden_states_temporal, encoder_hidden_states_temporal = (
|
||||
hidden_states_temp[:, : residual.shape[1]],
|
||||
hidden_states_temp[:, residual.shape[1] :],
|
||||
)
|
||||
hidden_states_temporal = attn.to_out_temporal(hidden_states_temporal)
|
||||
return hidden_states_temporal, encoder_hidden_states_temporal
|
||||
|
||||
def cross_attention(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
batch_size,
|
||||
head_dim,
|
||||
cur_frame,
|
||||
batchsize,
|
||||
):
|
||||
query_cross = attn.to_q_cross(hidden_states)
|
||||
query_cross = torch.cat([query_cross, encoder_hidden_states_query_proj], dim=1)
|
||||
|
||||
key_y = encoder_hidden_states_key_proj[0].unsqueeze(0)
|
||||
value_y = encoder_hidden_states_value_proj[0].unsqueeze(0)
|
||||
|
||||
query_y = query_cross.view(batch_size, -1, attn.heads, head_dim)
|
||||
key_y = key_y.view(batchsize, -1, attn.heads, head_dim)
|
||||
value_y = value_y.view(batchsize, -1, attn.heads, head_dim)
|
||||
|
||||
query_y = rearrange(query_y, "(B T) S H C -> B (S T) H C", T=cur_frame, B=batchsize)
|
||||
|
||||
query_y = query_y.transpose(1, 2)
|
||||
key_y = key_y.transpose(1, 2)
|
||||
value_y = value_y.transpose(1, 2)
|
||||
|
||||
cross_output = F.scaled_dot_product_attention(query_y, key_y, value_y, dropout_p=0.0, is_causal=False)
|
||||
cross_output = cross_output.transpose(1, 2).reshape(batchsize, -1, attn.heads * head_dim)
|
||||
cross_output = cross_output.to(query_cross.dtype)
|
||||
|
||||
cross_output = rearrange(cross_output, "B (S T) C -> (B T) S C", T=cur_frame, B=batchsize)
|
||||
cross_output = attn.to_out_context(cross_output)
|
||||
return cross_output
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
full_seqlen: Optional[int] = None,
|
||||
Frame: Optional[int] = None,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
inner_dim = encoder_hidden_states_key_proj.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
batchsize = full_seqlen // Frame
|
||||
# same as Frame if shard, otherwise sharded frame
|
||||
cur_frame = batch_size // batchsize
|
||||
|
||||
# temporal attention
|
||||
if enable_pab():
|
||||
broadcast_temporal, attn.temporal_count = if_broadcast_temporal(int(timestep[0]), attn.temporal_count)
|
||||
if enable_pab() and broadcast_temporal:
|
||||
hidden_states_temporal, encoder_hidden_states_temporal = attn.last_temporal
|
||||
else:
|
||||
hidden_states_temporal, encoder_hidden_states_temporal = self.temporal_attention(
|
||||
attn,
|
||||
hidden_states,
|
||||
residual,
|
||||
batch_size,
|
||||
batchsize,
|
||||
Frame,
|
||||
head_dim,
|
||||
freqs_cis,
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
)
|
||||
if enable_pab():
|
||||
attn.last_temporal = (hidden_states_temporal, encoder_hidden_states_temporal)
|
||||
|
||||
# cross attn
|
||||
if enable_pab():
|
||||
broadcast_cross, attn.cross_count = if_broadcast_cross(int(timestep[0]), attn.cross_count)
|
||||
if enable_pab() and broadcast_cross:
|
||||
cross_output = attn.last_cross
|
||||
else:
|
||||
cross_output = self.cross_attention(
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
batch_size,
|
||||
head_dim,
|
||||
cur_frame,
|
||||
batchsize,
|
||||
)
|
||||
if enable_pab():
|
||||
attn.last_cross = cross_output
|
||||
|
||||
# spatial attn
|
||||
if enable_pab():
|
||||
broadcast_spatial, attn.spatial_count = if_broadcast_spatial(int(timestep[0]), attn.spatial_count)
|
||||
if enable_pab() and broadcast_spatial:
|
||||
hidden_states = attn.last_spatial
|
||||
else:
|
||||
hidden_states = self.spatial_attn(
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
batch_size,
|
||||
head_dim,
|
||||
)
|
||||
if enable_pab():
|
||||
attn.last_spatial = hidden_states
|
||||
|
||||
# processs attention outputs.
|
||||
hidden_states = hidden_states * 1.1 + cross_output
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : residual.shape[1]],
|
||||
hidden_states[:, residual.shape[1] :],
|
||||
)
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if cur_frame == 1:
|
||||
hidden_states_temporal = hidden_states_temporal * 0
|
||||
hidden_states = hidden_states + hidden_states_temporal
|
||||
|
||||
# encoder
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
encoder_hidden_states_temporal = attn.to_add_out_temporal(encoder_hidden_states_temporal)
|
||||
if cur_frame == 1:
|
||||
encoder_hidden_states_temporal = encoder_hidden_states_temporal * 0
|
||||
encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_temporal
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
def dynamic_switch(self, attn, x, batchsize, to_spatial_shard: bool):
|
||||
if to_spatial_shard:
|
||||
scatter_dim, gather_dim = 2, 1
|
||||
set_pad("spatial", x.shape[1], attn.parallel_manager.sp_group)
|
||||
scatter_pad = get_pad("spatial")
|
||||
gather_pad = get_pad("temporal")
|
||||
else:
|
||||
scatter_dim, gather_dim = 1, 2
|
||||
scatter_pad = get_pad("temporal")
|
||||
gather_pad = get_pad("spatial")
|
||||
|
||||
x = rearrange(x, "(B T) S ... -> B T S ...", B=batchsize)
|
||||
x = all_to_all_with_pad(
|
||||
x,
|
||||
attn.parallel_manager.sp_group,
|
||||
scatter_dim=scatter_dim,
|
||||
gather_dim=gather_dim,
|
||||
scatter_pad=scatter_pad,
|
||||
gather_pad=gather_pad,
|
||||
)
|
||||
x = rearrange(x, "B T ... -> (B T) ...")
|
||||
return x
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
@ -100,3 +101,32 @@ class AdaLayerNorm(nn.Module):
|
||||
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class VchitectSpatialNorm(nn.Module):
|
||||
"""
|
||||
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
|
||||
|
||||
Args:
|
||||
f_channels (`int`):
|
||||
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
||||
zq_channels (`int`):
|
||||
The number of channels for the quantized vector as described in the paper.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
f_channels: int,
|
||||
zq_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
||||
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
||||
f_size = f.shape[-2:]
|
||||
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
@ -22,15 +22,9 @@ from diffusers.utils import is_torch_version
|
||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||
from torch import nn
|
||||
|
||||
from videosys.core.comm import all_to_all_comm, gather_sequence, get_spatial_pad, set_spatial_pad, split_sequence
|
||||
from videosys.core.comm import all_to_all_comm, gather_sequence, get_pad, set_pad, split_sequence
|
||||
from videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
|
||||
from videosys.core.parallel_mgr import (
|
||||
enable_sequence_parallel,
|
||||
get_cfg_parallel_group,
|
||||
get_cfg_parallel_size,
|
||||
get_sequence_parallel_group,
|
||||
get_sequence_parallel_size,
|
||||
)
|
||||
from videosys.core.parallel_mgr import ParallelManager
|
||||
from videosys.models.modules.embeddings import apply_rotary_emb
|
||||
from videosys.utils.utils import batch_func
|
||||
|
||||
@ -48,6 +42,49 @@ class CogVideoXAttnProcessor2_0:
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def _remove_extra_encoder(self, hidden_states, text_seq_length, attn):
|
||||
# current layout is [text, 1/n seq, text, 1/n seq, ...]
|
||||
# we want to remove the all the text info [text, seq]
|
||||
sp_size = attn.parallel_manager.sp_size
|
||||
split_seq = hidden_states.split(hidden_states.size(2) // sp_size, dim=2)
|
||||
encoder_hidden_states = split_seq[0][:, :, :text_seq_length]
|
||||
new_seq = [encoder_hidden_states]
|
||||
for i in range(sp_size):
|
||||
new_seq.append(split_seq[i][:, :, text_seq_length:])
|
||||
hidden_states = torch.cat(new_seq, dim=2)
|
||||
|
||||
# remove padding added when all2all
|
||||
# if pad is removed earlier than this
|
||||
# the split size will be wrong
|
||||
pad = get_pad("pad")
|
||||
if pad > 0:
|
||||
hidden_states = hidden_states.narrow(2, 0, hidden_states.size(2) - pad)
|
||||
return hidden_states
|
||||
|
||||
def _add_extra_encoder(self, hidden_states, text_seq_length, attn):
|
||||
# add padding for split and later all2all
|
||||
# if pad is removed later than this
|
||||
# the split size will be wrong
|
||||
pad = get_pad("pad")
|
||||
if pad > 0:
|
||||
pad_shape = list(hidden_states.shape)
|
||||
pad_shape[1] = pad
|
||||
pad_tensor = torch.zeros(pad_shape, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
hidden_states = torch.cat([hidden_states, pad_tensor], dim=1)
|
||||
|
||||
# current layout is [text, seq]
|
||||
# we want to add the extra encoder info [text, 1/n seq, text, 1/n seq, ...]
|
||||
sp_size = attn.parallel_manager.sp_size
|
||||
encoder = hidden_states[:, :text_seq_length]
|
||||
seq = hidden_states[:, text_seq_length:]
|
||||
seq = seq.split(seq.size(1) // sp_size, dim=1)
|
||||
new_seq = []
|
||||
for i in range(sp_size):
|
||||
new_seq.append(encoder)
|
||||
new_seq.append(seq[i])
|
||||
hidden_states = torch.cat(new_seq, dim=1)
|
||||
return hidden_states
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
@ -72,13 +109,15 @@ class CogVideoXAttnProcessor2_0:
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
if enable_sequence_parallel():
|
||||
if attn.parallel_manager.sp_size > 1:
|
||||
assert (
|
||||
attn.heads % get_sequence_parallel_size() == 0
|
||||
), f"Number of heads {attn.heads} must be divisible by sequence parallel size {get_sequence_parallel_size()}"
|
||||
attn_heads = attn.heads // get_sequence_parallel_size()
|
||||
attn.heads % attn.parallel_manager.sp_size == 0
|
||||
), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}"
|
||||
attn_heads = attn.heads // attn.parallel_manager.sp_size
|
||||
# normally we operate pad for every all2all. but for more convient implementation
|
||||
# we move pad operation to encoder add and remove in cogvideo
|
||||
query, key, value = map(
|
||||
lambda x: all_to_all_comm(x, get_sequence_parallel_group(), scatter_dim=2, gather_dim=1),
|
||||
lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1),
|
||||
[query, key, value],
|
||||
)
|
||||
else:
|
||||
@ -96,6 +135,13 @@ class CogVideoXAttnProcessor2_0:
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if attn.parallel_manager.sp_size > 1:
|
||||
# remove extra encoder for attention
|
||||
query, key, value = map(
|
||||
lambda x: self._remove_extra_encoder(x, text_seq_length, attn),
|
||||
[query, key, value],
|
||||
)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
emb_len = image_rotary_emb[0].shape[0]
|
||||
@ -113,77 +159,10 @@ class CogVideoXAttnProcessor2_0:
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
|
||||
|
||||
if enable_sequence_parallel():
|
||||
hidden_states = all_to_all_comm(hidden_states, get_sequence_parallel_group(), scatter_dim=1, gather_dim=2)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedCogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
if attn.parallel_manager.sp_size > 1:
|
||||
# add extra encoder for all_to_all
|
||||
hidden_states = self._add_extra_encoder(hidden_states, text_seq_length, attn)
|
||||
hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
@ -266,6 +245,9 @@ class CogVideoXBlock(nn.Module):
|
||||
processor=CogVideoXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# parallel
|
||||
self.attn1.parallel_manager = None
|
||||
|
||||
# 2. Feed Forward
|
||||
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
@ -300,7 +282,7 @@ class CogVideoXBlock(nn.Module):
|
||||
|
||||
# attention
|
||||
if enable_pab():
|
||||
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
|
||||
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count)
|
||||
if enable_pab() and broadcast_attn:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
|
||||
else:
|
||||
@ -474,6 +456,23 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# parallel
|
||||
self.parallel_manager = None
|
||||
|
||||
def enable_parallel(self, dp_size, sp_size, enable_cp):
|
||||
# update cfg parallel
|
||||
if enable_cp and sp_size % 2 == 0:
|
||||
sp_size = sp_size // 2
|
||||
cp_size = 2
|
||||
else:
|
||||
cp_size = 1
|
||||
|
||||
self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size)
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if hasattr(module, "parallel_manager"):
|
||||
module.parallel_manager = self.parallel_manager
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
@ -485,9 +484,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
all_timesteps=None,
|
||||
):
|
||||
if get_cfg_parallel_size() > 1:
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
@ -495,7 +493,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
timestep_cond,
|
||||
image_rotary_emb,
|
||||
) = batch_func(
|
||||
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
|
||||
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep,
|
||||
@ -530,9 +528,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
if enable_sequence_parallel():
|
||||
set_spatial_pad(hidden_states.shape[1])
|
||||
hidden_states = split_sequence(hidden_states, get_sequence_parallel_group(), dim=1, pad=get_spatial_pad())
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
|
||||
hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||
|
||||
# 4. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
@ -562,8 +560,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
timestep=timesteps if enable_pab() else None,
|
||||
)
|
||||
|
||||
if enable_sequence_parallel():
|
||||
hidden_states = gather_sequence(hidden_states, get_sequence_parallel_group(), dim=1, pad=get_spatial_pad())
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
@ -583,8 +581,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
if get_cfg_parallel_size() > 1:
|
||||
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
@ -33,15 +33,7 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
from videosys.core.comm import (
|
||||
all_to_all_with_pad,
|
||||
gather_sequence,
|
||||
get_spatial_pad,
|
||||
get_temporal_pad,
|
||||
set_spatial_pad,
|
||||
set_temporal_pad,
|
||||
split_sequence,
|
||||
)
|
||||
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
|
||||
from videosys.core.pab_mgr import (
|
||||
enable_pab,
|
||||
get_mlp_output,
|
||||
@ -51,12 +43,7 @@ from videosys.core.pab_mgr import (
|
||||
if_broadcast_temporal,
|
||||
save_mlp_output,
|
||||
)
|
||||
from videosys.core.parallel_mgr import (
|
||||
enable_sequence_parallel,
|
||||
get_cfg_parallel_group,
|
||||
get_cfg_parallel_size,
|
||||
get_sequence_parallel_group,
|
||||
)
|
||||
from videosys.core.parallel_mgr import ParallelManager
|
||||
from videosys.utils.utils import batch_func
|
||||
|
||||
|
||||
@ -388,9 +375,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
||||
|
||||
if enable_pab():
|
||||
broadcast_spatial, self.spatial_count = if_broadcast_spatial(
|
||||
int(org_timestep[0]), self.spatial_count, self.block_idx
|
||||
)
|
||||
broadcast_spatial, self.spatial_count = if_broadcast_spatial(int(org_timestep[0]), self.spatial_count)
|
||||
|
||||
if enable_pab() and broadcast_spatial:
|
||||
attn_output = self.spatial_last
|
||||
@ -681,6 +666,9 @@ class BasicTransformerBlock_(nn.Module):
|
||||
self.block_idx = block_idx
|
||||
self.count = 0
|
||||
|
||||
# parallel
|
||||
self.parallel_manager: ParallelManager = None
|
||||
|
||||
def set_last_out(self, last_out: torch.Tensor):
|
||||
self.last_out = last_out
|
||||
|
||||
@ -743,7 +731,7 @@ class BasicTransformerBlock_(nn.Module):
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
if enable_sequence_parallel():
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
norm_hidden_states = self.dynamic_switch(norm_hidden_states, to_spatial_shard=True)
|
||||
|
||||
attn_output = self.attn1(
|
||||
@ -753,7 +741,7 @@ class BasicTransformerBlock_(nn.Module):
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if enable_sequence_parallel():
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
attn_output = self.dynamic_switch(attn_output, to_spatial_shard=False)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
@ -838,15 +826,15 @@ class BasicTransformerBlock_(nn.Module):
|
||||
def dynamic_switch(self, x, to_spatial_shard: bool):
|
||||
if to_spatial_shard:
|
||||
scatter_dim, gather_dim = 0, 1
|
||||
scatter_pad = get_spatial_pad()
|
||||
gather_pad = get_temporal_pad()
|
||||
scatter_pad = get_pad("spatial")
|
||||
gather_pad = get_pad("temporal")
|
||||
else:
|
||||
scatter_dim, gather_dim = 1, 0
|
||||
scatter_pad = get_temporal_pad()
|
||||
gather_pad = get_spatial_pad()
|
||||
scatter_pad = get_pad("temporal")
|
||||
gather_pad = get_pad("spatial")
|
||||
x = all_to_all_with_pad(
|
||||
x,
|
||||
get_sequence_parallel_group(),
|
||||
self.parallel_manager.sp_group,
|
||||
scatter_dim=scatter_dim,
|
||||
gather_dim=gather_dim,
|
||||
scatter_pad=scatter_pad,
|
||||
@ -1133,6 +1121,23 @@ class LatteT2V(ModelMixin, ConfigMixin):
|
||||
temp_pos_embed = self.get_1d_sincos_temp_embed(inner_dim, video_length) # 1152 hidden size
|
||||
self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
|
||||
|
||||
# parallel
|
||||
self.parallel_manager = None
|
||||
|
||||
def enable_parallel(self, dp_size, sp_size, enable_cp):
|
||||
# update cfg parallel
|
||||
if enable_cp and sp_size % 2 == 0:
|
||||
sp_size = sp_size // 2
|
||||
cp_size = 2
|
||||
else:
|
||||
cp_size = 1
|
||||
|
||||
self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size)
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if hasattr(module, "parallel_manager"):
|
||||
module.parallel_manager = self.parallel_manager
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
@ -1191,7 +1196,7 @@ class LatteT2V(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
# 0. Split batch for data parallelism
|
||||
if get_cfg_parallel_size() > 1:
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
(
|
||||
hidden_states,
|
||||
timestep,
|
||||
@ -1201,7 +1206,7 @@ class LatteT2V(ModelMixin, ConfigMixin):
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
) = batch_func(
|
||||
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
|
||||
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
||||
hidden_states,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
@ -1292,14 +1297,14 @@ class LatteT2V(ModelMixin, ConfigMixin):
|
||||
timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
|
||||
timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
|
||||
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(frame + use_image_num)
|
||||
set_spatial_pad(num_patches)
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
|
||||
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
|
||||
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
|
||||
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
|
||||
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
|
||||
temp_pos_embed = split_sequence(
|
||||
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
|
||||
)
|
||||
else:
|
||||
temp_pos_embed = self.temp_pos_embed
|
||||
@ -1420,7 +1425,7 @@ class LatteT2V(ModelMixin, ConfigMixin):
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
if enable_sequence_parallel():
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
|
||||
|
||||
if self.is_input_patches:
|
||||
@ -1452,8 +1457,8 @@ class LatteT2V(ModelMixin, ConfigMixin):
|
||||
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
|
||||
|
||||
# 3. Gather batch for data parallelism
|
||||
if get_cfg_parallel_size() > 1:
|
||||
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
@ -1466,12 +1471,12 @@ class LatteT2V(ModelMixin, ConfigMixin):
|
||||
|
||||
def split_from_second_dim(self, x, batch_size):
|
||||
x = x.view(batch_size, -1, *x.shape[1:])
|
||||
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
|
||||
x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
|
||||
x = x.reshape(-1, *x.shape[2:])
|
||||
return x
|
||||
|
||||
def gather_from_second_dim(self, x, batch_size):
|
||||
x = x.view(batch_size, -1, *x.shape[1:])
|
||||
x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad())
|
||||
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
|
||||
x = x.reshape(-1, *x.shape[2:])
|
||||
return x
|
||||
|
||||
2826
videosys/models/transformers/open_sora_plan_v110_transformer_3d.py
Normal file
2826
videosys/models/transformers/open_sora_plan_v110_transformer_3d.py
Normal file
File diff suppressed because it is too large
Load Diff
2183
videosys/models/transformers/open_sora_plan_v120_transformer_3d.py
Normal file
2183
videosys/models/transformers/open_sora_plan_v120_transformer_3d.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -20,15 +20,7 @@ from timm.models.vision_transformer import Mlp
|
||||
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from videosys.core.comm import (
|
||||
all_to_all_with_pad,
|
||||
gather_sequence,
|
||||
get_spatial_pad,
|
||||
get_temporal_pad,
|
||||
set_spatial_pad,
|
||||
set_temporal_pad,
|
||||
split_sequence,
|
||||
)
|
||||
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
|
||||
from videosys.core.pab_mgr import (
|
||||
enable_pab,
|
||||
get_mlp_output,
|
||||
@ -38,12 +30,7 @@ from videosys.core.pab_mgr import (
|
||||
if_broadcast_temporal,
|
||||
save_mlp_output,
|
||||
)
|
||||
from videosys.core.parallel_mgr import (
|
||||
enable_sequence_parallel,
|
||||
get_cfg_parallel_size,
|
||||
get_data_parallel_group,
|
||||
get_sequence_parallel_group,
|
||||
)
|
||||
from videosys.core.parallel_mgr import ParallelManager
|
||||
from videosys.models.modules.activations import approx_gelu
|
||||
from videosys.models.modules.attentions import OpenSoraAttention, OpenSoraMultiHeadCrossAttention
|
||||
from videosys.models.modules.embeddings import (
|
||||
@ -143,6 +130,9 @@ class STDiT3Block(nn.Module):
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
|
||||
|
||||
# parallel
|
||||
self.parallel_manager: ParallelManager = None
|
||||
|
||||
# pab
|
||||
self.block_idx = block_idx
|
||||
self.attn_count = 0
|
||||
@ -188,9 +178,7 @@ class STDiT3Block(nn.Module):
|
||||
if self.temporal:
|
||||
broadcast_attn, self.attn_count = if_broadcast_temporal(int(timestep[0]), self.attn_count)
|
||||
else:
|
||||
broadcast_attn, self.attn_count = if_broadcast_spatial(
|
||||
int(timestep[0]), self.attn_count, self.block_idx
|
||||
)
|
||||
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count)
|
||||
|
||||
if enable_pab() and broadcast_attn:
|
||||
x_m_s = self.last_attn
|
||||
@ -203,12 +191,12 @@ class STDiT3Block(nn.Module):
|
||||
|
||||
# attention
|
||||
if self.temporal:
|
||||
if enable_sequence_parallel():
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=True)
|
||||
x_m = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S)
|
||||
x_m = self.attn(x_m)
|
||||
x_m = rearrange(x_m, "(B S) T C -> B (T S) C", T=T, S=S)
|
||||
if enable_sequence_parallel():
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=False)
|
||||
else:
|
||||
x_m = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S)
|
||||
@ -287,17 +275,17 @@ class STDiT3Block(nn.Module):
|
||||
def dynamic_switch(self, x, s, t, to_spatial_shard: bool):
|
||||
if to_spatial_shard:
|
||||
scatter_dim, gather_dim = 2, 1
|
||||
scatter_pad = get_spatial_pad()
|
||||
gather_pad = get_temporal_pad()
|
||||
scatter_pad = get_pad("spatial")
|
||||
gather_pad = get_pad("temporal")
|
||||
else:
|
||||
scatter_dim, gather_dim = 1, 2
|
||||
scatter_pad = get_temporal_pad()
|
||||
gather_pad = get_spatial_pad()
|
||||
scatter_pad = get_pad("temporal")
|
||||
gather_pad = get_pad("spatial")
|
||||
|
||||
x = rearrange(x, "b (t s) d -> b t s d", t=t, s=s)
|
||||
x = all_to_all_with_pad(
|
||||
x,
|
||||
get_sequence_parallel_group(),
|
||||
self.parallel_manager.sp_group,
|
||||
scatter_dim=scatter_dim,
|
||||
gather_dim=gather_dim,
|
||||
scatter_pad=scatter_pad,
|
||||
@ -449,6 +437,24 @@ class STDiT3(PreTrainedModel):
|
||||
for param in self.y_embedder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# parallel
|
||||
self.parallel_manager: ParallelManager = None
|
||||
|
||||
def enable_parallel(self, dp_size, sp_size, enable_cp):
|
||||
# update cfg parallel
|
||||
if enable_cp and sp_size % 2 == 0:
|
||||
sp_size = sp_size // 2
|
||||
cp_size = 2
|
||||
else:
|
||||
cp_size = 1
|
||||
|
||||
self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size)
|
||||
|
||||
for name, module in self.named_modules():
|
||||
if "spatial_blocks" in name or "temporal_blocks" in name:
|
||||
if hasattr(module, "parallel_manager"):
|
||||
module.parallel_manager = self.parallel_manager
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
@ -501,9 +507,14 @@ class STDiT3(PreTrainedModel):
|
||||
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
|
||||
):
|
||||
# === Split batch ===
|
||||
if get_cfg_parallel_size() > 1:
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
x, timestep, y, x_mask, mask = batch_func(
|
||||
partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
|
||||
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
||||
x,
|
||||
timestep,
|
||||
y,
|
||||
x_mask,
|
||||
mask,
|
||||
)
|
||||
|
||||
dtype = self.x_embedder.proj.weight.dtype
|
||||
@ -547,14 +558,14 @@ class STDiT3(PreTrainedModel):
|
||||
x = x + pos_emb
|
||||
|
||||
# shard over the sequence dim if sp is enabled
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(T)
|
||||
set_spatial_pad(S)
|
||||
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
set_pad("temporal", T, self.parallel_manager.sp_group)
|
||||
set_pad("spatial", S, self.parallel_manager.sp_group)
|
||||
x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
|
||||
T = x.shape[1]
|
||||
x_mask_org = x_mask
|
||||
x_mask = split_sequence(
|
||||
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
|
||||
)
|
||||
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
@ -589,9 +600,9 @@ class STDiT3(PreTrainedModel):
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
if enable_sequence_parallel():
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad())
|
||||
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
|
||||
T, S = x.shape[1], x.shape[2]
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
x_mask = x_mask_org
|
||||
@ -604,8 +615,8 @@ class STDiT3(PreTrainedModel):
|
||||
x = x.to(torch.float32)
|
||||
|
||||
# === Gather Output ===
|
||||
if get_cfg_parallel_size() > 1:
|
||||
x = gather_sequence(x, get_data_parallel_group(), dim=0)
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
x = gather_sequence(x, self.parallel_manager.cp_group, dim=0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
644
videosys/models/transformers/vchitect_transformer_3d.py
Normal file
644
videosys/models/transformers/vchitect_transformer_3d.py
Normal file
@ -0,0 +1,644 @@
|
||||
# Adapted from Vchitect
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# Vchitect: https://github.com/Vchitect/Vchitect-2.0
|
||||
# diffusers: https://github.com/huggingface/diffusers
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
|
||||
from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
||||
from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
|
||||
from diffusers.utils import USE_PEFT_BACKEND, deprecate, unscale_lora_layers
|
||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from videosys.core.comm import gather_from_second_dim, set_pad, split_from_second_dim
|
||||
from videosys.core.parallel_mgr import ParallelManager
|
||||
from videosys.models.modules.attentions import VchitectAttention, VchitectAttnProcessor
|
||||
|
||||
|
||||
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
|
||||
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
||||
ff_output = torch.cat(
|
||||
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||
dim=chunk_dim,
|
||||
)
|
||||
return ff_output
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class JointTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
||||
|
||||
Reference: https://arxiv.org/abs/2403.03206
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
||||
processing of `context` conditions.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
|
||||
super().__init__()
|
||||
|
||||
self.context_pre_only = context_pre_only
|
||||
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
|
||||
if context_norm_type == "ada_norm_continous":
|
||||
self.norm1_context = AdaLayerNormContinuous(
|
||||
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
|
||||
)
|
||||
elif context_norm_type == "ada_norm_zero":
|
||||
self.norm1_context = AdaLayerNormZero(dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
|
||||
)
|
||||
processor = VchitectAttnProcessor()
|
||||
self.attn = VchitectAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim // num_attention_heads,
|
||||
heads=num_attention_heads,
|
||||
out_dim=attention_head_dim,
|
||||
context_pre_only=context_pre_only,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
if not context_pre_only:
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
else:
|
||||
self.norm2_context = None
|
||||
self.ff_context = None
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
full_seqlen: int,
|
||||
Frame: int,
|
||||
timestep: int,
|
||||
):
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
if self.context_pre_only:
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
||||
else:
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
freqs_cis=freqs_cis,
|
||||
full_seqlen=full_seqlen,
|
||||
Frame=Frame,
|
||||
timestep=timestep,
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
if self.context_pre_only:
|
||||
encoder_hidden_states = None
|
||||
else:
|
||||
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
context_ff_output = _chunked_feed_forward(
|
||||
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
|
||||
)
|
||||
else:
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input.
|
||||
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
||||
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
final_dropout: bool = False,
|
||||
inner_dim=None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||
if activation_fn == "gelu-approximate":
|
||||
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
||||
elif activation_fn == "geglu":
|
||||
act_fn = GEGLU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
self.net.append(act_fn)
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
|
||||
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VchitectXLTransformerModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
The Transformer model introduced in Stable Diffusion 3.
|
||||
|
||||
Reference: https://arxiv.org/abs/2403.03206
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`): The width of the latent images. This is fixed during training since
|
||||
it is used to learn a number of position embeddings.
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
|
||||
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
||||
out_channels (`int`, defaults to 16): Number of output channels.
|
||||
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
num_layers: int = 18,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 18,
|
||||
joint_attention_dim: int = 4096,
|
||||
caption_projection_dim: int = 1152,
|
||||
pooled_projection_dim: int = 2048,
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
rope_scaling_factor: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
default_out_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=self.config.sample_size,
|
||||
width=self.config.sample_size,
|
||||
patch_size=self.config.patch_size,
|
||||
in_channels=self.config.in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
|
||||
)
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
|
||||
)
|
||||
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
|
||||
# `attention_head_dim` is doubled to account for the mixing.
|
||||
# It needs to crafted when we get the actual checkpoints.
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.inner_dim,
|
||||
context_pre_only=i == num_layers - 1,
|
||||
)
|
||||
for i in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Video param
|
||||
# self.scatter_dim_zero = Identity()
|
||||
self.freqs_cis = VchitectXLTransformerModel.precompute_freqs_cis(
|
||||
self.inner_dim // self.config.num_attention_heads,
|
||||
1000000,
|
||||
theta=1e6,
|
||||
rope_scaling_factor=rope_scaling_factor, # todo max pos embeds
|
||||
)
|
||||
|
||||
# self.vid_token = nn.Parameter(torch.empty(self.inner_dim))
|
||||
|
||||
# parallel
|
||||
self.parallel_manager = None
|
||||
|
||||
def enable_parallel(self, dp_size, sp_size, enable_cp):
|
||||
# update cfg parallel
|
||||
if enable_cp and sp_size % 2 == 0:
|
||||
sp_size = sp_size // 2
|
||||
cp_size = 2
|
||||
else:
|
||||
cp_size = 1
|
||||
|
||||
self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size)
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if hasattr(module, "parallel_manager"):
|
||||
module.parallel_manager = self.parallel_manager
|
||||
|
||||
@staticmethod
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, rope_scaling_factor: float = 1.0):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device, dtype=torch.float)
|
||||
t = t / rope_scaling_factor
|
||||
freqs = torch.outer(t, freqs).float()
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
||||
|
||||
Parameters:
|
||||
chunk_size (`int`, *optional*):
|
||||
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
||||
over each tensor of dim=`dim`.
|
||||
dim (`int`, *optional*, defaults to `0`):
|
||||
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
||||
or dim=1 (sequence length).
|
||||
"""
|
||||
if dim not in [0, 1]:
|
||||
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
||||
|
||||
# By default chunk size is 1
|
||||
chunk_size = chunk_size or 1
|
||||
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, VchitectAttnProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(
|
||||
name: str, module: torch.nn.Module, processors: Dict[str, VchitectAttnProcessor]
|
||||
):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[VchitectAttnProcessor, Dict[str, VchitectAttnProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, VchitectAttention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def patchify_and_embed(self, x):
|
||||
B, F, C, H, W = x.size()
|
||||
x = rearrange(x, "b f c h w -> (b f) c h w")
|
||||
x = self.pos_embed(x) # [B L D]
|
||||
return x, F, [(H, W)] * B
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
pooled_projections: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`VchitectXLTransformerModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states, F_num, _ = self.patchify_and_embed(
|
||||
hidden_states
|
||||
) # takes care of adding positional embeddings too.
|
||||
full_seq = batch_size * F_num
|
||||
|
||||
self.freqs_cis = self.freqs_cis.to(hidden_states.device)
|
||||
freqs_cis = self.freqs_cis
|
||||
# seq_length = hidden_states.size(1)
|
||||
# freqs_cis = self.freqs_cis[:hidden_states.size(1)*F_num]
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
set_pad("temporal", F_num, self.parallel_manager.sp_group)
|
||||
hidden_states = split_from_second_dim(hidden_states, batch_size, self.parallel_manager.sp_group)
|
||||
cur_temb = temb.repeat(hidden_states.shape[0] // batch_size, 1)
|
||||
else:
|
||||
cur_temb = temb.repeat(F_num, 1)
|
||||
|
||||
for block_idx, block in enumerate(self.transformer_blocks):
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=cur_temb,
|
||||
freqs_cis=freqs_cis,
|
||||
full_seqlen=full_seq,
|
||||
Frame=F_num,
|
||||
timestep=timestep,
|
||||
)
|
||||
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
hidden_states = gather_from_second_dim(hidden_states, batch_size, self.parallel_manager.sp_group)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
# hidden_states = hidden_states[:, :-1] #Drop the video token
|
||||
|
||||
# unpatchify
|
||||
patch_size = self.config.patch_size
|
||||
height = height // patch_size
|
||||
width = width // patch_size
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
|
||||
return list(self.transformer_blocks)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained_temporal(cls, pretrained_model_path, torch_dtype, logger, subfolder=None, tp_size=1):
|
||||
import json
|
||||
import os
|
||||
|
||||
if subfolder is not None:
|
||||
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
||||
|
||||
config_file = os.path.join(pretrained_model_path, "config.json")
|
||||
|
||||
with open(config_file, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
config["tp_size"] = tp_size
|
||||
from safetensors.torch import load_file
|
||||
|
||||
model = cls.from_config(config)
|
||||
# model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
||||
|
||||
model_files = [
|
||||
os.path.join(pretrained_model_path, "diffusion_pytorch_model.bin"),
|
||||
os.path.join(pretrained_model_path, "diffusion_pytorch_model.safetensors"),
|
||||
]
|
||||
|
||||
model_file = None
|
||||
|
||||
for fp in model_files:
|
||||
if os.path.exists(fp):
|
||||
model_file = fp
|
||||
|
||||
if not model_file:
|
||||
raise RuntimeError(f"{model_file} does not exist")
|
||||
|
||||
if not os.path.isfile(model_file):
|
||||
raise RuntimeError(f"{model_file} does not exist")
|
||||
|
||||
state_dict = load_file(model_file, device="cpu")
|
||||
m, u = model.load_state_dict(state_dict, strict=False)
|
||||
model = model.to(torch_dtype)
|
||||
|
||||
params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
|
||||
total_params = [p.numel() for n, p in model.named_parameters()]
|
||||
|
||||
if logger is not None:
|
||||
logger.info(f"model_file: {model_file}")
|
||||
logger.info(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
||||
logger.info(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
|
||||
logger.info(f"### Total Parameters: {sum(total_params) / 1e6} M")
|
||||
|
||||
return model
|
||||
@ -13,6 +13,7 @@ import math
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
@ -26,19 +27,17 @@ from videosys.models.transformers.cogvideox_transformer_3d import CogVideoXTrans
|
||||
from videosys.schedulers.scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
|
||||
from videosys.schedulers.scheduling_dpm_cogvideox import CogVideoXDPMScheduler
|
||||
from videosys.utils.logging import logger
|
||||
from videosys.utils.utils import save_video
|
||||
from videosys.utils.utils import save_video, set_seed
|
||||
|
||||
|
||||
class CogVideoXPABConfig(PABConfig):
|
||||
def __init__(
|
||||
self,
|
||||
steps: int = 50,
|
||||
spatial_broadcast: bool = True,
|
||||
spatial_threshold: list = [100, 850],
|
||||
spatial_range: int = 2,
|
||||
):
|
||||
super().__init__(
|
||||
steps=steps,
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
spatial_threshold=spatial_threshold,
|
||||
spatial_range=spatial_range,
|
||||
@ -114,7 +113,7 @@ class CogVideoXConfig:
|
||||
|
||||
|
||||
class CogVideoXPipeline(VideoSysPipeline):
|
||||
_optional_components = ["text_encoder", "tokenizer"]
|
||||
_optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
@ -158,9 +157,6 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
# set eval and device
|
||||
self.set_eval_and_device(self._device, vae, transformer)
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
@ -169,7 +165,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
if config.cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
else:
|
||||
self.set_eval_and_device(self._device, text_encoder)
|
||||
self.set_eval_and_device(self._device, text_encoder, vae, transformer)
|
||||
|
||||
# vae tiling
|
||||
if config.vae_tiling:
|
||||
@ -187,6 +183,31 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
# parallel
|
||||
self._set_parallel()
|
||||
|
||||
def _set_seed(self, seed):
|
||||
if dist.get_world_size() == 1:
|
||||
set_seed(seed)
|
||||
else:
|
||||
set_seed(seed, self.transformer.parallel_manager.dp_rank)
|
||||
|
||||
def _set_parallel(
|
||||
self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
|
||||
):
|
||||
# init sequence parallel
|
||||
if sp_size is None:
|
||||
sp_size = dist.get_world_size()
|
||||
dp_size = 1
|
||||
else:
|
||||
assert (
|
||||
dist.get_world_size() % sp_size == 0
|
||||
), f"world_size {dist.get_world_size()} must be divisible by sp_size"
|
||||
dp_size = dist.get_world_size() // sp_size
|
||||
|
||||
# transformer parallel
|
||||
self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@ -474,6 +495,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
num_frames: int = 49,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
seed: int = -1,
|
||||
guidance_scale: float = 6,
|
||||
use_dynamic_cfg: bool = False,
|
||||
num_videos_per_prompt: int = 1,
|
||||
@ -489,7 +511,6 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 226,
|
||||
verbose=True,
|
||||
) -> Union[VideoSysPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@ -572,6 +593,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
||||
)
|
||||
update_steps(num_inference_steps)
|
||||
self._set_seed(seed)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
@ -653,11 +675,10 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
|
||||
# with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
# for DPM-solver++
|
||||
old_pred_original_sample = None
|
||||
for i, t in progress_wrap(list(enumerate(timesteps))):
|
||||
old_pred_original_sample = None
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
@ -674,7 +695,6 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
all_timesteps=timesteps,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
@ -713,8 +733,8 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
# progress_bar.update()
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
|
||||
@ -29,13 +29,12 @@ from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
|
||||
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
|
||||
from videosys.models.transformers.latte_transformer_3d import LatteT2V
|
||||
from videosys.utils.logging import logger
|
||||
from videosys.utils.utils import save_video
|
||||
from videosys.utils.utils import save_video, set_seed
|
||||
|
||||
|
||||
class LattePABConfig(PABConfig):
|
||||
def __init__(
|
||||
self,
|
||||
steps: int = 50,
|
||||
spatial_broadcast: bool = True,
|
||||
spatial_threshold: list = [100, 800],
|
||||
spatial_range: int = 2,
|
||||
@ -62,7 +61,6 @@ class LattePABConfig(PABConfig):
|
||||
},
|
||||
):
|
||||
super().__init__(
|
||||
steps=steps,
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
spatial_threshold=spatial_threshold,
|
||||
spatial_range=spatial_range,
|
||||
@ -188,7 +186,7 @@ class LattePipeline(VideoSysPipeline):
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
_optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
@ -238,9 +236,6 @@ class LattePipeline(VideoSysPipeline):
|
||||
if config.enable_pab:
|
||||
set_pab_manager(config.pab_config)
|
||||
|
||||
# set eval and device
|
||||
self.set_eval_and_device(device, vae, transformer)
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
@ -249,11 +244,36 @@ class LattePipeline(VideoSysPipeline):
|
||||
if config.cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
else:
|
||||
self.set_eval_and_device(device, text_encoder)
|
||||
self.set_eval_and_device(device, text_encoder, vae, transformer)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# parallel
|
||||
self._set_parallel()
|
||||
|
||||
def _set_seed(self, seed):
|
||||
if dist.get_world_size() == 1:
|
||||
set_seed(seed)
|
||||
else:
|
||||
set_seed(seed, self.transformer.parallel_manager.dp_rank)
|
||||
|
||||
def _set_parallel(
|
||||
self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
|
||||
):
|
||||
# init sequence parallel
|
||||
if sp_size is None:
|
||||
sp_size = dist.get_world_size()
|
||||
dp_size = 1
|
||||
else:
|
||||
assert (
|
||||
dist.get_world_size() % sp_size == 0
|
||||
), f"world_size {dist.get_world_size()} must be divisible by sp_size"
|
||||
dp_size = dist.get_world_size() // sp_size
|
||||
|
||||
# transformer parallel
|
||||
self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
|
||||
|
||||
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
||||
def mask_text_embeddings(self, emb, mask):
|
||||
if emb.shape[0] == 1:
|
||||
@ -658,6 +678,7 @@ class LattePipeline(VideoSysPipeline):
|
||||
negative_prompt: str = "",
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
seed: int = -1,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
@ -745,6 +766,7 @@ class LattePipeline(VideoSysPipeline):
|
||||
width = 512
|
||||
update_steps(num_inference_steps)
|
||||
self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds)
|
||||
self._set_seed(seed)
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
|
||||
@ -2,19 +2,22 @@ import html
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import ftfy
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from bs4 import BeautifulSoup
|
||||
from diffusers.models import AutoencoderKL
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from videosys.core.pab_mgr import PABConfig, set_pab_manager
|
||||
from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
|
||||
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
|
||||
from videosys.models.autoencoders.autoencoder_kl_open_sora import OpenSoraVAE_V1_2
|
||||
from videosys.models.transformers.open_sora_transformer_3d import STDiT3
|
||||
from videosys.schedulers.scheduling_rflow_open_sora import RFLOW
|
||||
from videosys.utils.utils import save_video
|
||||
from videosys.utils.utils import save_video, set_seed
|
||||
|
||||
from .data_process import get_image_size, get_num_frames, prepare_multi_resolution_info, read_from_path
|
||||
|
||||
@ -29,7 +32,6 @@ BAD_PUNCT_REGEX = re.compile(
|
||||
class OpenSoraPABConfig(PABConfig):
|
||||
def __init__(
|
||||
self,
|
||||
steps: int = 50,
|
||||
spatial_broadcast: bool = True,
|
||||
spatial_threshold: list = [450, 930],
|
||||
spatial_range: int = 2,
|
||||
@ -52,7 +54,6 @@ class OpenSoraPABConfig(PABConfig):
|
||||
},
|
||||
):
|
||||
super().__init__(
|
||||
steps=steps,
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
spatial_threshold=spatial_threshold,
|
||||
spatial_range=spatial_range,
|
||||
@ -187,11 +188,7 @@ class OpenSoraPipeline(VideoSysPipeline):
|
||||
bad_punct_regex = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = [
|
||||
"text_encoder",
|
||||
"tokenizer",
|
||||
]
|
||||
_optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
@ -234,9 +231,6 @@ class OpenSoraPipeline(VideoSysPipeline):
|
||||
if config.enable_pab:
|
||||
set_pab_manager(config.pab_config)
|
||||
|
||||
# set eval and device
|
||||
self.set_eval_and_device(device, vae, transformer)
|
||||
|
||||
self.register_modules(
|
||||
text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, tokenizer=tokenizer
|
||||
)
|
||||
@ -245,7 +239,32 @@ class OpenSoraPipeline(VideoSysPipeline):
|
||||
if config.cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
else:
|
||||
self.set_eval_and_device(self._device, text_encoder)
|
||||
self.set_eval_and_device(self._device, vae, transformer, text_encoder)
|
||||
|
||||
# parallel
|
||||
self._set_parallel()
|
||||
|
||||
def _set_seed(self, seed):
|
||||
if dist.get_world_size() == 1:
|
||||
set_seed(seed)
|
||||
else:
|
||||
set_seed(seed, self.transformer.parallel_manager.dp_rank)
|
||||
|
||||
def _set_parallel(
|
||||
self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
|
||||
):
|
||||
# init sequence parallel
|
||||
if sp_size is None:
|
||||
sp_size = dist.get_world_size()
|
||||
dp_size = 1
|
||||
else:
|
||||
assert (
|
||||
dist.get_world_size() % sp_size == 0
|
||||
), f"world_size {dist.get_world_size()} must be divisible by sp_size"
|
||||
dp_size = dist.get_world_size() // sp_size
|
||||
|
||||
# transformer parallel
|
||||
self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
|
||||
|
||||
def get_text_embeddings(self, texts):
|
||||
text_tokens_and_mask = self.tokenizer(
|
||||
@ -283,10 +302,6 @@ class OpenSoraPipeline(VideoSysPipeline):
|
||||
return text.strip()
|
||||
|
||||
def _clean_caption(self, caption):
|
||||
import urllib.parse as ul
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
@ -418,6 +433,7 @@ class OpenSoraPipeline(VideoSysPipeline):
|
||||
loop: int = 1,
|
||||
llm_refine: bool = False,
|
||||
negative_prompt: str = "",
|
||||
seed: int = -1,
|
||||
ms: Optional[str] = "",
|
||||
refs: Optional[str] = "",
|
||||
aes: float = 6.5,
|
||||
@ -460,10 +476,6 @@ class OpenSoraPipeline(VideoSysPipeline):
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The width in pixels of the generated image.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
@ -508,6 +520,8 @@ class OpenSoraPipeline(VideoSysPipeline):
|
||||
fps = 24
|
||||
image_size = get_image_size(resolution, aspect_ratio)
|
||||
num_frames = get_num_frames(num_frames)
|
||||
self._set_seed(seed)
|
||||
update_steps(self._config.num_sampling_steps)
|
||||
|
||||
# == prepare batch prompts ==
|
||||
batch_prompts = [prompt]
|
||||
@ -621,7 +635,7 @@ class OpenSoraPipeline(VideoSysPipeline):
|
||||
progress=verbose,
|
||||
mask=masks,
|
||||
)
|
||||
samples = self.vae.decode(samples.to(self._dtype), num_frames=num_frames)
|
||||
samples = self.vae(samples.to(self._dtype), decode_only=True, num_frames=num_frames)
|
||||
video_clips.append(samples)
|
||||
|
||||
for i in range(1, loop):
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
from .pipeline_open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
|
||||
from .pipeline_open_sora_plan import (
|
||||
OpenSoraPlanConfig,
|
||||
OpenSoraPlanPipeline,
|
||||
OpenSoraPlanV110PABConfig,
|
||||
OpenSoraPlanV120PABConfig,
|
||||
)
|
||||
|
||||
__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanPABConfig"]
|
||||
__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig"]
|
||||
|
||||
@ -20,39 +20,27 @@ import torch.distributed as dist
|
||||
import tqdm
|
||||
from bs4 import BeautifulSoup
|
||||
from diffusers.models import AutoencoderKL, Transformer2DModel
|
||||
from diffusers.schedulers import PNDMScheduler
|
||||
from diffusers.schedulers import EulerAncestralDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
from transformers import AutoTokenizer, MT5EncoderModel, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
|
||||
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
|
||||
from videosys.models.autoencoders.autoencoder_kl_open_sora_plan_v110 import (
|
||||
CausalVAEModelWrapper as CausalVAEModelWrapperV110,
|
||||
)
|
||||
from videosys.models.autoencoders.autoencoder_kl_open_sora_plan_v120 import (
|
||||
CausalVAEModelWrapper as CausalVAEModelWrapperV120,
|
||||
)
|
||||
from videosys.models.transformers.open_sora_plan_v110_transformer_3d import LatteT2V
|
||||
from videosys.models.transformers.open_sora_plan_v120_transformer_3d import OpenSoraT2V
|
||||
from videosys.utils.logging import logger
|
||||
from videosys.utils.utils import save_video
|
||||
|
||||
from ...models.autoencoders.autoencoder_kl_open_sora_plan import ae_stride_config, getae_wrapper
|
||||
from ...models.transformers.open_sora_plan_transformer_3d import LatteT2V
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import PixArtAlphaPipeline
|
||||
|
||||
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
|
||||
>>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
|
||||
>>> # Enable memory optimizations.
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = "A small cactus with a happy face in the Sahara desert."
|
||||
>>> image = pipe(prompt).images[0]
|
||||
```
|
||||
"""
|
||||
from videosys.utils.utils import save_video, set_seed
|
||||
|
||||
|
||||
class OpenSoraPlanPABConfig(PABConfig):
|
||||
class OpenSoraPlanV110PABConfig(PABConfig):
|
||||
def __init__(
|
||||
self,
|
||||
steps: int = 150,
|
||||
spatial_broadcast: bool = True,
|
||||
spatial_threshold: list = [100, 850],
|
||||
spatial_range: int = 2,
|
||||
@ -97,7 +85,6 @@ class OpenSoraPlanPABConfig(PABConfig):
|
||||
},
|
||||
):
|
||||
super().__init__(
|
||||
steps=steps,
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
spatial_threshold=spatial_threshold,
|
||||
spatial_range=spatial_range,
|
||||
@ -113,6 +100,26 @@ class OpenSoraPlanPABConfig(PABConfig):
|
||||
)
|
||||
|
||||
|
||||
class OpenSoraPlanV120PABConfig(PABConfig):
|
||||
def __init__(
|
||||
self,
|
||||
spatial_broadcast: bool = True,
|
||||
spatial_threshold: list = [100, 850],
|
||||
spatial_range: int = 2,
|
||||
cross_broadcast: bool = True,
|
||||
cross_threshold: list = [100, 850],
|
||||
cross_range: int = 6,
|
||||
):
|
||||
super().__init__(
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
spatial_threshold=spatial_threshold,
|
||||
spatial_range=spatial_range,
|
||||
cross_broadcast=cross_broadcast,
|
||||
cross_threshold=cross_threshold,
|
||||
cross_range=cross_range,
|
||||
)
|
||||
|
||||
|
||||
class OpenSoraPlanConfig:
|
||||
"""
|
||||
This config is to instantiate a `OpenSoraPlanPipeline` class for video generation.
|
||||
@ -163,10 +170,10 @@ class OpenSoraPlanConfig:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: str = "LanguageBind/Open-Sora-Plan-v1.1.0",
|
||||
ae: str = "CausalVAEModel_4x8x8",
|
||||
text_encoder: str = "DeepFloyd/t5-v1_1-xxl",
|
||||
num_frames: int = 65,
|
||||
version: str = "v120",
|
||||
transformer_type: str = "29x480p",
|
||||
transformer: str = None,
|
||||
text_encoder: str = None,
|
||||
# ======= distributed ========
|
||||
num_gpus: int = 1,
|
||||
# ======= memory =======
|
||||
@ -175,15 +182,32 @@ class OpenSoraPlanConfig:
|
||||
tile_overlap_factor: float = 0.25,
|
||||
# ======= pab ========
|
||||
enable_pab: bool = False,
|
||||
pab_config: PABConfig = OpenSoraPlanPABConfig(),
|
||||
pab_config: PABConfig = None,
|
||||
):
|
||||
self.pipeline_cls = OpenSoraPlanPipeline
|
||||
self.ae = ae
|
||||
self.text_encoder = text_encoder
|
||||
self.transformer = transformer
|
||||
assert num_frames in [65, 221], "num_frames must be one of [65, 221]"
|
||||
self.num_frames = num_frames
|
||||
self.version = f"{num_frames}x512x512"
|
||||
|
||||
# get version
|
||||
assert version in ["v110", "v120"], f"Unknown Open-Sora-Plan version: {version}"
|
||||
self.version = version
|
||||
self.transformer_type = transformer_type
|
||||
|
||||
# check transformer_type
|
||||
if version == "v110":
|
||||
assert transformer_type in ["65x512x512", "221x512x512"]
|
||||
elif version == "v120":
|
||||
assert transformer_type in ["93x480p", "93x720p", "29x480p", "29x720p"]
|
||||
self.num_frames = int(transformer_type.split("x")[0])
|
||||
|
||||
# set default values according to version
|
||||
if version == "v110":
|
||||
transformer_default = "LanguageBind/Open-Sora-Plan-v1.1.0"
|
||||
text_encoder_default = "DeepFloyd/t5-v1_1-xxl"
|
||||
elif version == "v120":
|
||||
transformer_default = "LanguageBind/Open-Sora-Plan-v1.2.0"
|
||||
text_encoder_default = "google/mt5-xxl"
|
||||
self.text_encoder = text_encoder or text_encoder_default
|
||||
self.transformer = transformer or transformer_default
|
||||
|
||||
# ======= distributed ========
|
||||
self.num_gpus = num_gpus
|
||||
# ======= memory ========
|
||||
@ -192,7 +216,13 @@ class OpenSoraPlanConfig:
|
||||
self.tile_overlap_factor = tile_overlap_factor
|
||||
# ======= pab ========
|
||||
self.enable_pab = enable_pab
|
||||
self.pab_config = pab_config
|
||||
if self.enable_pab and pab_config is None:
|
||||
if version == "v110":
|
||||
self.pab_config = OpenSoraPlanV110PABConfig()
|
||||
elif version == "v120":
|
||||
self.pab_config = OpenSoraPlanV120PABConfig()
|
||||
else:
|
||||
self.pab_config = pab_config
|
||||
|
||||
|
||||
class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
@ -221,7 +251,7 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
_optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
@ -240,26 +270,57 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
|
||||
# init
|
||||
if tokenizer is None:
|
||||
tokenizer = T5Tokenizer.from_pretrained(config.text_encoder)
|
||||
if config.version == "v110":
|
||||
tokenizer = T5Tokenizer.from_pretrained(config.text_encoder)
|
||||
elif config.version == "v120":
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.text_encoder)
|
||||
|
||||
if text_encoder is None:
|
||||
text_encoder = T5EncoderModel.from_pretrained(config.text_encoder, torch_dtype=torch.float16)
|
||||
if config.version == "v110":
|
||||
text_encoder = T5EncoderModel.from_pretrained(config.text_encoder, torch_dtype=dtype)
|
||||
elif config.version == "v120":
|
||||
text_encoder = MT5EncoderModel.from_pretrained(
|
||||
config.text_encoder, low_cpu_mem_usage=True, torch_dtype=dtype
|
||||
)
|
||||
|
||||
if vae is None:
|
||||
vae = getae_wrapper(config.ae)(config.transformer, subfolder="vae").to(dtype=dtype)
|
||||
if config.version == "v110":
|
||||
vae = CausalVAEModelWrapperV110(config.transformer, subfolder="vae").to(dtype=dtype)
|
||||
elif config.version == "v120":
|
||||
vae = CausalVAEModelWrapperV120(config.transformer, subfolder="vae").to(dtype=dtype)
|
||||
|
||||
if transformer is None:
|
||||
transformer = LatteT2V.from_pretrained(config.transformer, subfolder=config.version, torch_dtype=dtype)
|
||||
if config.version == "v110":
|
||||
transformer = LatteT2V.from_pretrained(
|
||||
config.transformer, subfolder=config.transformer_type, torch_dtype=dtype
|
||||
)
|
||||
elif config.version == "v120":
|
||||
transformer = OpenSoraT2V.from_pretrained(
|
||||
config.transformer, subfolder=config.transformer_type, torch_dtype=dtype
|
||||
)
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = PNDMScheduler()
|
||||
if config.version == "v110":
|
||||
scheduler = PNDMScheduler()
|
||||
elif config.version == "v120":
|
||||
scheduler = EulerAncestralDiscreteScheduler()
|
||||
|
||||
# setting
|
||||
if config.enable_tiling:
|
||||
vae.vae.enable_tiling()
|
||||
vae.vae.tile_overlap_factor = config.tile_overlap_factor
|
||||
vae.vae_scale_factor = ae_stride_config[config.ae]
|
||||
vae.vae.tile_sample_min_size = 512
|
||||
vae.vae.tile_latent_min_size = 64
|
||||
vae.vae.tile_sample_min_size_t = 29
|
||||
vae.vae.tile_latent_min_size_t = 8
|
||||
# if low_mem:
|
||||
# vae.vae.tile_sample_min_size = 256
|
||||
# vae.vae.tile_latent_min_size = 32
|
||||
# vae.vae.tile_sample_min_size_t = 29
|
||||
# vae.vae.tile_latent_min_size_t = 8
|
||||
vae.vae_scale_factor = [4, 8, 8]
|
||||
transformer.force_images = False
|
||||
|
||||
# set eval and device
|
||||
self.set_eval_and_device(device, vae, transformer)
|
||||
|
||||
# pab
|
||||
if config.enable_pab:
|
||||
set_pab_manager(config.pab_config)
|
||||
@ -272,10 +333,35 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
if config.cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
else:
|
||||
self.set_eval_and_device(device, text_encoder)
|
||||
self.set_eval_and_device(device, text_encoder, vae, transformer)
|
||||
|
||||
# self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
# parallel
|
||||
self._set_parallel()
|
||||
|
||||
def _set_seed(self, seed):
|
||||
if dist.get_world_size() == 1:
|
||||
set_seed(seed)
|
||||
else:
|
||||
set_seed(seed, self.transformer.parallel_manager.dp_rank)
|
||||
|
||||
def _set_parallel(
|
||||
self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
|
||||
):
|
||||
# init sequence parallel
|
||||
if sp_size is None:
|
||||
sp_size = dist.get_world_size()
|
||||
dp_size = 1
|
||||
else:
|
||||
assert (
|
||||
dist.get_world_size() % sp_size == 0
|
||||
), f"world_size {dist.get_world_size()} must be divisible by sp_size"
|
||||
dp_size = dist.get_world_size() // sp_size
|
||||
|
||||
# transformer parallel
|
||||
self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
|
||||
|
||||
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
||||
def mask_text_embeddings(self, emb, mask):
|
||||
if emb.shape[0] == 1:
|
||||
@ -449,6 +535,143 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def encode_prompt_v120(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str = "",
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 120,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
||||
PixArt-Alpha, this should be "".
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
|
||||
string.
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
if device is None:
|
||||
device = getattr(self, "_execution_device", None) or getattr(self, "device", None) or torch.device("cuda")
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# See Section 3.1. of the paper.
|
||||
max_length = max_sequence_length
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens = [negative_prompt] * batch_size
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
negative_prompt_attention_mask = uncond_input.attention_mask
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
negative_prompt_attention_mask = None
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
@ -476,6 +699,8 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
callback_steps,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
@ -519,6 +744,12 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
@ -685,10 +916,13 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
seed: int = -1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@ -697,6 +931,7 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
mask_feature: bool = True,
|
||||
enable_temporal_attentions: bool = True,
|
||||
verbose: bool = True,
|
||||
max_sequence_length: int = 512,
|
||||
) -> Union[VideoSysPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@ -768,11 +1003,22 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
returned where the first element is a list with the generated images
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
height = 512
|
||||
width = 512
|
||||
height = self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1]
|
||||
width = self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2]
|
||||
num_frames = self._config.num_frames
|
||||
update_steps(num_inference_steps)
|
||||
self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds)
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_steps,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
self._set_seed(seed)
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@ -782,7 +1028,7 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
device = getattr(self, "_execution_device", None) or getattr(self, "device", None) or torch.device("cuda")
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@ -790,23 +1036,50 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
clean_caption=clean_caption,
|
||||
mask_feature=mask_feature,
|
||||
)
|
||||
if self._config.version == "v110":
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
clean_caption=clean_caption,
|
||||
mask_feature=mask_feature,
|
||||
)
|
||||
prompt_attention_mask = None
|
||||
negative_prompt_attention_mask = None
|
||||
else:
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt_v120(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
if prompt_attention_mask is not None and negative_prompt_attention_mask is not None:
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
if self._config.version == "v110":
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
else:
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
@ -827,15 +1100,10 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
|
||||
# 6.1 Prepare micro-conditions.
|
||||
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
||||
# if self.transformer.config.sample_size == 128:
|
||||
# resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
|
||||
# aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
|
||||
# resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
|
||||
# aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
|
||||
# added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
|
||||
for i, t in progress_wrap(list(enumerate(timesteps))):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@ -856,9 +1124,15 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
||||
|
||||
if prompt_embeds.ndim == 3:
|
||||
if prompt_embeds is not None and prompt_embeds.ndim == 3:
|
||||
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
|
||||
|
||||
if prompt_attention_mask is not None and prompt_attention_mask.ndim == 2:
|
||||
prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l
|
||||
# prepare attention_mask.
|
||||
# b c t h w -> b t h w
|
||||
attention_mask = torch.ones_like(latent_model_input)[:, 0]
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
@ -867,6 +1141,8 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
timestep=current_timestep,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
enable_temporal_attentions=enable_temporal_attentions,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@ -906,10 +1182,57 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
|
||||
return VideoSysPipelineOutput(video=video)
|
||||
|
||||
def decode_latents(self, latents):
|
||||
video = self.vae.decode(latents) # b t c h w
|
||||
# b t c h w -> b t h w c
|
||||
video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous()
|
||||
video = self.vae(latents.to(self.vae.vae.dtype))
|
||||
video = (
|
||||
((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous()
|
||||
) # b t h w c
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
return video
|
||||
|
||||
def save_video(self, video, output_path):
|
||||
save_video(video, output_path, fps=24)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
3
videosys/pipelines/vchitect/__init__.py
Normal file
3
videosys/pipelines/vchitect/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .pipeline_vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline
|
||||
|
||||
__all__ = ["VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig"]
|
||||
1057
videosys/pipelines/vchitect/pipeline_vchitect.py
Normal file
1057
videosys/pipelines/vchitect/pipeline_vchitect.py
Normal file
File diff suppressed because it is too large
Load Diff
0
videosys/utils/__init__.py
Normal file
0
videosys/utils/__init__.py
Normal file
12
videosys/utils/test.py
Normal file
12
videosys/utils/test.py
Normal file
@ -0,0 +1,12 @@
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def empty_cache(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
torch.cuda.empty_cache()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@ -16,7 +16,17 @@ def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
|
||||
p.requires_grad = flag
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
def set_seed(seed, dp_rank=None):
|
||||
if seed == -1:
|
||||
seed = random.randint(0, 1000000)
|
||||
|
||||
if dp_rank is not None:
|
||||
seed = torch.tensor(seed, dtype=torch.int64).cuda()
|
||||
if dist.get_world_size() > 1:
|
||||
dist.broadcast(seed, 0)
|
||||
seed = seed + dp_rank
|
||||
|
||||
seed = int(seed)
|
||||
random.seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user