diff --git a/eval/teacache/README.md b/eval/teacache/README.md
new file mode 100644
index 0000000..2fc8e65
--- /dev/null
+++ b/eval/teacache/README.md
@@ -0,0 +1,30 @@
+# Evaluation of TeaCache
+
+We first generate videos according to VBench's prompts.
+
+And then calculate Vbench, PSNR, LPIPS and SSIM based on the video generated.
+
+1. Generate video
+```
+cd eval/teacache
+python experiments/latte.py
+python experiments/opensora.py
+python experiments/open_sora_plan.py
+```
+
+2. Calculate Vbench score
+```
+# vbench is calculated independently
+# get scores for all metrics
+python vbench/run_vbench.py --video_path aaa --save_path bbb
+# calculate final score
+python vbench/cal_vbench.py --score_dir bbb
+```
+
+3. Calculate other metrics
+```
+# these metrics are calculated compared with original model
+# gt video is the video of original model
+# generated video is our methods's results
+python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb
+```
diff --git a/eval/teacache/common_metrics/batch_eval.py b/eval/teacache/common_metrics/batch_eval.py
new file mode 100644
index 0000000..1696235
--- /dev/null
+++ b/eval/teacache/common_metrics/batch_eval.py
@@ -0,0 +1,205 @@
+import argparse
+import os
+
+import imageio
+import torch
+import torchvision.transforms.functional as F
+import tqdm
+from calculate_lpips import calculate_lpips
+from calculate_psnr import calculate_psnr
+from calculate_ssim import calculate_ssim
+
+
+def load_video(video_path):
+ """
+ Load a video from the given path and convert it to a PyTorch tensor.
+ """
+ # Read the video using imageio
+ reader = imageio.get_reader(video_path, "ffmpeg")
+
+ # Extract frames and convert to a list of tensors
+ frames = []
+ for frame in reader:
+ # Convert the frame to a tensor and permute the dimensions to match (C, H, W)
+ frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
+ frames.append(frame_tensor)
+
+ # Stack the list of tensors into a single tensor with shape (T, C, H, W)
+ video_tensor = torch.stack(frames)
+
+ return video_tensor
+
+
+def resize_video(video, target_height, target_width):
+ resized_frames = []
+ for frame in video:
+ resized_frame = F.resize(frame, [target_height, target_width])
+ resized_frames.append(resized_frame)
+ return torch.stack(resized_frames)
+
+
+def resize_gt_video(gt_video, gen_video):
+ gen_video_shape = gen_video.shape
+ T_gen, _, H_gen, W_gen = gen_video_shape
+ T_eval, _, H_eval, W_eval = gt_video.shape
+
+ if T_eval < T_gen:
+ raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
+
+ if H_eval < H_gen or W_eval < W_gen:
+ # Resize the video maintaining the aspect ratio
+ resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
+ resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
+ gt_video = resize_video(gt_video, resize_height, resize_width)
+ # Recalculate the dimensions
+ T_eval, _, H_eval, W_eval = gt_video.shape
+
+ # Center crop
+ start_h = (H_eval - H_gen) // 2
+ start_w = (W_eval - W_gen) // 2
+ cropped_video = gt_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
+
+ return cropped_video
+
+
+def get_video_ids(gt_video_dirs, gen_video_dirs):
+ video_ids = []
+ for f in os.listdir(gt_video_dirs[0]):
+ if f.endswith(f".mp4"):
+ video_ids.append(f.replace(f".mp4", ""))
+ video_ids.sort()
+
+ for video_dir in gt_video_dirs + gen_video_dirs:
+ tmp_video_ids = []
+ for f in os.listdir(video_dir):
+ if f.endswith(f".mp4"):
+ tmp_video_ids.append(f.replace(f".mp4", ""))
+ tmp_video_ids.sort()
+ if tmp_video_ids != video_ids:
+ raise ValueError(f"Video IDs in {video_dir} are different.")
+ return video_ids
+
+
+def get_videos(video_ids, gt_video_dirs, gen_video_dirs):
+ gt_videos = {}
+ generated_videos = {}
+
+ for gt_video_dir in gt_video_dirs:
+ tmp_gt_videos_tensor = []
+ for video_id in video_ids:
+ gt_video = load_video(os.path.join(gt_video_dir, f"{video_id}.mp4"))
+ tmp_gt_videos_tensor.append(gt_video)
+ gt_videos[gt_video_dir] = tmp_gt_videos_tensor
+
+ for generated_video_dir in gen_video_dirs:
+ tmp_generated_videos_tensor = []
+ for video_id in video_ids:
+ generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.mp4"))
+ tmp_generated_videos_tensor.append(generated_video)
+ generated_videos[generated_video_dir] = tmp_generated_videos_tensor
+
+ return gt_videos, generated_videos
+
+
+def print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs):
+ out_str = ""
+
+ for gt_video_dir in gt_video_dirs:
+ for generated_video_dir in gen_video_dirs:
+ if gt_video_dir == generated_video_dir:
+ continue
+ lpips = sum(lpips_results[gt_video_dir][generated_video_dir]) / len(
+ lpips_results[gt_video_dir][generated_video_dir]
+ )
+ psnr = sum(psnr_results[gt_video_dir][generated_video_dir]) / len(
+ psnr_results[gt_video_dir][generated_video_dir]
+ )
+ ssim = sum(ssim_results[gt_video_dir][generated_video_dir]) / len(
+ ssim_results[gt_video_dir][generated_video_dir]
+ )
+ out_str += f"\ngt: {gt_video_dir} -> gen: {generated_video_dir}, lpips: {lpips:.4f}, psnr: {psnr:.4f}, ssim: {ssim:.4f}"
+
+ return out_str
+
+
+def main(args):
+ device = "cuda"
+ gt_video_dirs = args.gt_video_dirs
+ gen_video_dirs = args.gen_video_dirs
+
+ video_ids = get_video_ids(gt_video_dirs, gen_video_dirs)
+ print(f"Find {len(video_ids)} videos")
+
+ prompt_interval = 1
+ batch_size = 8
+ calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
+
+ lpips_results = {}
+ psnr_results = {}
+ ssim_results = {}
+ for gt_video_dir in gt_video_dirs:
+ lpips_results[gt_video_dir] = {}
+ psnr_results[gt_video_dir] = {}
+ ssim_results[gt_video_dir] = {}
+ for generated_video_dir in gen_video_dirs:
+ lpips_results[gt_video_dir][generated_video_dir] = []
+ psnr_results[gt_video_dir][generated_video_dir] = []
+ ssim_results[gt_video_dir][generated_video_dir] = []
+
+ total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
+
+ for idx in tqdm.tqdm(range(total_len)):
+ video_ids_batch = video_ids[idx * batch_size : (idx + 1) * batch_size]
+ gt_videos, generated_videos = get_videos(video_ids_batch, gt_video_dirs, gen_video_dirs)
+
+ for gt_video_dir, gt_videos_tensor in gt_videos.items():
+ for generated_video_dir, generated_videos_tensor in generated_videos.items():
+ if gt_video_dir == generated_video_dir:
+ continue
+
+ if not isinstance(gt_videos_tensor, torch.Tensor):
+ for i in range(len(gt_videos_tensor)):
+ gt_videos_tensor[i] = resize_gt_video(gt_videos_tensor[i], generated_videos_tensor[0])
+ gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
+
+ generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
+
+ if calculate_lpips_flag:
+ result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
+ result = result["value"].values()
+ result = float(sum(result) / len(result))
+ lpips_results[gt_video_dir][generated_video_dir].append(result)
+
+ if calculate_psnr_flag:
+ result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
+ result = result["value"].values()
+ result = float(sum(result) / len(result))
+ psnr_results[gt_video_dir][generated_video_dir].append(result)
+
+ if calculate_ssim_flag:
+ result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
+ result = result["value"].values()
+ result = float(sum(result) / len(result))
+ ssim_results[gt_video_dir][generated_video_dir].append(result)
+
+ if (idx + 1) % prompt_interval == 0:
+ out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
+ print(f"Processed {idx + 1} / {total_len} videos. {out_str}")
+
+ out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
+
+ # save
+ with open(f"./batch_eval.txt", "w+") as f:
+ f.write(out_str)
+
+ print(f"Processed all videos. {out_str}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--gt_video_dirs", type=str, nargs="+")
+ parser.add_argument("--gen_video_dirs", type=str, nargs="+")
+
+ args = parser.parse_args()
+
+ main(args)
diff --git a/eval/teacache/experiments/latte.py b/eval/teacache/experiments/latte.py
index 962da93..fbd6498 100644
--- a/eval/teacache/experiments/latte.py
+++ b/eval/teacache/experiments/latte.py
@@ -5,12 +5,7 @@ from einops import rearrange, repeat
from torch import nn
import numpy as np
from typing import Any, Dict, Optional, Tuple
-from videosys.core.parallel_mgr import (
- enable_sequence_parallel,
- get_cfg_parallel_size,
- get_data_parallel_group,
- get_sequence_parallel_group,
-)
+from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
def teacache_forward(
self,
@@ -67,7 +62,7 @@ def teacache_forward(
"""
# 0. Split batch for data parallelism
- if get_cfg_parallel_size() > 1:
+ if self.parallel_manager.cp_size > 1:
(
hidden_states,
timestep,
@@ -77,7 +72,7 @@ def teacache_forward(
attention_mask,
encoder_attention_mask,
) = batch_func(
- partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
+ partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
hidden_states,
timestep,
encoder_hidden_states,
@@ -193,14 +188,14 @@ def teacache_forward(
if not should_calc:
hidden_states += self.previous_residual
else:
- if enable_sequence_parallel():
- set_temporal_pad(frame + use_image_num)
- set_spatial_pad(num_patches)
+ if self.parallel_manager.sp_size > 1:
+ set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
+ set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
temp_pos_embed = split_sequence(
- self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
+ self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
@@ -323,14 +318,14 @@ def teacache_forward(
).contiguous()
self.previous_residual = hidden_states - hidden_states_origin
else:
- if enable_sequence_parallel():
- set_temporal_pad(frame + use_image_num)
- set_spatial_pad(num_patches)
+ if self.parallel_manager.sp_size > 1:
+ set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
+ set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
temp_pos_embed = split_sequence(
- self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
+ self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
@@ -451,7 +446,8 @@ def teacache_forward(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
- if enable_sequence_parallel():
+
+ if self.parallel_manager.sp_size > 1:
if self.enable_teacache:
if should_calc:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
@@ -488,8 +484,8 @@ def teacache_forward(
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
# 3. Gather batch for data parallelism
- if get_cfg_parallel_size() > 1:
- output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
+ if self.parallel_manager.cp_size > 1:
+ output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)
@@ -497,10 +493,6 @@ def teacache_forward(
return Transformer3DModelOutput(sample=output)
-def eval_base(prompt_list):
- config = LatteConfig()
- engine = VideoSysEngine(config)
- generate_func(engine, prompt_list, "./samples/latte_base", loop=5)
def eval_teacache_slow(prompt_list):
config = LatteConfig()
@@ -523,10 +515,18 @@ def eval_teacache_fast(prompt_list):
engine.driver_worker.transformer.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/latte_teacache_fast", loop=5)
+
+
+
+def eval_base(prompt_list):
+ config = LatteConfig()
+ engine = VideoSysEngine(config)
+ generate_func(engine, prompt_list, "./samples/latte_base", loop=5)
+
+
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
- # eval_base(prompt_list)
+ eval_base(prompt_list)
eval_teacache_slow(prompt_list)
- # eval_teacache_fast(prompt_list)
-
+ eval_teacache_fast(prompt_list)
\ No newline at end of file
diff --git a/eval/teacache/experiments/opensora.py b/eval/teacache/experiments/opensora.py
index a001c5a..3ef72fd 100644
--- a/eval/teacache/experiments/opensora.py
+++ b/eval/teacache/experiments/opensora.py
@@ -3,23 +3,23 @@ from videosys import OpenSoraConfig, VideoSysEngine
import torch
from einops import rearrange
from videosys.models.transformers.open_sora_transformer_3d import t2i_modulate, auto_grad_checkpoint
-from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_temporal_pad, set_spatial_pad, set_temporal_pad, split_sequence
+from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
import numpy as np
from videosys.utils.utils import batch_func
-from videosys.core.parallel_mgr import (
- enable_sequence_parallel,
- get_cfg_parallel_size,
- get_data_parallel_group,
- get_sequence_parallel_group,
-)
+from functools import partial
def teacache_forward(
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
):
# === Split batch ===
- if get_cfg_parallel_size() > 1:
+ if self.parallel_manager.cp_size > 1:
x, timestep, y, x_mask, mask = batch_func(
- partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
+ partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
+ x,
+ timestep,
+ y,
+ x_mask,
+ mask,
)
dtype = self.x_embedder.proj.weight.dtype
@@ -60,7 +60,7 @@ def teacache_forward(
# === get x embed ===
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
- x = x + pos_emb
+ x = x + pos_emb
if self.enable_teacache:
inp = x.clone()
@@ -84,7 +84,6 @@ def teacache_forward(
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
-
# === blocks ===
if self.enable_teacache:
if not should_calc:
@@ -92,16 +91,15 @@ def teacache_forward(
x += self.previous_residual
else:
# shard over the sequence dim if sp is enabled
- if enable_sequence_parallel():
- set_temporal_pad(T)
- set_spatial_pad(S)
- x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
+ if self.parallel_manager.sp_size > 1:
+ set_pad("temporal", T, self.parallel_manager.sp_group)
+ set_pad("spatial", S, self.parallel_manager.sp_group)
+ x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
T = x.shape[1]
x_mask_org = x_mask
x_mask = split_sequence(
- x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
+ x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
-
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
origin_x = x.clone().detach()
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
@@ -135,16 +133,17 @@ def teacache_forward(
self.previous_residual = x - origin_x
else:
# shard over the sequence dim if sp is enabled
- if enable_sequence_parallel():
- set_temporal_pad(T)
- set_spatial_pad(S)
- x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
+ if self.parallel_manager.sp_size > 1:
+ set_pad("temporal", T, self.parallel_manager.sp_group)
+ set_pad("spatial", S, self.parallel_manager.sp_group)
+ x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
T = x.shape[1]
x_mask_org = x_mask
x_mask = split_sequence(
- x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
+ x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
+
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
x = auto_grad_checkpoint(
spatial_block,
@@ -174,25 +173,23 @@ def teacache_forward(
all_timesteps=all_timesteps,
)
- if enable_sequence_parallel():
+ if self.parallel_manager.sp_size > 1:
if self.enable_teacache:
if should_calc:
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
self.previous_residual = rearrange(self.previous_residual, "B (T S) C -> B T S C", T=T, S=S)
- x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
- self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
+ x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
+ self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
T, S = x.shape[1], x.shape[2]
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
self.previous_residual = rearrange(self.previous_residual, "B T S C -> B (T S) C", T=T, S=S)
x_mask = x_mask_org
else:
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
- x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
+ x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
T, S = x.shape[1], x.shape[2]
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
x_mask = x_mask_org
-
-
# === final layer ===
x = self.final_layer(x, t, x_mask, t0, T, S)
x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
@@ -201,12 +198,11 @@ def teacache_forward(
x = x.to(torch.float32)
# === Gather Output ===
- if get_cfg_parallel_size() > 1:
- x = gather_sequence(x, get_data_parallel_group(), dim=0)
+ if self.parallel_manager.cp_size > 1:
+ x = gather_sequence(x, self.parallel_manager.cp_group, dim=0)
return x
-
def eval_base(prompt_list):
config = OpenSoraConfig()
engine = VideoSysEngine(config)
@@ -235,9 +231,9 @@ def eval_teacache_fast(prompt_list):
generate_func(engine, prompt_list, "./samples/opensora_teacache_fast", loop=5)
-
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
- # eval_base(prompt_list)
+ eval_base(prompt_list)
eval_teacache_slow(prompt_list)
- # eval_teacache_fast(prompt_list)
+ eval_teacache_fast(prompt_list)
+
\ No newline at end of file
diff --git a/eval/teacache/experiments/opensora_plan.py b/eval/teacache/experiments/opensora_plan.py
index 377f307..b056687 100644
--- a/eval/teacache/experiments/opensora_plan.py
+++ b/eval/teacache/experiments/opensora_plan.py
@@ -5,12 +5,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
from typing import Any, Dict, Optional, Tuple
-from videosys.core.parallel_mgr import (
- enable_sequence_parallel,
- get_cfg_parallel_group,
- get_cfg_parallel_size,
- get_sequence_parallel_group,
-)
+from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
def teacache_forward(
self,
@@ -66,7 +61,7 @@ def teacache_forward(
`tuple` where the first element is the sample tensor.
"""
# 0. Split batch
- if get_cfg_parallel_size() > 1:
+ if self.parallel_manager.cp_size > 1:
(
hidden_states,
timestep,
@@ -75,7 +70,7 @@ def teacache_forward(
attention_mask,
encoder_attention_mask,
) = batch_func(
- partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
+ partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
hidden_states,
timestep,
encoder_hidden_states,
@@ -204,20 +199,20 @@ def teacache_forward(
if not should_calc:
hidden_states += self.previous_residual
else:
- if enable_sequence_parallel():
- set_temporal_pad(frame + use_image_num)
- set_spatial_pad(num_patches)
+ if self.parallel_manager.sp_size > 1:
+ set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
+ set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
temp_pos_embed = split_sequence(
- self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
+ self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
- ori_hidden_states = hidden_states.clone()
+ ori_hidden_states = hidden_states.clone().detach()
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
@@ -359,20 +354,20 @@ def teacache_forward(
).contiguous()
self.previous_residual = hidden_states - ori_hidden_states
else:
- if enable_sequence_parallel():
- set_temporal_pad(frame + use_image_num)
- set_spatial_pad(num_patches)
+ if self.parallel_manager.sp_size > 1:
+ set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
+ set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
+ self.previous_residual = self.split_from_second_dim(self.previous_residual, input_batch_size) if self.previous_residual is not None else None
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
temp_pos_embed = split_sequence(
- self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
+ self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
-
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
@@ -512,8 +507,8 @@ def teacache_forward(
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
-
- if enable_sequence_parallel():
+
+ if self.parallel_manager.sp_size > 1:
if self.enable_teacache:
if should_calc:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
@@ -550,8 +545,8 @@ def teacache_forward(
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
# 3. Gather batch for data parallelism
- if get_cfg_parallel_size() > 1:
- output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
+ if self.parallel_manager.cp_size > 1:
+ output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)
@@ -559,14 +554,8 @@ def teacache_forward(
return Transformer3DModelOutput(sample=output)
-
-def eval_base(prompt_list):
- config = OpenSoraPlanConfig()
- engine = VideoSysEngine(config)
- generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)
-
def eval_teacache_slow(prompt_list):
- config = OpenSoraPlanConfig()
+ config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
engine = VideoSysEngine(config)
engine.driver_worker.transformer.enable_teacache = True
engine.driver_worker.transformer.rel_l1_thresh = 0.1
@@ -577,7 +566,7 @@ def eval_teacache_slow(prompt_list):
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
- config = OpenSoraPlanConfig()
+ config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
engine = VideoSysEngine(config)
engine.driver_worker.transformer.enable_teacache = True
engine.driver_worker.transformer.rel_l1_thresh = 0.2
@@ -587,8 +576,16 @@ def eval_teacache_fast(prompt_list):
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_fast", loop=5)
+
+def eval_base(prompt_list):
+ config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", )
+ engine = VideoSysEngine(config)
+ generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)
+
+
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
- # eval_base(prompt_list)
+ eval_base(prompt_list)
eval_teacache_slow(prompt_list)
- # eval_teacache_fast(prompt_list)
+ eval_teacache_fast(prompt_list)
+
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index ea6d973..3b8a98d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,7 @@
+accelerate>0.17.0
+bs4
click
-colossalai
+colossalai==0.4.0
diffusers==0.30.0
einops
fabric
@@ -16,7 +18,9 @@ pydantic
ray
rich
safetensors
+sentencepiece
timm
torch>=1.13
tqdm
-transformers
+peft==0.13.2
+transformers==4.39.3
diff --git a/setup.py b/setup.py
index 50b9537..9a2ed05 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,9 @@
from typing import List
from setuptools import find_packages, setup
+from setuptools.command.develop import develop
+from setuptools.command.egg_info import egg_info
+from setuptools.command.install import install
def fetch_requirements(path) -> List[str]:
@@ -14,7 +17,9 @@ def fetch_requirements(path) -> List[str]:
The lines in the requirements file.
"""
with open(path, "r") as fd:
- return [r.strip() for r in fd.readlines()]
+ requirements = [r.strip() for r in fd.readlines()]
+ # requirements.remove("colossalai")
+ return requirements
def fetch_readme() -> str:
@@ -28,6 +33,28 @@ def fetch_readme() -> str:
return f.read()
+def custom_install():
+ return ["pip", "install", "colossalai", "--no-deps"]
+
+
+class CustomInstallCommand(install):
+ def run(self):
+ install.run(self)
+ self.spawn(custom_install())
+
+
+class CustomDevelopCommand(develop):
+ def run(self):
+ develop.run(self)
+ self.spawn(custom_install())
+
+
+class CustomEggInfoCommand(egg_info):
+ def run(self):
+ egg_info.run(self)
+ self.spawn(custom_install())
+
+
setup(
name="videosys",
version="2.0.0",
@@ -39,12 +66,17 @@ setup(
"*.egg-info",
)
),
- description="VideoSys",
+ description="TeaCache",
long_description=fetch_readme(),
long_description_content_type="text/markdown",
license="Apache Software License 2.0",
install_requires=fetch_requirements("requirements.txt"),
- python_requires=">=3.6",
+ python_requires=">=3.7",
+ # cmdclass={
+ # "install": CustomInstallCommand,
+ # "develop": CustomDevelopCommand,
+ # "egg_info": CustomEggInfoCommand,
+ # },
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
diff --git a/videosys/__init__.py b/videosys/__init__.py
index 859fb7c..6c539c0 100644
--- a/videosys/__init__.py
+++ b/videosys/__init__.py
@@ -3,13 +3,20 @@ from .core.parallel_mgr import initialize
from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline
from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
-from .pipelines.open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
+from .pipelines.open_sora_plan import (
+ OpenSoraPlanConfig,
+ OpenSoraPlanPipeline,
+ OpenSoraPlanV110PABConfig,
+ OpenSoraPlanV120PABConfig,
+)
+from .pipelines.vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline
__all__ = [
"initialize",
"VideoSysEngine",
"LattePipeline", "LatteConfig", "LattePABConfig",
- "OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanPABConfig",
+ "OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig",
"OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig",
- "CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"
+ "CogVideoXPipeline", "CogVideoXConfig", "CogVideoXPABConfig",
+ "VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig"
] # fmt: skip
diff --git a/videosys/core/comm.py b/videosys/core/comm.py
index 175fba5..75c242b 100644
--- a/videosys/core/comm.py
+++ b/videosys/core/comm.py
@@ -7,8 +7,6 @@ from einops import rearrange
from torch import Tensor
from torch.distributed import ProcessGroup
-from videosys.core.parallel_mgr import get_sequence_parallel_size
-
# ======================================================
# Model
# ======================================================
@@ -369,30 +367,18 @@ def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
# Pad
# ==============================
-SPTIAL_PAD = 0
-TEMPORAL_PAD = 0
+PAD_DICT = {}
-def set_spatial_pad(dim_size: int):
- sp_size = get_sequence_parallel_size()
+def set_pad(name: str, dim_size: int, parallel_group: dist.ProcessGroup):
+ sp_size = dist.get_world_size(parallel_group)
pad = (sp_size - (dim_size % sp_size)) % sp_size
- global SPTIAL_PAD
- SPTIAL_PAD = pad
+ global PAD_DICT
+ PAD_DICT[name] = pad
-def get_spatial_pad() -> int:
- return SPTIAL_PAD
-
-
-def set_temporal_pad(dim_size: int):
- sp_size = get_sequence_parallel_size()
- pad = (sp_size - (dim_size % sp_size)) % sp_size
- global TEMPORAL_PAD
- TEMPORAL_PAD = pad
-
-
-def get_temporal_pad() -> int:
- return TEMPORAL_PAD
+def get_pad(name) -> int:
+ return PAD_DICT[name]
def all_to_all_with_pad(
@@ -418,3 +404,17 @@ def all_to_all_with_pad(
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
return input_
+
+
+def split_from_second_dim(x, batch_size, parallel_group):
+ x = x.view(batch_size, -1, *x.shape[1:])
+ x = split_sequence(x, parallel_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
+ x = x.reshape(-1, *x.shape[2:])
+ return x
+
+
+def gather_from_second_dim(x, batch_size, parallel_group):
+ x = x.view(batch_size, -1, *x.shape[1:])
+ x = gather_sequence(x, parallel_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
+ x = x.reshape(-1, *x.shape[2:])
+ return x
diff --git a/videosys/core/engine.py b/videosys/core/engine.py
index 6d1868f..086b193 100644
--- a/videosys/core/engine.py
+++ b/videosys/core/engine.py
@@ -66,7 +66,7 @@ class VideoSysEngine:
# TODO: add more options here for pipeline, or wrap all options into config
def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None):
- videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method, seed=42)
+ videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method)
pipeline = pipeline_cls(self.config)
return pipeline
diff --git a/videosys/core/pab_mgr.py b/videosys/core/pab_mgr.py
index ecd387b..a582105 100644
--- a/videosys/core/pab_mgr.py
+++ b/videosys/core/pab_mgr.py
@@ -6,7 +6,6 @@ PAB_MANAGER = None
class PABConfig:
def __init__(
self,
- steps: int,
cross_broadcast: bool = False,
cross_threshold: list = None,
cross_range: int = None,
@@ -20,7 +19,7 @@ class PABConfig:
mlp_spatial_broadcast_config: dict = None,
mlp_temporal_broadcast_config: dict = None,
):
- self.steps = steps
+ self.steps = None
self.cross_broadcast = cross_broadcast
self.cross_threshold = cross_threshold
@@ -45,7 +44,7 @@ class PABManager:
def __init__(self, config: PABConfig):
self.config: PABConfig = config
- init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
+ init_prompt = f"Init Pyramid Attention Broadcast."
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
@@ -78,7 +77,7 @@ class PABManager:
count = (count + 1) % self.config.steps
return flag, count
- def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
+ def if_broadcast_spatial(self, timestep: int, count: int):
if (
self.config.spatial_broadcast
and (timestep is not None)
@@ -213,10 +212,10 @@ def if_broadcast_temporal(timestep: int, count: int):
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
-def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
+def if_broadcast_spatial(timestep: int, count: int):
if not enable_pab():
return False, count
- return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
+ return PAB_MANAGER.if_broadcast_spatial(timestep, count)
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
diff --git a/videosys/core/parallel_mgr.py b/videosys/core/parallel_mgr.py
index fbb9ccf..2036930 100644
--- a/videosys/core/parallel_mgr.py
+++ b/videosys/core/parallel_mgr.py
@@ -1,14 +1,9 @@
-from typing import Optional
-
import torch
import torch.distributed as dist
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from torch.distributed import ProcessGroup
from videosys.utils.logging import init_dist_logger, logger
-from videosys.utils.utils import set_seed
-
-PARALLEL_MANAGER = None
class ParallelManager(ProcessGroupMesh):
@@ -21,71 +16,28 @@ class ParallelManager(ProcessGroupMesh):
self.dp_rank = dist.get_rank(self.dp_group)
self.cp_size = cp_size
- self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
- self.cp_rank = dist.get_rank(self.cp_group)
+ if cp_size > 1:
+ self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
+ self.cp_rank = dist.get_rank(self.cp_group)
+ else:
+ self.cp_group = None
+ self.cp_rank = None
self.sp_size = sp_size
- self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
- self.sp_rank = dist.get_rank(self.sp_group)
- self.enable_sp = sp_size > 1
+ if sp_size > 1:
+ self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
+ self.sp_rank = dist.get_rank(self.sp_group)
+ else:
+ self.sp_group = None
+ self.sp_rank = None
logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}")
-def set_parallel_manager(dp_size, cp_size, sp_size):
- global PARALLEL_MANAGER
- PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size)
-
-
-def get_data_parallel_group():
- return PARALLEL_MANAGER.dp_group
-
-
-def get_data_parallel_size():
- return PARALLEL_MANAGER.dp_size
-
-
-def get_data_parallel_rank():
- return PARALLEL_MANAGER.dp_rank
-
-
-def get_sequence_parallel_group():
- return PARALLEL_MANAGER.sp_group
-
-
-def get_sequence_parallel_size():
- return PARALLEL_MANAGER.sp_size
-
-
-def get_sequence_parallel_rank():
- return PARALLEL_MANAGER.sp_rank
-
-
-def get_cfg_parallel_group():
- return PARALLEL_MANAGER.cp_group
-
-
-def get_cfg_parallel_size():
- return PARALLEL_MANAGER.cp_size
-
-
-def enable_sequence_parallel():
- if PARALLEL_MANAGER is None:
- return False
- return PARALLEL_MANAGER.enable_sp
-
-
-def get_parallel_manager():
- return PARALLEL_MANAGER
-
-
def initialize(
rank=0,
world_size=1,
init_method=None,
- seed: Optional[int] = None,
- sp_size: Optional[int] = None,
- enable_cp: bool = False,
):
if not dist.is_initialized():
try:
@@ -97,24 +49,3 @@ def initialize(
init_dist_logger()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
-
- # init sequence parallel
- if sp_size is None:
- sp_size = dist.get_world_size()
- dp_size = 1
- else:
- assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size"
- dp_size = dist.get_world_size() // sp_size
-
- # update cfg parallel
- # NOTE: enable cp parallel will be slower. disable it for now.
- if False and enable_cp and sp_size % 2 == 0:
- sp_size = sp_size // 2
- cp_size = 2
- else:
- cp_size = 1
-
- set_parallel_manager(dp_size, cp_size, sp_size)
-
- if seed is not None:
- set_seed(seed + get_data_parallel_rank())
diff --git a/videosys/core/pipeline.py b/videosys/core/pipeline.py
index 75b79d3..0aafb96 100644
--- a/videosys/core/pipeline.py
+++ b/videosys/core/pipeline.py
@@ -13,9 +13,10 @@ class VideoSysPipeline(DiffusionPipeline):
@staticmethod
def set_eval_and_device(device: torch.device, *modules):
- for module in modules:
- module.eval()
- module.to(device)
+ modules = list(modules)
+ for i in range(len(modules)):
+ modules[i] = modules[i].eval()
+ modules[i] = modules[i].to(device)
@abstractmethod
def generate(self, *args, **kwargs):
diff --git a/videosys/models/autoencoders/autoencoder_kl_open_sora.py b/videosys/models/autoencoders/autoencoder_kl_open_sora.py
index d073fe4..9a69d80 100644
--- a/videosys/models/autoencoders/autoencoder_kl_open_sora.py
+++ b/videosys/models/autoencoders/autoencoder_kl_open_sora.py
@@ -694,7 +694,10 @@ class VideoAutoencoderPipeline(PreTrainedModel):
else:
return x
- def forward(self, x):
+ def forward(self, x, decode_only=False, **kwargs):
+ if decode_only:
+ return self.decode(x, **kwargs)
+
assert self.cal_loss, "This method is only available when cal_loss is True"
z, posterior, x_z = self.encode(x)
x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2])
diff --git a/videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v110.py b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v110.py
new file mode 100644
index 0000000..30338fc
--- /dev/null
+++ b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v110.py
@@ -0,0 +1,1643 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+import glob
+import os
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers import ConfigMixin, ModelMixin
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import logging
+from einops import rearrange
+from torch import nn
+
+logging.set_verbosity_error()
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+def tensor_to_video(x):
+ x = x.detach().cpu()
+ x = torch.clamp(x, -1, 1)
+ x = (x + 1) / 2
+ x = x.permute(1, 0, 2, 3).float().numpy() # c t h w ->
+ x = (255 * x).astype(np.uint8)
+ return x
+
+
+def nonlinearity(x):
+ return x * torch.sigmoid(x)
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def resolve_str_to_obj(str_val, append=True):
+ return globals()[str_val]
+
+
+class VideoBaseAE_PL(ModelMixin, ConfigMixin):
+ config_name = "config.json"
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ def encode(self, x: torch.Tensor, *args, **kwargs):
+ pass
+
+ def decode(self, encoding: torch.Tensor, *args, **kwargs):
+ pass
+
+ @property
+ def num_training_steps(self) -> int:
+ """Total training steps inferred from datamodule and devices."""
+ if self.trainer.max_steps:
+ return self.trainer.max_steps
+
+ limit_batches = self.trainer.limit_train_batches
+ batches = len(self.train_dataloader())
+ batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches)
+
+ num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
+ if self.trainer.tpu_cores:
+ num_devices = max(num_devices, self.trainer.tpu_cores)
+
+ effective_accum = self.trainer.accumulate_grad_batches * num_devices
+ return (batches // effective_accum) * self.trainer.max_epochs
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, "*.ckpt"))
+ if ckpt_files:
+ # Adapt to PyTorch Lightning
+ last_ckpt_file = ckpt_files[-1]
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+ model = cls.from_config(config_file)
+ # print("init from {}".format(last_ckpt_file))
+ model.init_from_ckpt(last_ckpt_file)
+ return model
+ else:
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ z_channels: int,
+ hidden_size: int,
+ hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
+ attn_resolutions: Tuple[int] = (16,),
+ conv_in: str = "Conv2d",
+ conv_out: str = "CasualConv3d",
+ attention: str = "AttnBlock",
+ resnet_blocks: Tuple[str] = (
+ "ResnetBlock2D",
+ "ResnetBlock2D",
+ "ResnetBlock2D",
+ "ResnetBlock3D",
+ ),
+ spatial_downsample: Tuple[str] = (
+ "Downsample",
+ "Downsample",
+ "Downsample",
+ "",
+ ),
+ temporal_downsample: Tuple[str] = ("", "", "TimeDownsampleRes2x", ""),
+ mid_resnet: str = "ResnetBlock3D",
+ dropout: float = 0.0,
+ resolution: int = 256,
+ num_res_blocks: int = 2,
+ double_z: bool = True,
+ ) -> None:
+ super().__init__()
+ assert len(resnet_blocks) == len(hidden_size_mult), print(hidden_size_mult, resnet_blocks)
+ # ---- Config ----
+ self.num_resolutions = len(hidden_size_mult)
+ self.resolution = resolution
+ self.num_res_blocks = num_res_blocks
+
+ # ---- In ----
+ self.conv_in = resolve_str_to_obj(conv_in)(3, hidden_size, kernel_size=3, stride=1, padding=1)
+
+ # ---- Downsample ----
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(hidden_size_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = hidden_size * in_ch_mult[i_level]
+ block_out = hidden_size * hidden_size_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ resolve_str_to_obj(resnet_blocks[i_level])(
+ in_channels=block_in,
+ out_channels=block_out,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(resolve_str_to_obj(attention)(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if spatial_downsample[i_level]:
+ down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(block_in, block_in)
+ curr_res = curr_res // 2
+ if temporal_downsample[i_level]:
+ down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(block_in, block_in)
+ self.down.append(down)
+
+ # ---- Mid ----
+ self.mid = nn.Module()
+ self.mid.block_1 = resolve_str_to_obj(mid_resnet)(
+ in_channels=block_in,
+ out_channels=block_in,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = resolve_str_to_obj(attention)(block_in)
+ self.mid.block_2 = resolve_str_to_obj(mid_resnet)(
+ in_channels=block_in,
+ out_channels=block_in,
+ dropout=dropout,
+ )
+ # ---- Out ----
+ self.norm_out = Normalize(block_in)
+ self.conv_out = resolve_str_to_obj(conv_out)(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1])
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if hasattr(self.down[i_level], "downsample"):
+ hs.append(self.down[i_level].downsample(hs[-1]))
+ if hasattr(self.down[i_level], "time_downsample"):
+ hs_down = self.down[i_level].time_downsample(hs[-1])
+ hs.append(hs_down)
+
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ z_channels: int,
+ hidden_size: int,
+ hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
+ attn_resolutions: Tuple[int] = (16,),
+ conv_in: str = "Conv2d",
+ conv_out: str = "CasualConv3d",
+ attention: str = "AttnBlock",
+ resnet_blocks: Tuple[str] = (
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ ),
+ spatial_upsample: Tuple[str] = (
+ "",
+ "SpatialUpsample2x",
+ "SpatialUpsample2x",
+ "SpatialUpsample2x",
+ ),
+ temporal_upsample: Tuple[str] = ("", "", "", "TimeUpsampleRes2x"),
+ mid_resnet: str = "ResnetBlock3D",
+ dropout: float = 0.0,
+ resolution: int = 256,
+ num_res_blocks: int = 2,
+ ):
+ super().__init__()
+ # ---- Config ----
+ self.num_resolutions = len(hidden_size_mult)
+ self.resolution = resolution
+ self.num_res_blocks = num_res_blocks
+
+ # ---- In ----
+ block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.conv_in = resolve_str_to_obj(conv_in)(z_channels, block_in, kernel_size=3, padding=1)
+
+ # ---- Mid ----
+ self.mid = nn.Module()
+ self.mid.block_1 = resolve_str_to_obj(mid_resnet)(
+ in_channels=block_in,
+ out_channels=block_in,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = resolve_str_to_obj(attention)(block_in)
+ self.mid.block_2 = resolve_str_to_obj(mid_resnet)(
+ in_channels=block_in,
+ out_channels=block_in,
+ dropout=dropout,
+ )
+
+ # ---- Upsample ----
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = hidden_size * hidden_size_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ resolve_str_to_obj(resnet_blocks[i_level])(
+ in_channels=block_in,
+ out_channels=block_out,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(resolve_str_to_obj(attention)(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if spatial_upsample[i_level]:
+ up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(block_in, block_in)
+ curr_res = curr_res * 2
+ if temporal_upsample[i_level]:
+ up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(block_in, block_in)
+ self.up.insert(0, up)
+
+ # ---- Out ----
+ self.norm_out = Normalize(block_in)
+ self.conv_out = resolve_str_to_obj(conv_out)(block_in, 3, kernel_size=3, padding=1)
+
+ def forward(self, z):
+ h = self.conv_in(z)
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if hasattr(self.up[i_level], "upsample"):
+ h = self.up[i_level].upsample(h)
+ if hasattr(self.up[i_level], "time_upsample"):
+ h = self.up[i_level].time_upsample(h)
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class CausalVAEModel(VideoBaseAE_PL):
+ @register_to_config
+ def __init__(
+ self,
+ lr: float = 1e-5,
+ hidden_size: int = 128,
+ z_channels: int = 4,
+ hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
+ attn_resolutions: Tuple[int] = [],
+ dropout: float = 0.0,
+ resolution: int = 256,
+ double_z: bool = True,
+ embed_dim: int = 4,
+ num_res_blocks: int = 2,
+ loss_type: str = "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator",
+ loss_params: dict = {
+ "kl_weight": 0.000001,
+ "logvar_init": 0.0,
+ "disc_start": 2001,
+ "disc_weight": 0.5,
+ },
+ q_conv: str = "CausalConv3d",
+ encoder_conv_in: str = "CausalConv3d",
+ encoder_conv_out: str = "CausalConv3d",
+ encoder_attention: str = "AttnBlock3D",
+ encoder_resnet_blocks: Tuple[str] = (
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ ),
+ encoder_spatial_downsample: Tuple[str] = (
+ "SpatialDownsample2x",
+ "SpatialDownsample2x",
+ "SpatialDownsample2x",
+ "",
+ ),
+ encoder_temporal_downsample: Tuple[str] = (
+ "",
+ "TimeDownsample2x",
+ "TimeDownsample2x",
+ "",
+ ),
+ encoder_mid_resnet: str = "ResnetBlock3D",
+ decoder_conv_in: str = "CausalConv3d",
+ decoder_conv_out: str = "CausalConv3d",
+ decoder_attention: str = "AttnBlock3D",
+ decoder_resnet_blocks: Tuple[str] = (
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ ),
+ decoder_spatial_upsample: Tuple[str] = (
+ "",
+ "SpatialUpsample2x",
+ "SpatialUpsample2x",
+ "SpatialUpsample2x",
+ ),
+ decoder_temporal_upsample: Tuple[str] = ("", "", "TimeUpsample2x", "TimeUpsample2x"),
+ decoder_mid_resnet: str = "ResnetBlock3D",
+ ) -> None:
+ super().__init__()
+ self.tile_sample_min_size = 256
+ self.tile_sample_min_size_t = 65
+ self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1)))
+ t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0]
+ self.tile_latent_min_size_t = int((self.tile_sample_min_size_t - 1) / (2 ** len(t_down_ratio))) + 1
+ self.tile_overlap_factor = 0.25
+ self.use_tiling = False
+
+ self.learning_rate = lr
+ self.lr_g_factor = 1.0
+
+ self.encoder = Encoder(
+ z_channels=z_channels,
+ hidden_size=hidden_size,
+ hidden_size_mult=hidden_size_mult,
+ attn_resolutions=attn_resolutions,
+ conv_in=encoder_conv_in,
+ conv_out=encoder_conv_out,
+ attention=encoder_attention,
+ resnet_blocks=encoder_resnet_blocks,
+ spatial_downsample=encoder_spatial_downsample,
+ temporal_downsample=encoder_temporal_downsample,
+ mid_resnet=encoder_mid_resnet,
+ dropout=dropout,
+ resolution=resolution,
+ num_res_blocks=num_res_blocks,
+ double_z=double_z,
+ )
+
+ self.decoder = Decoder(
+ z_channels=z_channels,
+ hidden_size=hidden_size,
+ hidden_size_mult=hidden_size_mult,
+ attn_resolutions=attn_resolutions,
+ conv_in=decoder_conv_in,
+ conv_out=decoder_conv_out,
+ attention=decoder_attention,
+ resnet_blocks=decoder_resnet_blocks,
+ spatial_upsample=decoder_spatial_upsample,
+ temporal_upsample=decoder_temporal_upsample,
+ mid_resnet=decoder_mid_resnet,
+ dropout=dropout,
+ resolution=resolution,
+ num_res_blocks=num_res_blocks,
+ )
+
+ quant_conv_cls = resolve_str_to_obj(q_conv)
+ self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1)
+ self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1)
+
+ def encode(self, x):
+ if self.use_tiling and (
+ x.shape[-1] > self.tile_sample_min_size
+ or x.shape[-2] > self.tile_sample_min_size
+ or x.shape[-3] > self.tile_sample_min_size_t
+ ):
+ return self.tiled_encode(x)
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ if self.use_tiling and (
+ z.shape[-1] > self.tile_latent_min_size
+ or z.shape[-2] > self.tile_latent_min_size
+ or z.shape[-3] > self.tile_latent_min_size_t
+ ):
+ return self.tiled_decode(z)
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def training_step(self, batch, batch_idx):
+ if hasattr(self.loss, "discriminator"):
+ return self._training_step_gan(batch, batch_idx=batch_idx)
+ else:
+ return self._training_step(batch, batch_idx=batch_idx)
+
+ def _training_step(self, batch, batch_idx):
+ inputs = self.get_input(batch, "video")
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(
+ inputs,
+ reconstructions,
+ posterior,
+ split="train",
+ )
+ self.log(
+ "aeloss",
+ aeloss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ def _training_step_gan(self, batch, batch_idx):
+ inputs = self.get_input(batch, "video")
+ reconstructions, posterior = self(inputs)
+ opt1, opt2 = self.optimizers()
+
+ # ---- AE Loss ----
+ aeloss, log_dict_ae = self.loss(
+ inputs,
+ reconstructions,
+ posterior,
+ 0,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="train",
+ )
+ self.log(
+ "aeloss",
+ aeloss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ opt1.zero_grad()
+ self.manual_backward(aeloss)
+ self.clip_gradients(opt1, gradient_clip_val=1, gradient_clip_algorithm="norm")
+ opt1.step()
+ # ---- GAN Loss ----
+ discloss, log_dict_disc = self.loss(
+ inputs,
+ reconstructions,
+ posterior,
+ 1,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="train",
+ )
+ self.log(
+ "discloss",
+ discloss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ opt2.zero_grad()
+ self.manual_backward(discloss)
+ self.clip_gradients(opt2, gradient_clip_val=1, gradient_clip_algorithm="norm")
+ opt2.step()
+ self.log_dict(
+ {**log_dict_ae, **log_dict_disc},
+ prog_bar=False,
+ logger=True,
+ on_step=True,
+ on_epoch=False,
+ )
+
+ def configure_optimizers(self):
+ from itertools import chain
+
+ lr = self.learning_rate
+ modules_to_train = [
+ self.encoder.named_parameters(),
+ self.decoder.named_parameters(),
+ self.post_quant_conv.named_parameters(),
+ self.quant_conv.named_parameters(),
+ ]
+ params_with_time = []
+ params_without_time = []
+ for name, param in chain(*modules_to_train):
+ if "time" in name:
+ params_with_time.append(param)
+ else:
+ params_without_time.append(param)
+ optimizers = []
+ opt_ae = torch.optim.Adam(
+ [
+ {"params": params_with_time, "lr": lr},
+ {"params": params_without_time, "lr": lr},
+ ],
+ lr=lr,
+ betas=(0.5, 0.9),
+ )
+ optimizers.append(opt_ae)
+
+ if hasattr(self.loss, "discriminator"):
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
+ optimizers.append(opt_disc)
+
+ return optimizers, []
+
+ def get_last_layer(self):
+ if hasattr(self.decoder.conv_out, "conv"):
+ return self.decoder.conv_out.conv.weight
+ else:
+ return self.decoder.conv_out.weight
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x):
+ t = x.shape[2]
+ t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)]
+ if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
+ t_chunk_start_end = [[0, t]]
+ else:
+ t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)]
+ if t_chunk_start_end[-1][-1] > t:
+ t_chunk_start_end[-1][-1] = t
+ elif t_chunk_start_end[-1][-1] < t:
+ last_start_end = [t_chunk_idx[-1], t]
+ t_chunk_start_end.append(last_start_end)
+ moments = []
+ for idx, (start, end) in enumerate(t_chunk_start_end):
+ chunk_x = x[:, :, start:end]
+ if idx != 0:
+ moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:]
+ else:
+ moment = self.tiled_encode2d(chunk_x, return_moments=True)
+ moments.append(moment)
+ moments = torch.cat(moments, dim=2)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def tiled_decode(self, x):
+ t = x.shape[2]
+ t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)]
+ if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
+ t_chunk_start_end = [[0, t]]
+ else:
+ t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)]
+ if t_chunk_start_end[-1][-1] > t:
+ t_chunk_start_end[-1][-1] = t
+ elif t_chunk_start_end[-1][-1] < t:
+ last_start_end = [t_chunk_idx[-1], t]
+ t_chunk_start_end.append(last_start_end)
+ dec_ = []
+ for idx, (start, end) in enumerate(t_chunk_start_end):
+ chunk_x = x[:, :, start:end]
+ if idx != 0:
+ dec = self.tiled_decode2d(chunk_x)[:, :, 1:]
+ else:
+ dec = self.tiled_decode2d(chunk_x)
+ dec_.append(dec)
+ dec_ = torch.cat(dec_, dim=2)
+ return dec_
+
+ def tiled_encode2d(self, x, return_moments=False):
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[3], overlap_size):
+ row = []
+ for j in range(0, x.shape[4], overlap_size):
+ tile = x[
+ :,
+ :,
+ :,
+ i : i + self.tile_sample_min_size,
+ j : j + self.tile_sample_min_size,
+ ]
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ moments = torch.cat(result_rows, dim=3)
+ posterior = DiagonalGaussianDistribution(moments)
+ if return_moments:
+ return moments
+ return posterior
+
+ def tiled_decode2d(self, z):
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[3], overlap_size):
+ row = []
+ for j in range(0, z.shape[4], overlap_size):
+ tile = z[
+ :,
+ :,
+ :,
+ i : i + self.tile_latent_min_size,
+ j : j + self.tile_latent_min_size,
+ ]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ dec = torch.cat(result_rows, dim=3)
+ return dec
+
+ def enable_tiling(self, use_tiling: bool = True):
+ self.use_tiling = use_tiling
+
+ def disable_tiling(self):
+ self.enable_tiling(False)
+
+ def init_from_ckpt(self, path, ignore_keys=list(), remove_loss=False):
+ sd = torch.load(path, map_location="cpu")
+ # print("init from " + path)
+ if "state_dict" in sd:
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.get_input(batch, "video")
+ latents = self.encode(inputs).sample()
+ video_recon = self.decode(latents)
+ for idx in range(len(video_recon)):
+ self.logger.log_video(f"recon {batch_idx} {idx}", [tensor_to_video(video_recon[idx])], fps=[10])
+
+
+class CausalVAEModelWrapper(nn.Module):
+ def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs):
+ super(CausalVAEModelWrapper, self).__init__()
+ # if os.path.exists(ckpt):
+ # self.vae = CausalVAEModel.load_from_checkpoint(ckpt)
+ self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs)
+
+ def encode(self, x): # b c t h w
+ # x = self.vae.encode(x).sample()
+ x = self.vae.encode(x).sample().mul_(0.18215)
+ return x
+
+ def decode(self, x):
+ # x = self.vae.decode(x)
+ x = self.vae.decode(x / 0.18215)
+ x = rearrange(x, "b c t h w -> b t c h w").contiguous()
+ return x
+
+ def dtype(self):
+ return self.vae.dtype
+
+ #
+ # def device(self):
+ # return self.vae.device
+
+ def forward(self, x):
+ return self.decode(x)
+
+
+videobase_ae_stride = {
+ "CausalVAEModel_4x8x8": [4, 8, 8],
+}
+
+videobase_ae_channel = {
+ "CausalVAEModel_4x8x8": 4,
+}
+
+videobase_ae = {
+ "CausalVAEModel_4x8x8": CausalVAEModelWrapper,
+}
+
+
+ae_stride_config = {}
+ae_stride_config.update(videobase_ae_stride)
+
+ae_channel_config = {}
+ae_channel_config.update(videobase_ae_channel)
+
+
+def getae_wrapper(ae):
+ """deprecation"""
+ ae = videobase_ae.get(ae, None)
+ assert ae is not None
+ return ae
+
+
+def video_to_image(func):
+ def wrapper(self, x, *args, **kwargs):
+ if x.dim() == 5:
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = func(self, x, *args, **kwargs)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ return x
+
+ return wrapper
+
+
+class Block(nn.Module):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+
+class LinearAttention(Block):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock3D(Block):
+ """Compatible with old versions, there are issues, use with caution."""
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, t, h, w = q.shape
+ q = q.reshape(b * t, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b * t, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b * t, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, t, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class AttnBlock3DFix(nn.Module):
+ """
+ Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172.
+ """
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c)
+ b, c, t, h, w = q.shape
+ q = q.permute(0, 2, 1, 3, 4)
+ q = q.reshape(b * t, c, h * w)
+ q = q.permute(0, 2, 1)
+
+ # k: (b c t h w) -> (b t c h w) -> (b*t c h*w)
+ k = k.permute(0, 2, 1, 3, 4)
+ k = k.reshape(b * t, c, h * w)
+
+ # w: (b*t hw hw)
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ # v: (b c t h w) -> (b t c h w) -> (bt c hw)
+ # w_: (bt hw hw) -> (bt hw hw)
+ v = v.permute(0, 2, 1, 3, 4)
+ v = v.reshape(b * t, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+
+ # h_: (b*t c hw) -> (b t c h w) -> (b c t h w)
+ h_ = h_.reshape(b, t, c, h, w)
+ h_ = h_.permute(0, 2, 1, 3, 4)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class AttnBlock(Block):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ @video_to_image
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class TemporalAttnBlock(Block):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, t, h, w = q.shape
+ q = rearrange(q, "b c t h w -> (b h w) t c")
+ k = rearrange(k, "b c t h w -> (b h w) c t")
+ v = rearrange(v, "b c t h w -> (b h w) c t")
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ w_ = w_.permute(0, 2, 1)
+ h_ = torch.bmm(v, w_)
+ h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown"
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ print(attn_type)
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla3D":
+ return AttnBlock3D(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Conv2d(nn.Conv2d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int]] = 3,
+ stride: Union[int, Tuple[int]] = 1,
+ padding: Union[str, int, Tuple[int]] = 0,
+ dilation: Union[int, Tuple[int]] = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = "zeros",
+ device=None,
+ dtype=None,
+ ) -> None:
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups,
+ bias,
+ padding_mode,
+ device,
+ dtype,
+ )
+
+ @video_to_image
+ def forward(self, x):
+ return super().forward(x)
+
+
+class CausalConv3d(nn.Module):
+ def __init__(
+ self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs
+ ):
+ super().__init__()
+ self.kernel_size = cast_tuple(kernel_size, 3)
+ self.time_kernel_size = self.kernel_size[0]
+ self.chan_in = chan_in
+ self.chan_out = chan_out
+ stride = kwargs.pop("stride", 1)
+ padding = kwargs.pop("padding", 0)
+ padding = list(cast_tuple(padding, 3))
+ padding[0] = 0
+ stride = cast_tuple(stride, 3)
+ self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding)
+ self._init_weights(init_method)
+
+ def _init_weights(self, init_method):
+ torch.tensor(self.kernel_size)
+ if init_method == "avg":
+ assert self.kernel_size[1] == 1 and self.kernel_size[2] == 1, "only support temporal up/down sample"
+ assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out"
+ weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size))
+
+ eyes = torch.concat(
+ [
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
+ ],
+ dim=-1,
+ )
+ weight[:, :, :, 0, 0] = eyes
+
+ self.conv.weight = nn.Parameter(
+ weight,
+ requires_grad=True,
+ )
+ elif init_method == "zero":
+ self.conv.weight = nn.Parameter(
+ torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)),
+ requires_grad=True,
+ )
+ if self.conv.bias is not None:
+ nn.init.constant_(self.conv.bias, 0)
+
+ def forward(self, x):
+ # 1 + 16 16 as video, 1 as image
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w
+ x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
+ return self.conv(x)
+
+
+class GroupNorm(Block):
+ def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True)
+
+ def forward(self, x):
+ return self.norm(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class ActNorm(nn.Module):
+ def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
+ std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height * width * torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
+
+
+def nonlinearity(x):
+ return x * torch.sigmoid(x)
+
+
+def cast_tuple(t, length=1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+
+def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
+ n_dims = len(x.shape)
+ if src_dim < 0:
+ src_dim = n_dims + src_dim
+ if dest_dim < 0:
+ dest_dim = n_dims + dest_dim
+ assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
+ dims = list(range(n_dims))
+ del dims[src_dim]
+ permutation = []
+ ctr = 0
+ for i in range(n_dims):
+ if i == dest_dim:
+ permutation.append(src_dim)
+ else:
+ permutation.append(dims[ctr])
+ ctr += 1
+ x = x.permute(permutation)
+ if make_contiguous:
+ x = x.contiguous()
+ return x
+
+
+class Codebook(nn.Module):
+ def __init__(self, n_codes, embedding_dim):
+ super().__init__()
+ self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
+ self.register_buffer("N", torch.zeros(n_codes))
+ self.register_buffer("z_avg", self.embeddings.data.clone())
+
+ self.n_codes = n_codes
+ self.embedding_dim = embedding_dim
+ self._need_init = True
+
+ def _tile(self, x):
+ d, ew = x.shape
+ if d < self.n_codes:
+ n_repeats = (self.n_codes + d - 1) // d
+ std = 0.01 / np.sqrt(ew)
+ x = x.repeat(n_repeats, 1)
+ x = x + torch.randn_like(x) * std
+ return x
+
+ def _init_embeddings(self, z):
+ # z: [b, c, t, h, w]
+ self._need_init = False
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
+ y = self._tile(flat_inputs)
+
+ y.shape[0]
+ _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
+ if dist.is_initialized():
+ dist.broadcast(_k_rand, 0)
+ self.embeddings.data.copy_(_k_rand)
+ self.z_avg.data.copy_(_k_rand)
+ self.N.data.copy_(torch.ones(self.n_codes))
+
+ def forward(self, z):
+ # z: [b, c, t, h, w]
+ if self._need_init and self.training:
+ self._init_embeddings(z)
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
+ distances = (
+ (flat_inputs**2).sum(dim=1, keepdim=True)
+ - 2 * flat_inputs @ self.embeddings.t()
+ + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
+ )
+
+ encoding_indices = torch.argmin(distances, dim=1)
+ encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs)
+ encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
+
+ embeddings = F.embedding(encoding_indices, self.embeddings)
+ embeddings = shift_dim(embeddings, -1, 1)
+
+ commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
+
+ # EMA codebook update
+ if self.training:
+ n_total = encode_onehot.sum(dim=0)
+ encode_sum = flat_inputs.t() @ encode_onehot
+ if dist.is_initialized():
+ dist.all_reduce(n_total)
+ dist.all_reduce(encode_sum)
+
+ self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
+ self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
+
+ n = self.N.sum()
+ weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
+ encode_normalized = self.z_avg / weights.unsqueeze(1)
+ self.embeddings.data.copy_(encode_normalized)
+
+ y = self._tile(flat_inputs)
+ _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
+ if dist.is_initialized():
+ dist.broadcast(_k_rand, 0)
+
+ usage = (self.N.view(self.n_codes, 1) >= 1).float()
+ self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
+
+ embeddings_st = (embeddings - z).detach() + z
+
+ avg_probs = torch.mean(encode_onehot, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ return dict(
+ embeddings=embeddings_st,
+ encodings=encoding_indices,
+ commitment_loss=commitment_loss,
+ perplexity=perplexity,
+ )
+
+ def dictionary_lookup(self, encodings):
+ embeddings = F.embedding(encodings, self.embeddings)
+ return embeddings
+
+
+class ResnetBlock2D(Block):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ @video_to_image
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+ x = x + h
+ return x
+
+
+class ResnetBlock3D(Block):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1)
+ else:
+ self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0)
+
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+ return x + h
+
+
+class Upsample(Block):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.with_conv = True
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ @video_to_image
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(Block):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.with_conv = True
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
+
+ @video_to_image
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class SpatialDownsample2x(Block):
+ def __init__(
+ self,
+ chan_in,
+ chan_out,
+ kernel_size: Union[int, Tuple[int]] = (3, 3),
+ stride: Union[int, Tuple[int]] = (2, 2),
+ ):
+ super().__init__()
+ kernel_size = cast_tuple(kernel_size, 2)
+ stride = cast_tuple(stride, 2)
+ self.chan_in = chan_in
+ self.chan_out = chan_out
+ self.kernel_size = kernel_size
+ self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=0)
+
+ def forward(self, x):
+ pad = (0, 1, 0, 1, 0, 0)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class SpatialUpsample2x(Block):
+ def __init__(
+ self,
+ chan_in,
+ chan_out,
+ kernel_size: Union[int, Tuple[int]] = (3, 3),
+ stride: Union[int, Tuple[int]] = (1, 1),
+ ):
+ super().__init__()
+ self.chan_in = chan_in
+ self.chan_out = chan_out
+ self.kernel_size = kernel_size
+ self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=1)
+
+ def forward(self, x):
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> b (c t) h w")
+ x = F.interpolate(x, scale_factor=(2, 2), mode="nearest")
+ x = rearrange(x, "b (c t) h w -> b c t h w", t=t)
+ x = self.conv(x)
+ return x
+
+
+class TimeDownsample2x(Block):
+ def __init__(self, chan_in, chan_out, kernel_size: int = 3):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.conv = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
+
+ def forward(self, x):
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size - 1, 1, 1))
+ x = torch.concatenate((first_frame_pad, x), dim=2)
+ return self.conv(x)
+
+
+class TimeUpsample2x(Block):
+ def __init__(self, chan_in, chan_out):
+ super().__init__()
+
+ def forward(self, x):
+ if x.size(2) > 1:
+ x, x_ = x[:, :, :1], x[:, :, 1:]
+ x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
+ x = torch.concat([x, x_], dim=2)
+ return x
+
+
+class TimeDownsampleRes2x(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size: int = 3,
+ mix_factor: float = 2.0,
+ ):
+ super().__init__()
+ self.kernel_size = cast_tuple(kernel_size, 3)
+ self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
+ self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
+
+ def forward(self, x):
+ alpha = torch.sigmoid(self.mix_factor)
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
+ x = torch.concatenate((first_frame_pad, x), dim=2)
+ return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x)
+
+
+class TimeUpsampleRes2x(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size: int = 3,
+ mix_factor: float = 2.0,
+ ):
+ super().__init__()
+ self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
+
+ def forward(self, x):
+ alpha = torch.sigmoid(self.mix_factor)
+ if x.size(2) > 1:
+ x, x_ = x[:, :, :1], x[:, :, 1:]
+ x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
+ x = torch.concat([x, x_], dim=2)
+ return alpha * x + (1 - alpha) * self.conv(x)
+
+
+class TimeDownsampleResAdv2x(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size: int = 3,
+ mix_factor: float = 1.5,
+ ):
+ super().__init__()
+ self.kernel_size = cast_tuple(kernel_size, 3)
+ self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
+ self.attn = TemporalAttnBlock(in_channels)
+ self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
+ self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
+
+ def forward(self, x):
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
+ x = torch.concatenate((first_frame_pad, x), dim=2)
+ alpha = torch.sigmoid(self.mix_factor)
+ return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(self.attn((self.res(x))))
+
+
+class TimeUpsampleResAdv2x(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size: int = 3,
+ mix_factor: float = 1.5,
+ ):
+ super().__init__()
+ self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
+ self.attn = TemporalAttnBlock(in_channels)
+ self.norm = Normalize(in_channels=in_channels)
+ self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
+
+ def forward(self, x):
+ if x.size(2) > 1:
+ x, x_ = x[:, :, :1], x[:, :, 1:]
+ x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
+ x = torch.concat([x, x_], dim=2)
+ alpha = torch.sigmoid(self.mix_factor)
+ return alpha * x + (1 - alpha) * self.conv(self.attn(self.res(x)))
diff --git a/videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v120.py b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v120.py
new file mode 100644
index 0000000..0bb4aa8
--- /dev/null
+++ b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v120.py
@@ -0,0 +1,1139 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+import glob
+import os
+from copy import deepcopy
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers import ConfigMixin, ModelMixin
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from einops import rearrange
+from huggingface_hub import hf_hub_download
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+from torchvision.transforms import Lambda
+
+npu_config = None
+
+
+def cast_tuple(t, length=1):
+ return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length)
+
+
+class Block(nn.Module):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+
+class CausalConv3d(nn.Module):
+ def __init__(
+ self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs
+ ):
+ super().__init__()
+ self.kernel_size = cast_tuple(kernel_size, 3)
+ self.time_kernel_size = self.kernel_size[0]
+ self.chan_in = chan_in
+ self.chan_out = chan_out
+ stride = kwargs.pop("stride", 1)
+ padding = kwargs.pop("padding", 0)
+ padding = list(cast_tuple(padding, 3))
+ padding[0] = 0
+ stride = cast_tuple(stride, 3)
+ self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding)
+ self.pad = nn.ReplicationPad2d((0, 0, self.time_kernel_size - 1, 0))
+ self._init_weights(init_method)
+
+ def _init_weights(self, init_method):
+ torch.tensor(self.kernel_size)
+ if init_method == "avg":
+ assert self.kernel_size[1] == 1 and self.kernel_size[2] == 1, "only support temporal up/down sample"
+ assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out"
+ weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size))
+
+ eyes = torch.concat(
+ [
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
+ ],
+ dim=-1,
+ )
+ weight[:, :, :, 0, 0] = eyes
+
+ self.conv.weight = nn.Parameter(
+ weight,
+ requires_grad=True,
+ )
+ elif init_method == "zero":
+ self.conv.weight = nn.Parameter(
+ torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)),
+ requires_grad=True,
+ )
+ if self.conv.bias is not None:
+ nn.init.constant_(self.conv.bias, 0)
+
+ def forward(self, x):
+ if npu_config is not None and npu_config.on_npu:
+ x_dtype = x.dtype
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w
+ x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
+ return npu_config.run_conv3d(self.conv, x, x_dtype)
+ else:
+ # 1 + 16 16 as video, 1 as image
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w
+ x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
+ return self.conv(x)
+
+
+def nonlinearity(x):
+ return x * torch.sigmoid(x)
+
+
+def video_to_image(func):
+ def wrapper(self, x, *args, **kwargs):
+ if x.dim() == 5:
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = func(self, x, *args, **kwargs)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ return x
+
+ return wrapper
+
+
+class Conv2d(nn.Conv2d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int]] = 3,
+ stride: Union[int, Tuple[int]] = 1,
+ padding: Union[str, int, Tuple[int]] = 0,
+ dilation: Union[int, Tuple[int]] = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = "zeros",
+ device=None,
+ dtype=None,
+ ) -> None:
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups,
+ bias,
+ padding_mode,
+ device,
+ dtype,
+ )
+
+ @video_to_image
+ def forward(self, x):
+ return super().forward(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class VideoBaseAE(ModelMixin, ConfigMixin):
+ config_name = "config.json"
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ def encode(self, x: torch.Tensor, *args, **kwargs):
+ pass
+
+ def decode(self, encoding: torch.Tensor, *args, **kwargs):
+ pass
+
+ @property
+ def num_training_steps(self) -> int:
+ """Total training steps inferred from datamodule and devices."""
+ if self.trainer.max_steps:
+ return self.trainer.max_steps
+
+ limit_batches = self.trainer.limit_train_batches
+ batches = len(self.train_dataloader())
+ batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches)
+
+ num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
+ if self.trainer.tpu_cores:
+ num_devices = max(num_devices, self.trainer.tpu_cores)
+
+ effective_accum = self.trainer.accumulate_grad_batches * num_devices
+ return (batches // effective_accum) * self.trainer.max_epochs
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, "*.ckpt"))
+ if not ckpt_files:
+ ckpt_file = hf_hub_download(pretrained_model_name_or_path, subfolder="vae", filename="checkpoint.ckpt")
+ config_file = hf_hub_download(pretrained_model_name_or_path, subfolder="vae", filename="config.json")
+ else:
+ ckpt_file = ckpt_files[-1]
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+
+ # Adapt to checkpoint
+ model = cls.from_config(config_file)
+ model.init_from_ckpt(ckpt_file)
+ return model
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+class ResnetBlock2D(Block):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ @video_to_image
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+ x = x + h
+ return x
+
+
+class ResnetBlock3D(Block):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1)
+ else:
+ self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0)
+
+ def forward(self, x):
+ h = x
+ if npu_config is None:
+ h = self.norm1(h)
+ else:
+ h = npu_config.run_group_norm(self.norm1, h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+ if npu_config is None:
+ h = self.norm2(h)
+ else:
+ h = npu_config.run_group_norm(self.norm2, h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+ return x + h
+
+
+class SpatialUpsample2x(Block):
+ def __init__(
+ self,
+ chan_in,
+ chan_out,
+ kernel_size: Union[int, Tuple[int]] = (3, 3),
+ stride: Union[int, Tuple[int]] = (1, 1),
+ unup=False,
+ ):
+ super().__init__()
+ self.chan_in = chan_in
+ self.chan_out = chan_out
+ self.kernel_size = kernel_size
+ self.unup = unup
+ self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=1)
+
+ def forward(self, x):
+ if not self.unup:
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> b (c t) h w")
+ x = F.interpolate(x, scale_factor=(2, 2), mode="nearest")
+ x = rearrange(x, "b (c t) h w -> b c t h w", t=t)
+ x = self.conv(x)
+ return x
+
+
+class Spatial2xTime2x3DUpsample(Block):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1)
+
+ def forward(self, x):
+ if x.size(2) > 1:
+ x, x_ = x[:, :, :1], x[:, :, 1:]
+ x_ = F.interpolate(x_, scale_factor=(2, 2, 2), mode="trilinear")
+ x = F.interpolate(x, scale_factor=(1, 2, 2), mode="trilinear")
+ x = torch.concat([x, x_], dim=2)
+ else:
+ x = F.interpolate(x, scale_factor=(1, 2, 2), mode="trilinear")
+ return self.conv(x)
+
+
+class AttnBlock3DFix(nn.Module):
+ """
+ Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172.
+ """
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+
+ def forward(self, x):
+ h_ = x
+ if npu_config is None:
+ h_ = self.norm(h_)
+ else:
+ h_ = npu_config.run_group_norm(self.norm, h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c)
+ b, c, t, h, w = q.shape
+ q = q.permute(0, 2, 1, 3, 4)
+ q = q.reshape(b * t, c, h * w)
+ q = q.permute(0, 2, 1)
+
+ # k: (b c t h w) -> (b t c h w) -> (b*t c h*w)
+ k = k.permute(0, 2, 1, 3, 4)
+ k = k.reshape(b * t, c, h * w)
+
+ # w: (b*t hw hw)
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ # v: (b c t h w) -> (b t c h w) -> (bt c hw)
+ # w_: (bt hw hw) -> (bt hw hw)
+ v = v.permute(0, 2, 1, 3, 4)
+ v = v.reshape(b * t, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+
+ # h_: (b*t c hw) -> (b t c h w) -> (b c t h w)
+ h_ = h_.reshape(b, t, c, h, w)
+ h_ = h_.permute(0, 2, 1, 3, 4)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class Spatial2xTime2x3DDownsample(Block):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=0, stride=2)
+
+ def forward(self, x):
+ pad = (0, 1, 0, 1, 0, 0)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Downsample(Block):
+ def __init__(self, in_channels, out_channels, undown=False):
+ super().__init__()
+ self.with_conv = True
+ self.undown = undown
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ if self.undown:
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
+
+ @video_to_image
+ def forward(self, x):
+ if self.with_conv:
+ if self.undown:
+ if npu_config is not None and npu_config.on_npu:
+ x_dtype = x.dtype
+ x = x.to(npu_config.replaced_type)
+ x = npu_config.run_conv3d(self.conv, x, x_dtype)
+ else:
+ x = self.conv(x)
+ else:
+ pad = (0, 1, 0, 1)
+ if npu_config is not None and npu_config.on_npu:
+ x_dtype = x.dtype
+ x = x.to(npu_config.replaced_type)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = npu_config.run_conv3d(self.conv, x, x_dtype)
+ else:
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock3D_GC(Block):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1)
+ else:
+ self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0)
+
+ def forward(self, x):
+ return checkpoint(self._forward, x, use_reentrant=True)
+
+ def _forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+ return x + h
+
+
+def resolve_str_to_obj(str_val, append=True):
+ return globals()[str_val]
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ z_channels: int,
+ hidden_size: int,
+ hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
+ attn_resolutions: Tuple[int] = (16,),
+ conv_in="Conv2d",
+ conv_out="CasualConv3d",
+ attention="AttnBlock",
+ resnet_blocks=(
+ "ResnetBlock2D",
+ "ResnetBlock2D",
+ "ResnetBlock2D",
+ "ResnetBlock3D",
+ ),
+ spatial_downsample=(
+ "Downsample",
+ "Downsample",
+ "Downsample",
+ "",
+ ),
+ temporal_downsample=("", "", "TimeDownsampleRes2x", ""),
+ mid_resnet="ResnetBlock3D",
+ dropout: float = 0.0,
+ resolution: int = 256,
+ num_res_blocks: int = 2,
+ double_z: bool = True,
+ ) -> None:
+ super().__init__()
+ assert len(resnet_blocks) == len(hidden_size_mult), print(hidden_size_mult, resnet_blocks)
+ # ---- Config ----
+ self.num_resolutions = len(hidden_size_mult)
+ self.resolution = resolution
+ self.num_res_blocks = num_res_blocks
+
+ # ---- In ----
+ self.conv_in = resolve_str_to_obj(conv_in)(3, hidden_size, kernel_size=3, stride=1, padding=1)
+
+ # ---- Downsample ----
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(hidden_size_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = hidden_size * in_ch_mult[i_level]
+ block_out = hidden_size * hidden_size_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ resolve_str_to_obj(resnet_blocks[i_level])(
+ in_channels=block_in,
+ out_channels=block_out,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(resolve_str_to_obj(attention)(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if spatial_downsample[i_level]:
+ down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(block_in, block_in)
+ curr_res = curr_res // 2
+ if temporal_downsample[i_level]:
+ down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(block_in, block_in)
+ self.down.append(down)
+
+ # ---- Mid ----
+ self.mid = nn.Module()
+ self.mid.block_1 = resolve_str_to_obj(mid_resnet)(
+ in_channels=block_in,
+ out_channels=block_in,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = resolve_str_to_obj(attention)(block_in)
+ self.mid.block_2 = resolve_str_to_obj(mid_resnet)(
+ in_channels=block_in,
+ out_channels=block_in,
+ dropout=dropout,
+ )
+ # ---- Out ----
+ self.norm_out = Normalize(block_in)
+ self.conv_out = resolve_str_to_obj(conv_out)(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1])
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if hasattr(self.down[i_level], "downsample"):
+ hs.append(self.down[i_level].downsample(hs[-1]))
+ if hasattr(self.down[i_level], "time_downsample"):
+ hs_down = self.down[i_level].time_downsample(hs[-1])
+ hs.append(hs_down)
+
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+
+ if npu_config is None:
+ h = self.norm_out(h)
+ else:
+ h = npu_config.run_group_norm(self.norm_out, h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ z_channels: int,
+ hidden_size: int,
+ hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
+ attn_resolutions: Tuple[int] = (16,),
+ conv_in="Conv2d",
+ conv_out="CasualConv3d",
+ attention="AttnBlock",
+ resnet_blocks=(
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ ),
+ spatial_upsample=(
+ "",
+ "SpatialUpsample2x",
+ "SpatialUpsample2x",
+ "SpatialUpsample2x",
+ ),
+ temporal_upsample=("", "", "", "TimeUpsampleRes2x"),
+ mid_resnet="ResnetBlock3D",
+ dropout: float = 0.0,
+ resolution: int = 256,
+ num_res_blocks: int = 2,
+ ):
+ super().__init__()
+ # ---- Config ----
+ self.num_resolutions = len(hidden_size_mult)
+ self.resolution = resolution
+ self.num_res_blocks = num_res_blocks
+
+ # ---- In ----
+ block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.conv_in = resolve_str_to_obj(conv_in)(z_channels, block_in, kernel_size=3, padding=1)
+
+ # ---- Mid ----
+ self.mid = nn.Module()
+ self.mid.block_1 = resolve_str_to_obj(mid_resnet)(
+ in_channels=block_in,
+ out_channels=block_in,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = resolve_str_to_obj(attention)(block_in)
+ self.mid.block_2 = resolve_str_to_obj(mid_resnet)(
+ in_channels=block_in,
+ out_channels=block_in,
+ dropout=dropout,
+ )
+
+ # ---- Upsample ----
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = hidden_size * hidden_size_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ resolve_str_to_obj(resnet_blocks[i_level])(
+ in_channels=block_in,
+ out_channels=block_out,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(resolve_str_to_obj(attention)(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if spatial_upsample[i_level]:
+ up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(block_in, block_in)
+ curr_res = curr_res * 2
+ if temporal_upsample[i_level]:
+ up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(block_in, block_in)
+ self.up.insert(0, up)
+
+ # ---- Out ----
+ self.norm_out = Normalize(block_in)
+ self.conv_out = resolve_str_to_obj(conv_out)(block_in, 3, kernel_size=3, padding=1)
+
+ def forward(self, z):
+ h = self.conv_in(z)
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if hasattr(self.up[i_level], "upsample"):
+ h = self.up[i_level].upsample(h)
+ if hasattr(self.up[i_level], "time_upsample"):
+ h = self.up[i_level].time_upsample(h)
+ if npu_config is None:
+ h = self.norm_out(h)
+ else:
+ h = npu_config.run_group_norm(self.norm_out, h)
+ h = nonlinearity(h)
+ if npu_config is None:
+ h = self.conv_out(h)
+ else:
+ h_dtype = h.dtype
+ h = npu_config.run_conv3d(self.conv_out, h, h_dtype)
+ return h
+
+
+class CausalVAEModel(VideoBaseAE):
+ @register_to_config
+ def __init__(
+ self,
+ hidden_size: int = 128,
+ z_channels: int = 4,
+ hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
+ attn_resolutions: Tuple[int] = [],
+ dropout: float = 0.0,
+ resolution: int = 256,
+ double_z: bool = True,
+ embed_dim: int = 4,
+ num_res_blocks: int = 2,
+ q_conv: str = "CausalConv3d",
+ encoder_conv_in="CausalConv3d",
+ encoder_conv_out="CausalConv3d",
+ encoder_attention="AttnBlock3D",
+ encoder_resnet_blocks=(
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ ),
+ encoder_spatial_downsample=(
+ "SpatialDownsample2x",
+ "SpatialDownsample2x",
+ "SpatialDownsample2x",
+ "",
+ ),
+ encoder_temporal_downsample=(
+ "",
+ "TimeDownsample2x",
+ "TimeDownsample2x",
+ "",
+ ),
+ encoder_mid_resnet="ResnetBlock3D",
+ decoder_conv_in="CausalConv3d",
+ decoder_conv_out="CausalConv3d",
+ decoder_attention="AttnBlock3D",
+ decoder_resnet_blocks=(
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ ),
+ decoder_spatial_upsample=(
+ "",
+ "SpatialUpsample2x",
+ "SpatialUpsample2x",
+ "SpatialUpsample2x",
+ ),
+ decoder_temporal_upsample=("", "", "TimeUpsample2x", "TimeUpsample2x"),
+ decoder_mid_resnet="ResnetBlock3D",
+ use_quant_layer: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.tile_sample_min_size = 256
+ self.tile_sample_min_size_t = 33
+ self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1)))
+
+ # t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0]
+ # self.tile_latent_min_size_t = int((self.tile_sample_min_size_t-1) / (2 ** len(t_down_ratio))) + 1
+ self.tile_latent_min_size_t = 16
+ self.tile_overlap_factor = 0.125
+ self.use_tiling = False
+
+ self.use_quant_layer = use_quant_layer
+
+ self.encoder = Encoder(
+ z_channels=z_channels,
+ hidden_size=hidden_size,
+ hidden_size_mult=hidden_size_mult,
+ attn_resolutions=attn_resolutions,
+ conv_in=encoder_conv_in,
+ conv_out=encoder_conv_out,
+ attention=encoder_attention,
+ resnet_blocks=encoder_resnet_blocks,
+ spatial_downsample=encoder_spatial_downsample,
+ temporal_downsample=encoder_temporal_downsample,
+ mid_resnet=encoder_mid_resnet,
+ dropout=dropout,
+ resolution=resolution,
+ num_res_blocks=num_res_blocks,
+ double_z=double_z,
+ )
+
+ self.decoder = Decoder(
+ z_channels=z_channels,
+ hidden_size=hidden_size,
+ hidden_size_mult=hidden_size_mult,
+ attn_resolutions=attn_resolutions,
+ conv_in=decoder_conv_in,
+ conv_out=decoder_conv_out,
+ attention=decoder_attention,
+ resnet_blocks=decoder_resnet_blocks,
+ spatial_upsample=decoder_spatial_upsample,
+ temporal_upsample=decoder_temporal_upsample,
+ mid_resnet=decoder_mid_resnet,
+ dropout=dropout,
+ resolution=resolution,
+ num_res_blocks=num_res_blocks,
+ )
+ if self.use_quant_layer:
+ quant_conv_cls = resolve_str_to_obj(q_conv)
+ self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1)
+ self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1)
+
+ def get_encoder(self):
+ if self.use_quant_layer:
+ return [self.quant_conv, self.encoder]
+ return [self.encoder]
+
+ def get_decoder(self):
+ if self.use_quant_layer:
+ return [self.post_quant_conv, self.decoder]
+ return [self.decoder]
+
+ def encode(self, x):
+ if self.use_tiling and (
+ x.shape[-1] > self.tile_sample_min_size
+ or x.shape[-2] > self.tile_sample_min_size
+ or x.shape[-3] > self.tile_sample_min_size_t
+ ):
+ return self.tiled_encode(x)
+ h = self.encoder(x)
+ if self.use_quant_layer:
+ h = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(h)
+ return posterior
+
+ def decode(self, z):
+ if self.use_tiling and (
+ z.shape[-1] > self.tile_latent_min_size
+ or z.shape[-2] > self.tile_latent_min_size
+ or z.shape[-3] > self.tile_latent_min_size_t
+ ):
+ return self.tiled_decode(z)
+ if self.use_quant_layer:
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def on_train_start(self):
+ self.ema = deepcopy(self) if self.save_ema == True else None
+
+ def get_last_layer(self):
+ if hasattr(self.decoder.conv_out, "conv"):
+ return self.decoder.conv_out.conv.weight
+ else:
+ return self.decoder.conv_out.weight
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x):
+ t = x.shape[2]
+ t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)]
+ if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
+ t_chunk_start_end = [[0, t]]
+ else:
+ t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)]
+ if t_chunk_start_end[-1][-1] > t:
+ t_chunk_start_end[-1][-1] = t
+ elif t_chunk_start_end[-1][-1] < t:
+ last_start_end = [t_chunk_idx[-1], t]
+ t_chunk_start_end.append(last_start_end)
+ moments = []
+ for idx, (start, end) in enumerate(t_chunk_start_end):
+ chunk_x = x[:, :, start:end]
+ if idx != 0:
+ moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:]
+ else:
+ moment = self.tiled_encode2d(chunk_x, return_moments=True)
+ moments.append(moment)
+ moments = torch.cat(moments, dim=2)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def tiled_decode(self, x):
+ t = x.shape[2]
+ t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)]
+ if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
+ t_chunk_start_end = [[0, t]]
+ else:
+ t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)]
+ if t_chunk_start_end[-1][-1] > t:
+ t_chunk_start_end[-1][-1] = t
+ elif t_chunk_start_end[-1][-1] < t:
+ last_start_end = [t_chunk_idx[-1], t]
+ t_chunk_start_end.append(last_start_end)
+ dec_ = []
+ for idx, (start, end) in enumerate(t_chunk_start_end):
+ chunk_x = x[:, :, start:end]
+ if idx != 0:
+ dec = self.tiled_decode2d(chunk_x)[:, :, 1:]
+ else:
+ dec = self.tiled_decode2d(chunk_x)
+ dec_.append(dec)
+ dec_ = torch.cat(dec_, dim=2)
+ return dec_
+
+ def tiled_encode2d(self, x, return_moments=False):
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[3], overlap_size):
+ row = []
+ for j in range(0, x.shape[4], overlap_size):
+ tile = x[
+ :,
+ :,
+ :,
+ i : i + self.tile_sample_min_size,
+ j : j + self.tile_sample_min_size,
+ ]
+ tile = self.encoder(tile)
+ if self.use_quant_layer:
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ moments = torch.cat(result_rows, dim=3)
+ posterior = DiagonalGaussianDistribution(moments)
+ if return_moments:
+ return moments
+ return posterior
+
+ def tiled_decode2d(self, z):
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[3], overlap_size):
+ row = []
+ for j in range(0, z.shape[4], overlap_size):
+ tile = z[
+ :,
+ :,
+ :,
+ i : i + self.tile_latent_min_size,
+ j : j + self.tile_latent_min_size,
+ ]
+ if self.use_quant_layer:
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ dec = torch.cat(result_rows, dim=3)
+ return dec
+
+ def enable_tiling(self, use_tiling: bool = True):
+ self.use_tiling = use_tiling
+
+ def disable_tiling(self):
+ self.enable_tiling(False)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")
+ # print("init from " + path)
+
+ if "ema_state_dict" in sd and len(sd["ema_state_dict"]) > 0 and os.environ.get("NOT_USE_EMA_MODEL", 0) == 0:
+ # print("Load from ema model!")
+ sd = sd["ema_state_dict"]
+ sd = {key.replace("module.", ""): value for key, value in sd.items()}
+ elif "state_dict" in sd:
+ # print("Load from normal model!")
+ if "gen_model" in sd["state_dict"]:
+ sd = sd["state_dict"]["gen_model"]
+ else:
+ sd = sd["state_dict"]
+
+ keys = list(sd.keys())
+
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+
+ miss, unexpected = self.load_state_dict(sd, strict=False)
+ assert len(miss) == 0, f"miss key: {miss}"
+ if len(unexpected) > 0:
+ for i in unexpected:
+ assert "loss" in i, "unexpected key: {i}"
+
+
+ae_stride_config = {
+ "CausalVAEModel_D4_2x8x8": [2, 8, 8],
+ "CausalVAEModel_D8_2x8x8": [2, 8, 8],
+ "CausalVAEModel_D4_4x8x8": [4, 8, 8],
+ "CausalVAEModel_D8_4x8x8": [4, 8, 8],
+}
+
+
+ae_channel_config = {
+ "CausalVAEModel_D4_2x8x8": 4,
+ "CausalVAEModel_D8_2x8x8": 8,
+ "CausalVAEModel_D4_4x8x8": 4,
+ "CausalVAEModel_D8_4x8x8": 8,
+}
+
+
+ae_denorm = {
+ "CausalVAEModel_D4_2x8x8": lambda x: (x + 1.0) / 2.0,
+ "CausalVAEModel_D8_2x8x8": lambda x: (x + 1.0) / 2.0,
+ "CausalVAEModel_D4_4x8x8": lambda x: (x + 1.0) / 2.0,
+ "CausalVAEModel_D8_4x8x8": lambda x: (x + 1.0) / 2.0,
+}
+
+ae_norm = {
+ "CausalVAEModel_D4_2x8x8": Lambda(lambda x: 2.0 * x - 1.0),
+ "CausalVAEModel_D8_2x8x8": Lambda(lambda x: 2.0 * x - 1.0),
+ "CausalVAEModel_D4_4x8x8": Lambda(lambda x: 2.0 * x - 1.0),
+ "CausalVAEModel_D8_4x8x8": Lambda(lambda x: 2.0 * x - 1.0),
+}
+
+
+class CausalVAEModelWrapper(nn.Module):
+ def __init__(self, model_path, subfolder=None, cache_dir=None, use_ema=False, **kwargs):
+ super(CausalVAEModelWrapper, self).__init__()
+ # if os.path.exists(ckpt):
+ # self.vae = CausalVAEModel.load_from_checkpoint(ckpt)
+ # hf_hub_download(model_path, subfolder="vae", filename="checkpoint.ckpt")
+ # cached_download(hf_hub_url(model_path, subfolder="vae", filename="checkpoint.ckpt"))
+ self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs)
+ if use_ema:
+ self.vae.init_from_ema(model_path)
+ self.vae = self.vae.ema
+
+ def encode(self, x): # b c t h w
+ # x = self.vae.encode(x).sample()
+ x = self.vae.encode(x).sample().mul_(0.18215)
+ return x
+
+ def decode(self, x):
+ # x = self.vae.decode(x)
+ x = self.vae.decode(x / 0.18215)
+ x = rearrange(x, "b c t h w -> b t c h w").contiguous()
+ return x
+
+ def forward(self, x):
+ return self.decode(x)
+
+ def dtype(self):
+ return self.vae.dtype
diff --git a/videosys/models/modules/attentions.py b/videosys/models/modules/attentions.py
index 8e2c20c..4bbba72 100644
--- a/videosys/models/modules/attentions.py
+++ b/videosys/models/modules/attentions.py
@@ -1,12 +1,21 @@
+import inspect
from dataclasses import dataclass
-from typing import Iterable, List, Tuple
+from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
+from diffusers.models.attention import Attention
+from diffusers.models.attention_processor import AttnProcessor
+from einops import rearrange
+from torch import nn
+from torch.amp import autocast
-from videosys.models.modules.normalization import LlamaRMSNorm
+from videosys.core.comm import all_to_all_with_pad, get_pad, set_pad
+from videosys.core.pab_mgr import enable_pab, if_broadcast_cross, if_broadcast_spatial, if_broadcast_temporal
+from videosys.models.modules.normalization import LlamaRMSNorm, VchitectSpatialNorm
+from videosys.utils.logging import logger
class OpenSoraAttention(nn.Module):
@@ -203,3 +212,634 @@ class _SeqLenInfo:
seqstart=seqstart,
seqstart_py=seqstart_py,
)
+
+
+class VchitectAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`):
+ The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8):
+ The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64):
+ The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the attention computation to `float32`.
+ upcast_softmax (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the softmax computation to `float32`.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the group norm in the cross attention.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ norm_num_groups (`int`, *optional*, defaults to `None`):
+ The number of groups to use for the group norm in the attention.
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the spatial normalization.
+ out_bias (`bool`, *optional*, defaults to `True`):
+ Set to `True` to use a bias in the output linear layer.
+ scale_qk (`bool`, *optional*, defaults to `True`):
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+ `added_kv_proj_dim` is not `None`.
+ eps (`float`, *optional*, defaults to 1e-5):
+ An additional value added to the denominator in group normalization that is used for numerical stability.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
+ A factor to rescale the output by dividing it with this value.
+ residual_connection (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add the residual connection to the output.
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+ Set to `True` if the attention block is loaded from a deprecated state dict.
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
+ `AttnProcessor` otherwise.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ cross_attention_norm_num_groups: int = 32,
+ qk_norm: Optional[str] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ spatial_norm_dim: Optional[int] = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ _from_deprecated_attn_block: bool = False,
+ processor: Optional[AttnProcessor] = None,
+ out_dim: int = None,
+ context_pre_only: bool = None,
+ ):
+ super().__init__()
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.is_cross_attention = cross_attention_dim is not None
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+ self.dropout = dropout
+ self.fused_projections = False
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+
+ # we make use of this private variable to know whether this class is loaded
+ # with an deprecated state dict so that we can convert it on the fly
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+ else:
+ self.group_norm = None
+
+ if spatial_norm_dim is not None:
+ self.spatial_norm = VchitectSpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
+ else:
+ self.spatial_norm = None
+
+ if qk_norm is None:
+ self.norm_q = None
+ self.norm_k = None
+ elif qk_norm == "layer_norm":
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps)
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps)
+ else:
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
+
+ if cross_attention_norm is None:
+ self.norm_cross = None
+ elif cross_attention_norm == "layer_norm":
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+ elif cross_attention_norm == "group_norm":
+ if self.added_kv_proj_dim is not None:
+ # The given `encoder_hidden_states` are initially of shape
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+ # before the projection, so we need to use `added_kv_proj_dim` as
+ # the number of channels for the group norm.
+ norm_cross_num_channels = added_kv_proj_dim
+ else:
+ norm_cross_num_channels = self.cross_attention_dim
+
+ self.norm_cross = nn.GroupNorm(
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+ )
+ else:
+ raise ValueError(
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+ )
+
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.only_cross_attention:
+ # only relevant for the `AddedKVProcessor` classes
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
+ else:
+ self.to_k = None
+ self.to_v = None
+
+ self.to_q_cross = nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_q_temp = nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.only_cross_attention:
+ # only relevant for the `AddedKVProcessor` classes
+ self.to_k_temp = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
+ self.to_v_temp = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
+ else:
+ self.to_k_temp = None
+ self.to_v_temp = None
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
+ if self.context_pre_only is not None:
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
+
+ self.to_out_temporal = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
+ nn.init.constant_(self.to_out_temporal.weight, 0)
+ nn.init.constant_(self.to_out_temporal.bias, 0)
+
+ if self.context_pre_only is not None and not self.context_pre_only:
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
+ self.to_add_out_temporal = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
+ nn.init.constant_(self.to_add_out_temporal.weight, 0)
+ nn.init.constant_(self.to_add_out_temporal.bias, 0)
+
+ self.to_out_context = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
+ nn.init.constant_(self.to_out_context.weight, 0)
+ nn.init.constant_(self.to_out_context.bias, 0)
+
+ # set attention processor
+ self.set_processor(processor)
+
+ # parallel
+ self.parallel_manager = None
+
+ # pab
+ self.spatial_count = 0
+ self.last_spatial = None
+ self.temporal_count = 0
+ self.last_temporal = None
+ self.cross_count = 0
+ self.last_cross = None
+
+ def set_processor(self, processor: "AttnProcessor") -> None:
+ r"""
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+ self.processor = processor
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[torch.Tensor] = None,
+ full_seqlen: Optional[int] = None,
+ Frame: Optional[int] = None,
+ timestep: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks"}
+ unused_kwargs = [
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
+ ]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ freqs_cis=freqs_cis,
+ full_seqlen=full_seqlen,
+ Frame=Frame,
+ timestep=timestep,
+ **cross_attention_kwargs,
+ )
+
+ @torch.no_grad()
+ def fuse_projections(self, fuse=True):
+ device = self.to_q.weight.data.device
+ dtype = self.to_q.weight.data.dtype
+
+ if not self.is_cross_attention:
+ # fetch weight matrices.
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ # create a new single projection layer and copy over the weights.
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_qkv.weight.copy_(concatenated_weights)
+ if self.use_bias:
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ self.to_qkv.bias.copy_(concatenated_bias)
+
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_kv.weight.copy_(concatenated_weights)
+ if self.use_bias:
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ self.to_kv.bias.copy_(concatenated_bias)
+
+ self.fused_projections = fuse
+
+
+class VchitectAttnProcessor:
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor):
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+ @autocast("cuda", enabled=False)
+ def apply_rotary_emb(
+ self,
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ ):
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+ def spatial_attn(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ batch_size,
+ head_dim,
+ ):
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ # attention
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+ xq, xk = query.to(value.dtype), key.to(value.dtype)
+
+ query_spatial, key_spatial, value_spatial = xq, xk, value
+ query_spatial = query_spatial.transpose(1, 2)
+ key_spatial = key_spatial.transpose(1, 2)
+ value_spatial = value_spatial.transpose(1, 2)
+
+ hidden_states = hidden_states = F.scaled_dot_product_attention(
+ query_spatial, key_spatial, value_spatial, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ return hidden_states
+
+ def temporal_attention(
+ self,
+ attn,
+ hidden_states,
+ residual,
+ batch_size,
+ batchsize,
+ Frame,
+ head_dim,
+ freqs_cis,
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ ):
+ query_t = attn.to_q_temp(hidden_states)
+ key_t = attn.to_k_temp(hidden_states)
+ value_t = attn.to_v_temp(hidden_states)
+
+ query_t = torch.cat([query_t, encoder_hidden_states_query_proj], dim=1)
+ key_t = torch.cat([key_t, encoder_hidden_states_key_proj], dim=1)
+ value_t = torch.cat([value_t, encoder_hidden_states_value_proj], dim=1)
+
+ query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
+ key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
+ value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
+
+ query_t, key_t = query_t.to(value_t.dtype), key_t.to(value_t.dtype)
+
+ if attn.parallel_manager.sp_size > 1:
+ func = lambda x: self.dynamic_switch(attn, x, batchsize, to_spatial_shard=True)
+ query_t, key_t, value_t = map(func, [query_t, key_t, value_t])
+
+ func = lambda x: rearrange(x, "(B T) S H C -> (B S) T H C", T=Frame, B=batchsize)
+ xq_gather_temporal, xk_gather_temporal, xv_gather_temporal = map(func, [query_t, key_t, value_t])
+
+ freqs_cis_temporal = freqs_cis[: xq_gather_temporal.shape[1], :]
+ xq_gather_temporal, xk_gather_temporal = self.apply_rotary_emb(
+ xq_gather_temporal, xk_gather_temporal, freqs_cis=freqs_cis_temporal
+ )
+
+ xq_gather_temporal = xq_gather_temporal.transpose(1, 2)
+ xk_gather_temporal = xk_gather_temporal.transpose(1, 2)
+ xv_gather_temporal = xv_gather_temporal.transpose(1, 2)
+
+ batch_size_temp = xv_gather_temporal.shape[0]
+ hidden_states_temp = F.scaled_dot_product_attention(
+ xq_gather_temporal, xk_gather_temporal, xv_gather_temporal, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_temp = hidden_states_temp.transpose(1, 2).reshape(batch_size_temp, -1, attn.heads * head_dim)
+ hidden_states_temp = hidden_states_temp.to(value_t.dtype)
+ hidden_states_temp = rearrange(hidden_states_temp, "(B S) T C -> (B T) S C", T=Frame, B=batchsize)
+ if attn.parallel_manager.sp_size > 1:
+ hidden_states_temp = self.dynamic_switch(attn, hidden_states_temp, batchsize, to_spatial_shard=False)
+
+ hidden_states_temporal, encoder_hidden_states_temporal = (
+ hidden_states_temp[:, : residual.shape[1]],
+ hidden_states_temp[:, residual.shape[1] :],
+ )
+ hidden_states_temporal = attn.to_out_temporal(hidden_states_temporal)
+ return hidden_states_temporal, encoder_hidden_states_temporal
+
+ def cross_attention(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ batch_size,
+ head_dim,
+ cur_frame,
+ batchsize,
+ ):
+ query_cross = attn.to_q_cross(hidden_states)
+ query_cross = torch.cat([query_cross, encoder_hidden_states_query_proj], dim=1)
+
+ key_y = encoder_hidden_states_key_proj[0].unsqueeze(0)
+ value_y = encoder_hidden_states_value_proj[0].unsqueeze(0)
+
+ query_y = query_cross.view(batch_size, -1, attn.heads, head_dim)
+ key_y = key_y.view(batchsize, -1, attn.heads, head_dim)
+ value_y = value_y.view(batchsize, -1, attn.heads, head_dim)
+
+ query_y = rearrange(query_y, "(B T) S H C -> B (S T) H C", T=cur_frame, B=batchsize)
+
+ query_y = query_y.transpose(1, 2)
+ key_y = key_y.transpose(1, 2)
+ value_y = value_y.transpose(1, 2)
+
+ cross_output = F.scaled_dot_product_attention(query_y, key_y, value_y, dropout_p=0.0, is_causal=False)
+ cross_output = cross_output.transpose(1, 2).reshape(batchsize, -1, attn.heads * head_dim)
+ cross_output = cross_output.to(query_cross.dtype)
+
+ cross_output = rearrange(cross_output, "B (S T) C -> (B T) S C", T=cur_frame, B=batchsize)
+ cross_output = attn.to_out_context(cross_output)
+ return cross_output
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ freqs_cis: Optional[torch.Tensor] = None,
+ full_seqlen: Optional[int] = None,
+ Frame: Optional[int] = None,
+ timestep: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ batch_size = encoder_hidden_states.shape[0]
+ inner_dim = encoder_hidden_states_key_proj.shape[-1]
+ head_dim = inner_dim // attn.heads
+ batchsize = full_seqlen // Frame
+ # same as Frame if shard, otherwise sharded frame
+ cur_frame = batch_size // batchsize
+
+ # temporal attention
+ if enable_pab():
+ broadcast_temporal, attn.temporal_count = if_broadcast_temporal(int(timestep[0]), attn.temporal_count)
+ if enable_pab() and broadcast_temporal:
+ hidden_states_temporal, encoder_hidden_states_temporal = attn.last_temporal
+ else:
+ hidden_states_temporal, encoder_hidden_states_temporal = self.temporal_attention(
+ attn,
+ hidden_states,
+ residual,
+ batch_size,
+ batchsize,
+ Frame,
+ head_dim,
+ freqs_cis,
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ )
+ if enable_pab():
+ attn.last_temporal = (hidden_states_temporal, encoder_hidden_states_temporal)
+
+ # cross attn
+ if enable_pab():
+ broadcast_cross, attn.cross_count = if_broadcast_cross(int(timestep[0]), attn.cross_count)
+ if enable_pab() and broadcast_cross:
+ cross_output = attn.last_cross
+ else:
+ cross_output = self.cross_attention(
+ attn,
+ hidden_states,
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ batch_size,
+ head_dim,
+ cur_frame,
+ batchsize,
+ )
+ if enable_pab():
+ attn.last_cross = cross_output
+
+ # spatial attn
+ if enable_pab():
+ broadcast_spatial, attn.spatial_count = if_broadcast_spatial(int(timestep[0]), attn.spatial_count)
+ if enable_pab() and broadcast_spatial:
+ hidden_states = attn.last_spatial
+ else:
+ hidden_states = self.spatial_attn(
+ attn,
+ hidden_states,
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ batch_size,
+ head_dim,
+ )
+ if enable_pab():
+ attn.last_spatial = hidden_states
+
+ # processs attention outputs.
+ hidden_states = hidden_states * 1.1 + cross_output
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if cur_frame == 1:
+ hidden_states_temporal = hidden_states_temporal * 0
+ hidden_states = hidden_states + hidden_states_temporal
+
+ # encoder
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+ encoder_hidden_states_temporal = attn.to_add_out_temporal(encoder_hidden_states_temporal)
+ if cur_frame == 1:
+ encoder_hidden_states_temporal = encoder_hidden_states_temporal * 0
+ encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_temporal
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states, encoder_hidden_states
+
+ def dynamic_switch(self, attn, x, batchsize, to_spatial_shard: bool):
+ if to_spatial_shard:
+ scatter_dim, gather_dim = 2, 1
+ set_pad("spatial", x.shape[1], attn.parallel_manager.sp_group)
+ scatter_pad = get_pad("spatial")
+ gather_pad = get_pad("temporal")
+ else:
+ scatter_dim, gather_dim = 1, 2
+ scatter_pad = get_pad("temporal")
+ gather_pad = get_pad("spatial")
+
+ x = rearrange(x, "(B T) S ... -> B T S ...", B=batchsize)
+ x = all_to_all_with_pad(
+ x,
+ attn.parallel_manager.sp_group,
+ scatter_dim=scatter_dim,
+ gather_dim=gather_dim,
+ scatter_pad=scatter_pad,
+ gather_pad=gather_pad,
+ )
+ x = rearrange(x, "B T ... -> (B T) ...")
+ return x
diff --git a/videosys/models/modules/normalization.py b/videosys/models/modules/normalization.py
index 7985e56..4df9ae3 100644
--- a/videosys/models/modules/normalization.py
+++ b/videosys/models/modules/normalization.py
@@ -2,6 +2,7 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
+import torch.nn.functional as F
class LlamaRMSNorm(nn.Module):
@@ -100,3 +101,32 @@ class AdaLayerNorm(nn.Module):
x = self.norm(x) * (1 + scale) + shift
return x
+
+
+class VchitectSpatialNorm(nn.Module):
+ """
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
+
+ Args:
+ f_channels (`int`):
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
+ zq_channels (`int`):
+ The number of channels for the quantized vector as described in the paper.
+ """
+
+ def __init__(
+ self,
+ f_channels: int,
+ zq_channels: int,
+ ):
+ super().__init__()
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
+ f_size = f.shape[-2:]
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
diff --git a/videosys/models/transformers/cogvideox_transformer_3d.py b/videosys/models/transformers/cogvideox_transformer_3d.py
index 4b86d70..e568e06 100644
--- a/videosys/models/transformers/cogvideox_transformer_3d.py
+++ b/videosys/models/transformers/cogvideox_transformer_3d.py
@@ -22,15 +22,9 @@ from diffusers.utils import is_torch_version
from diffusers.utils.torch_utils import maybe_allow_in_graph
from torch import nn
-from videosys.core.comm import all_to_all_comm, gather_sequence, get_spatial_pad, set_spatial_pad, split_sequence
+from videosys.core.comm import all_to_all_comm, gather_sequence, get_pad, set_pad, split_sequence
from videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
-from videosys.core.parallel_mgr import (
- enable_sequence_parallel,
- get_cfg_parallel_group,
- get_cfg_parallel_size,
- get_sequence_parallel_group,
- get_sequence_parallel_size,
-)
+from videosys.core.parallel_mgr import ParallelManager
from videosys.models.modules.embeddings import apply_rotary_emb
from videosys.utils.utils import batch_func
@@ -48,6 +42,49 @@ class CogVideoXAttnProcessor2_0:
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ def _remove_extra_encoder(self, hidden_states, text_seq_length, attn):
+ # current layout is [text, 1/n seq, text, 1/n seq, ...]
+ # we want to remove the all the text info [text, seq]
+ sp_size = attn.parallel_manager.sp_size
+ split_seq = hidden_states.split(hidden_states.size(2) // sp_size, dim=2)
+ encoder_hidden_states = split_seq[0][:, :, :text_seq_length]
+ new_seq = [encoder_hidden_states]
+ for i in range(sp_size):
+ new_seq.append(split_seq[i][:, :, text_seq_length:])
+ hidden_states = torch.cat(new_seq, dim=2)
+
+ # remove padding added when all2all
+ # if pad is removed earlier than this
+ # the split size will be wrong
+ pad = get_pad("pad")
+ if pad > 0:
+ hidden_states = hidden_states.narrow(2, 0, hidden_states.size(2) - pad)
+ return hidden_states
+
+ def _add_extra_encoder(self, hidden_states, text_seq_length, attn):
+ # add padding for split and later all2all
+ # if pad is removed later than this
+ # the split size will be wrong
+ pad = get_pad("pad")
+ if pad > 0:
+ pad_shape = list(hidden_states.shape)
+ pad_shape[1] = pad
+ pad_tensor = torch.zeros(pad_shape, device=hidden_states.device, dtype=hidden_states.dtype)
+ hidden_states = torch.cat([hidden_states, pad_tensor], dim=1)
+
+ # current layout is [text, seq]
+ # we want to add the extra encoder info [text, 1/n seq, text, 1/n seq, ...]
+ sp_size = attn.parallel_manager.sp_size
+ encoder = hidden_states[:, :text_seq_length]
+ seq = hidden_states[:, text_seq_length:]
+ seq = seq.split(seq.size(1) // sp_size, dim=1)
+ new_seq = []
+ for i in range(sp_size):
+ new_seq.append(encoder)
+ new_seq.append(seq[i])
+ hidden_states = torch.cat(new_seq, dim=1)
+ return hidden_states
+
def __call__(
self,
attn: Attention,
@@ -72,13 +109,15 @@ class CogVideoXAttnProcessor2_0:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
- if enable_sequence_parallel():
+ if attn.parallel_manager.sp_size > 1:
assert (
- attn.heads % get_sequence_parallel_size() == 0
- ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {get_sequence_parallel_size()}"
- attn_heads = attn.heads // get_sequence_parallel_size()
+ attn.heads % attn.parallel_manager.sp_size == 0
+ ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}"
+ attn_heads = attn.heads // attn.parallel_manager.sp_size
+ # normally we operate pad for every all2all. but for more convient implementation
+ # we move pad operation to encoder add and remove in cogvideo
query, key, value = map(
- lambda x: all_to_all_comm(x, get_sequence_parallel_group(), scatter_dim=2, gather_dim=1),
+ lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1),
[query, key, value],
)
else:
@@ -96,6 +135,13 @@ class CogVideoXAttnProcessor2_0:
if attn.norm_k is not None:
key = attn.norm_k(key)
+ if attn.parallel_manager.sp_size > 1:
+ # remove extra encoder for attention
+ query, key, value = map(
+ lambda x: self._remove_extra_encoder(x, text_seq_length, attn),
+ [query, key, value],
+ )
+
# Apply RoPE if needed
if image_rotary_emb is not None:
emb_len = image_rotary_emb[0].shape[0]
@@ -113,77 +159,10 @@ class CogVideoXAttnProcessor2_0:
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
- if enable_sequence_parallel():
- hidden_states = all_to_all_comm(hidden_states, get_sequence_parallel_group(), scatter_dim=1, gather_dim=2)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states, hidden_states = hidden_states.split(
- [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
- )
- return hidden_states, encoder_hidden_states
-
-
-class FusedCogVideoXAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
- query and key vectors, but does not include spatial normalization.
- """
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
-
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # Apply RoPE if needed
- if image_rotary_emb is not None:
- query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
- if not attn.is_cross_attention:
- key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ if attn.parallel_manager.sp_size > 1:
+ # add extra encoder for all_to_all
+ hidden_states = self._add_extra_encoder(hidden_states, text_seq_length, attn)
+ hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@@ -266,6 +245,9 @@ class CogVideoXBlock(nn.Module):
processor=CogVideoXAttnProcessor2_0(),
)
+ # parallel
+ self.attn1.parallel_manager = None
+
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
@@ -300,7 +282,7 @@ class CogVideoXBlock(nn.Module):
# attention
if enable_pab():
- broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
+ broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count)
if enable_pab() and broadcast_attn:
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
else:
@@ -474,6 +456,23 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False
+ # parallel
+ self.parallel_manager = None
+
+ def enable_parallel(self, dp_size, sp_size, enable_cp):
+ # update cfg parallel
+ if enable_cp and sp_size % 2 == 0:
+ sp_size = sp_size // 2
+ cp_size = 2
+ else:
+ cp_size = 1
+
+ self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size)
+
+ for _, module in self.named_modules():
+ if hasattr(module, "parallel_manager"):
+ module.parallel_manager = self.parallel_manager
+
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@@ -485,9 +484,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
- all_timesteps=None,
):
- if get_cfg_parallel_size() > 1:
+ if self.parallel_manager.cp_size > 1:
(
hidden_states,
encoder_hidden_states,
@@ -495,7 +493,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
timestep_cond,
image_rotary_emb,
) = batch_func(
- partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
+ partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
hidden_states,
encoder_hidden_states,
timestep,
@@ -530,9 +528,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
- if enable_sequence_parallel():
- set_spatial_pad(hidden_states.shape[1])
- hidden_states = split_sequence(hidden_states, get_sequence_parallel_group(), dim=1, pad=get_spatial_pad())
+ if self.parallel_manager.sp_size > 1:
+ set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
+ hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
@@ -562,8 +560,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
timestep=timesteps if enable_pab() else None,
)
- if enable_sequence_parallel():
- hidden_states = gather_sequence(hidden_states, get_sequence_parallel_group(), dim=1, pad=get_spatial_pad())
+ if self.parallel_manager.sp_size > 1:
+ hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
@@ -583,8 +581,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
- if get_cfg_parallel_size() > 1:
- output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
+ if self.parallel_manager.cp_size > 1:
+ output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)
diff --git a/videosys/models/transformers/latte_transformer_3d.py b/videosys/models/transformers/latte_transformer_3d.py
index 3a00759..ef28720 100644
--- a/videosys/models/transformers/latte_transformer_3d.py
+++ b/videosys/models/transformers/latte_transformer_3d.py
@@ -33,15 +33,7 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
from einops import rearrange, repeat
from torch import nn
-from videosys.core.comm import (
- all_to_all_with_pad,
- gather_sequence,
- get_spatial_pad,
- get_temporal_pad,
- set_spatial_pad,
- set_temporal_pad,
- split_sequence,
-)
+from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
from videosys.core.pab_mgr import (
enable_pab,
get_mlp_output,
@@ -51,12 +43,7 @@ from videosys.core.pab_mgr import (
if_broadcast_temporal,
save_mlp_output,
)
-from videosys.core.parallel_mgr import (
- enable_sequence_parallel,
- get_cfg_parallel_group,
- get_cfg_parallel_size,
- get_sequence_parallel_group,
-)
+from videosys.core.parallel_mgr import ParallelManager
from videosys.utils.utils import batch_func
@@ -388,9 +375,7 @@ class BasicTransformerBlock(nn.Module):
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
if enable_pab():
- broadcast_spatial, self.spatial_count = if_broadcast_spatial(
- int(org_timestep[0]), self.spatial_count, self.block_idx
- )
+ broadcast_spatial, self.spatial_count = if_broadcast_spatial(int(org_timestep[0]), self.spatial_count)
if enable_pab() and broadcast_spatial:
attn_output = self.spatial_last
@@ -681,6 +666,9 @@ class BasicTransformerBlock_(nn.Module):
self.block_idx = block_idx
self.count = 0
+ # parallel
+ self.parallel_manager: ParallelManager = None
+
def set_last_out(self, last_out: torch.Tensor):
self.last_out = last_out
@@ -743,7 +731,7 @@ class BasicTransformerBlock_(nn.Module):
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
- if enable_sequence_parallel():
+ if self.parallel_manager.sp_size > 1:
norm_hidden_states = self.dynamic_switch(norm_hidden_states, to_spatial_shard=True)
attn_output = self.attn1(
@@ -753,7 +741,7 @@ class BasicTransformerBlock_(nn.Module):
**cross_attention_kwargs,
)
- if enable_sequence_parallel():
+ if self.parallel_manager.sp_size > 1:
attn_output = self.dynamic_switch(attn_output, to_spatial_shard=False)
if self.use_ada_layer_norm_zero:
@@ -838,15 +826,15 @@ class BasicTransformerBlock_(nn.Module):
def dynamic_switch(self, x, to_spatial_shard: bool):
if to_spatial_shard:
scatter_dim, gather_dim = 0, 1
- scatter_pad = get_spatial_pad()
- gather_pad = get_temporal_pad()
+ scatter_pad = get_pad("spatial")
+ gather_pad = get_pad("temporal")
else:
scatter_dim, gather_dim = 1, 0
- scatter_pad = get_temporal_pad()
- gather_pad = get_spatial_pad()
+ scatter_pad = get_pad("temporal")
+ gather_pad = get_pad("spatial")
x = all_to_all_with_pad(
x,
- get_sequence_parallel_group(),
+ self.parallel_manager.sp_group,
scatter_dim=scatter_dim,
gather_dim=gather_dim,
scatter_pad=scatter_pad,
@@ -1133,6 +1121,23 @@ class LatteT2V(ModelMixin, ConfigMixin):
temp_pos_embed = self.get_1d_sincos_temp_embed(inner_dim, video_length) # 1152 hidden size
self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
+ # parallel
+ self.parallel_manager = None
+
+ def enable_parallel(self, dp_size, sp_size, enable_cp):
+ # update cfg parallel
+ if enable_cp and sp_size % 2 == 0:
+ sp_size = sp_size // 2
+ cp_size = 2
+ else:
+ cp_size = 1
+
+ self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size)
+
+ for _, module in self.named_modules():
+ if hasattr(module, "parallel_manager"):
+ module.parallel_manager = self.parallel_manager
+
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@@ -1191,7 +1196,7 @@ class LatteT2V(ModelMixin, ConfigMixin):
"""
# 0. Split batch for data parallelism
- if get_cfg_parallel_size() > 1:
+ if self.parallel_manager.cp_size > 1:
(
hidden_states,
timestep,
@@ -1201,7 +1206,7 @@ class LatteT2V(ModelMixin, ConfigMixin):
attention_mask,
encoder_attention_mask,
) = batch_func(
- partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
+ partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
hidden_states,
timestep,
encoder_hidden_states,
@@ -1292,14 +1297,14 @@ class LatteT2V(ModelMixin, ConfigMixin):
timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
- if enable_sequence_parallel():
- set_temporal_pad(frame + use_image_num)
- set_spatial_pad(num_patches)
+ if self.parallel_manager.sp_size > 1:
+ set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
+ set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
temp_pos_embed = split_sequence(
- self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
+ self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
@@ -1420,7 +1425,7 @@ class LatteT2V(ModelMixin, ConfigMixin):
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
- if enable_sequence_parallel():
+ if self.parallel_manager.sp_size > 1:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
if self.is_input_patches:
@@ -1452,8 +1457,8 @@ class LatteT2V(ModelMixin, ConfigMixin):
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
# 3. Gather batch for data parallelism
- if get_cfg_parallel_size() > 1:
- output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
+ if self.parallel_manager.cp_size > 1:
+ output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)
@@ -1466,12 +1471,12 @@ class LatteT2V(ModelMixin, ConfigMixin):
def split_from_second_dim(self, x, batch_size):
x = x.view(batch_size, -1, *x.shape[1:])
- x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
+ x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
x = x.reshape(-1, *x.shape[2:])
return x
def gather_from_second_dim(self, x, batch_size):
x = x.view(batch_size, -1, *x.shape[1:])
- x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad())
+ x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
x = x.reshape(-1, *x.shape[2:])
return x
diff --git a/videosys/models/transformers/open_sora_plan_v110_transformer_3d.py b/videosys/models/transformers/open_sora_plan_v110_transformer_3d.py
new file mode 100644
index 0000000..9839677
--- /dev/null
+++ b/videosys/models/transformers/open_sora_plan_v110_transformer_3d.py
@@ -0,0 +1,2826 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+
+import json
+import os
+from dataclasses import dataclass
+from functools import partial
+from importlib import import_module
+from typing import Any, Callable, Dict, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
+from diffusers.models.attention_processor import (
+ AttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ AttnProcessor,
+ CustomDiffusionAttnProcessor,
+ CustomDiffusionAttnProcessor2_0,
+ CustomDiffusionXFormersAttnProcessor,
+ LoRAAttnAddedKVProcessor,
+ LoRAAttnProcessor,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ SlicedAttnAddedKVProcessor,
+ SlicedAttnProcessor,
+ SpatialNorm,
+ XFormersAttnAddedKVProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_xformers_available
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from einops import rearrange, repeat
+from torch import nn
+
+from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
+from videosys.core.pab_mgr import (
+ enable_pab,
+ get_mlp_output,
+ if_broadcast_cross,
+ if_broadcast_mlp,
+ if_broadcast_spatial,
+ if_broadcast_temporal,
+ save_mlp_output,
+)
+from videosys.core.parallel_mgr import ParallelManager
+from videosys.utils.logging import logger
+from videosys.utils.utils import batch_func
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+SPATIAL_LIST = []
+TEMPROAL_LIST = []
+CROSS_LIST = []
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
+):
+ """
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, int):
+ grid_size = (grid_size, grid_size)
+
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed(embed_dim, length, interpolation_scale=1.0, base_size=16):
+ pos = torch.arange(0, length).unsqueeze(1) / interpolation_scale
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
+ return pos_embed
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
+ """
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+class RoPE2D(torch.nn.Module):
+ def __init__(self, freq=10000.0, F0=1.0, scaling_factor=1.0):
+ super().__init__()
+ self.base = freq
+ self.F0 = F0
+ self.scaling_factor = scaling_factor
+ self.cache = {}
+
+ def get_cos_sin(self, D, seq_len, device, dtype):
+ if (D, seq_len, device, dtype) not in self.cache:
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
+ freqs = torch.cat((freqs, freqs), dim=-1)
+ cos = freqs.cos() # (Seq, Dim)
+ sin = freqs.sin()
+ self.cache[D, seq_len, device, dtype] = (cos, sin)
+ return self.cache[D, seq_len, device, dtype]
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
+ assert pos1d.ndim == 2
+
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
+
+ def forward(self, tokens, positions):
+ """
+ input:
+ * tokens: batch_size x nheads x ntokens x dim
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
+ output:
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
+ """
+ assert tokens.size(3) % 2 == 0, "number of dimensions should be a multiple of two"
+ D = tokens.size(3) // 2
+ assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2
+ cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype)
+ # split features into two along the feature dimension, and apply rope1d on each half
+ y, x = tokens.chunk(2, dim=-1)
+ y = self.apply_rope1d(y, positions[:, :, 0], cos, sin)
+ x = self.apply_rope1d(x, positions[:, :, 1], cos, sin)
+ tokens = torch.cat((y, x), dim=-1)
+ return tokens
+
+
+class LinearScalingRoPE2D(RoPE2D):
+ """Code from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L148"""
+
+ def forward(self, tokens, positions):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ dtype = positions.dtype
+ positions = positions.float() / self.scaling_factor
+ positions = positions.to(dtype)
+ tokens = super().forward(tokens, positions)
+ return tokens
+
+
+class RoPE1D(torch.nn.Module):
+ def __init__(self, freq=10000.0, F0=1.0, scaling_factor=1.0):
+ super().__init__()
+ self.base = freq
+ self.F0 = F0
+ self.scaling_factor = scaling_factor
+ self.cache = {}
+
+ def get_cos_sin(self, D, seq_len, device, dtype):
+ if (D, seq_len, device, dtype) not in self.cache:
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
+ freqs = torch.cat((freqs, freqs), dim=-1)
+ cos = freqs.cos() # (Seq, Dim)
+ sin = freqs.sin()
+ self.cache[D, seq_len, device, dtype] = (cos, sin)
+ return self.cache[D, seq_len, device, dtype]
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
+ assert pos1d.ndim == 2
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
+
+ def forward(self, tokens, positions):
+ """
+ input:
+ * tokens: batch_size x nheads x ntokens x dim
+ * positions: batch_size x ntokens (t position of each token)
+ output:
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
+ """
+ D = tokens.size(3)
+ assert positions.ndim == 2 # Batch, Seq
+ cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype)
+ tokens = self.apply_rope1d(tokens, positions, cos, sin)
+ return tokens
+
+
+class LinearScalingRoPE1D(RoPE1D):
+ """Code from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L148"""
+
+ def forward(self, tokens, positions):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ dtype = positions.dtype
+ positions = positions.float() / self.scaling_factor
+ positions = positions.to(dtype)
+ tokens = super().forward(tokens, positions)
+ return tokens
+
+
+class PositionGetter2D(object):
+ """return positions of patches"""
+
+ def __init__(self):
+ self.cache_positions = {}
+
+ def __call__(self, b, h, w, device):
+ if not (h, w) in self.cache_positions:
+ x = torch.arange(w, device=device)
+ y = torch.arange(h, device=device)
+ self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2)
+ pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone()
+ return pos
+
+
+class PositionGetter1D(object):
+ """return positions of patches"""
+
+ def __init__(self):
+ self.cache_positions = {}
+
+ def __call__(self, b, l, device):
+ if not (l) in self.cache_positions:
+ x = torch.arange(l, device=device)
+ self.cache_positions[l] = x # (l, )
+ pos = self.cache_positions[l].view(1, l).expand(b, -1).clone()
+ return pos
+
+
+class CombinedTimestepSizeEmbeddings(nn.Module):
+ """
+ For PixArt-Alpha.
+
+ Reference:
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
+ """
+
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.outdim = size_emb_dim
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.use_additional_conditions = use_additional_conditions
+ if use_additional_conditions:
+ self.use_additional_conditions = True
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+
+ def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
+ if size.ndim == 1:
+ size = size[:, None]
+
+ if size.shape[0] != batch_size:
+ size = size.repeat(batch_size // size.shape[0], 1)
+ if size.shape[0] != batch_size:
+ raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
+
+ current_batch_size, dims = size.shape[0], size.shape[1]
+ size = size.reshape(-1)
+ size_freq = self.additional_condition_proj(size).to(size.dtype)
+
+ size_emb = embedder(size_freq)
+ size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
+ return size_emb
+
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ if self.use_additional_conditions:
+ resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
+ aspect_ratio = self.apply_condition(
+ aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
+ )
+ conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
+ else:
+ conditioning = timesteps_emb
+
+ return conditioning
+
+
+class CaptionProjection(nn.Module):
+ """
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
+
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
+ """
+
+ def __init__(self, in_features, hidden_size, num_tokens=120):
+ super().__init__()
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
+ self.act_1 = nn.GELU(approximate="tanh")
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
+ self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))
+
+ def forward(self, caption, force_drop_ids=None):
+ hidden_states = self.linear_1(caption)
+ hidden_states = self.act_1(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ height=224,
+ width=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ interpolation_scale=1,
+ ):
+ super().__init__()
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ self.patch_size = patch_size
+ # See:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
+ self.height, self.width = height // patch_size, width // patch_size
+
+ self.base_size = height // patch_size
+ self.interpolation_scale = interpolation_scale
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ )
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
+
+ def forward(self, latent):
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
+
+ latent = self.proj(latent)
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ if self.layer_norm:
+ latent = self.norm(latent)
+ # Interpolate positional embeddings if needed.
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
+ if self.height != height or self.width != width:
+ # raise ValueError
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=self.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed)
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = self.pos_embed
+ return (latent + pos_embed).to(latent.dtype)
+
+
+@maybe_allow_in_graph
+class Attention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`):
+ The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8):
+ The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64):
+ The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the attention computation to `float32`.
+ upcast_softmax (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the softmax computation to `float32`.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the group norm in the cross attention.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ norm_num_groups (`int`, *optional*, defaults to `None`):
+ The number of groups to use for the group norm in the attention.
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the spatial normalization.
+ out_bias (`bool`, *optional*, defaults to `True`):
+ Set to `True` to use a bias in the output linear layer.
+ scale_qk (`bool`, *optional*, defaults to `True`):
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+ `added_kv_proj_dim` is not `None`.
+ eps (`float`, *optional*, defaults to 1e-5):
+ An additional value added to the denominator in group normalization that is used for numerical stability.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
+ A factor to rescale the output by dividing it with this value.
+ residual_connection (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add the residual connection to the output.
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+ Set to `True` if the attention block is loaded from a deprecated state dict.
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
+ `AttnProcessor` otherwise.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ cross_attention_norm_num_groups: int = 32,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ spatial_norm_dim: Optional[int] = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ _from_deprecated_attn_block: bool = False,
+ processor: Optional["AttnProcessor"] = None,
+ attention_mode: str = "xformers",
+ use_rope: bool = False,
+ rope_scaling: Optional[Dict] = None,
+ compress_kv_factor: Optional[Tuple] = None,
+ ):
+ super().__init__()
+ self.inner_dim = dim_head * heads
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+ self.dropout = dropout
+ self.use_rope = use_rope
+ self.rope_scaling = rope_scaling
+ self.compress_kv_factor = compress_kv_factor
+
+ # we make use of this private variable to know whether this class is loaded
+ # with an deprecated state dict so that we can convert it on the fly
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+ else:
+ self.group_norm = None
+
+ if spatial_norm_dim is not None:
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
+ else:
+ self.spatial_norm = None
+
+ if cross_attention_norm is None:
+ self.norm_cross = None
+ elif cross_attention_norm == "layer_norm":
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+ elif cross_attention_norm == "group_norm":
+ if self.added_kv_proj_dim is not None:
+ # The given `encoder_hidden_states` are initially of shape
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+ # before the projection, so we need to use `added_kv_proj_dim` as
+ # the number of channels for the group norm.
+ norm_cross_num_channels = added_kv_proj_dim
+ else:
+ norm_cross_num_channels = self.cross_attention_dim
+
+ self.norm_cross = nn.GroupNorm(
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+ )
+ else:
+ raise ValueError(
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+ )
+
+ if USE_PEFT_BACKEND:
+ linear_cls = nn.Linear
+ else:
+ linear_cls = LoRACompatibleLinear
+
+ assert not (
+ self.use_rope and (self.compress_kv_factor is not None)
+ ), "Can not both enable compressing kv and using rope"
+ if self.compress_kv_factor is not None:
+ self._init_compress()
+
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
+
+ if not self.only_cross_attention:
+ # only relevant for the `AddedKVProcessor` classes
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
+ else:
+ self.to_k = None
+ self.to_v = None
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
+
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ if processor is None:
+ processor = (
+ AttnProcessor2_0(
+ self.inner_dim,
+ attention_mode,
+ use_rope,
+ rope_scaling=rope_scaling,
+ compress_kv_factor=compress_kv_factor,
+ )
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
+ else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ r"""
+ Set whether to use memory efficient attention from `xformers` or not.
+
+ Args:
+ use_memory_efficient_attention_xformers (`bool`):
+ Whether to use memory efficient attention from `xformers` or not.
+ attention_op (`Callable`, *optional*):
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
+ `xformers`.
+ """
+ is_lora = hasattr(self, "processor")
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
+ )
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (
+ AttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ SlicedAttnAddedKVProcessor,
+ XFormersAttnAddedKVProcessor,
+ LoRAAttnAddedKVProcessor,
+ ),
+ )
+
+ if use_memory_efficient_attention_xformers:
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
+ raise NotImplementedError(
+ f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
+ )
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ (
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers"
+ ),
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+
+ if is_lora:
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
+ processor = LoRAXFormersAttnProcessor(
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ rank=self.processor.rank,
+ attention_op=attention_op,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ processor.to(self.processor.to_q_lora.up.weight.device)
+ elif is_custom_diffusion:
+ processor = CustomDiffusionXFormersAttnProcessor(
+ train_kv=self.processor.train_kv,
+ train_q_out=self.processor.train_q_out,
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ attention_op=attention_op,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_custom_diffusion"):
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ elif is_added_kv_processor:
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
+ # which uses this type of cross attention ONLY because the attention mask of format
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
+ # throw warning
+ logger.info(
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
+ )
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
+ else:
+ processor = XFormersAttnProcessor(attention_op=attention_op)
+ else:
+ if is_lora:
+ attn_processor_class = (
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
+ )
+ processor = attn_processor_class(
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ rank=self.processor.rank,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ processor.to(self.processor.to_q_lora.up.weight.device)
+ elif is_custom_diffusion:
+ attn_processor_class = (
+ CustomDiffusionAttnProcessor2_0
+ if hasattr(F, "scaled_dot_product_attention")
+ else CustomDiffusionAttnProcessor
+ )
+ processor = attn_processor_class(
+ train_kv=self.processor.train_kv,
+ train_q_out=self.processor.train_q_out,
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_custom_diffusion"):
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ else:
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ processor = (
+ AttnProcessor2_0()
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
+ else AttnProcessor()
+ )
+
+ self.set_processor(processor)
+
+ def set_attention_slice(self, slice_size: int) -> None:
+ r"""
+ Set the slice size for attention computation.
+
+ Args:
+ slice_size (`int`):
+ The slice size for attention computation.
+ """
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ if slice_size is not None and self.added_kv_proj_dim is not None:
+ processor = SlicedAttnAddedKVProcessor(slice_size)
+ elif slice_size is not None:
+ processor = SlicedAttnProcessor(slice_size)
+ elif self.added_kv_proj_dim is not None:
+ processor = AttnAddedKVProcessor()
+ else:
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+
+ self.set_processor(processor)
+
+ def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
+ r"""
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ _remove_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to remove LoRA layers from the model.
+ """
+ if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
+ deprecate(
+ "set_processor to offload LoRA",
+ "0.26.0",
+ "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
+ )
+ # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
+ # We need to remove all LoRA layers
+ # Don't forget to remove ALL `_remove_lora` from the codebase
+ for module in self.modules():
+ if hasattr(module, "set_lora_layer"):
+ module.set_lora_layer(None)
+
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def get_processor(self, return_deprecated_lora: bool = False):
+ r"""
+ Get the attention processor in use.
+
+ Args:
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to return the deprecated LoRA attention processor.
+
+ Returns:
+ "AttentionProcessor": The attention processor in use.
+ """
+ if not return_deprecated_lora:
+ return self.processor
+
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
+ # with PEFT is completed.
+ is_lora_activated = {
+ name: module.lora_layer is not None
+ for name, module in self.named_modules()
+ if hasattr(module, "lora_layer")
+ }
+
+ # 1. if no layer has a LoRA activated we can return the processor as usual
+ if not any(is_lora_activated.values()):
+ return self.processor
+
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
+ is_lora_activated.pop("add_k_proj", None)
+ is_lora_activated.pop("add_v_proj", None)
+ # 2. else it is not posssible that only some layers have LoRA activated
+ if not all(is_lora_activated.values()):
+ raise ValueError(
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
+ )
+
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
+ non_lora_processor_cls_name = self.processor.__class__.__name__
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
+
+ hidden_size = self.inner_dim
+
+ # now create a LoRA attention processor from the LoRA layers
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
+ kwargs = {
+ "cross_attention_dim": self.cross_attention_dim,
+ "rank": self.to_q.lora_layer.rank,
+ "network_alpha": self.to_q.lora_layer.network_alpha,
+ "q_rank": self.to_q.lora_layer.rank,
+ "q_hidden_size": self.to_q.lora_layer.out_features,
+ "k_rank": self.to_k.lora_layer.rank,
+ "k_hidden_size": self.to_k.lora_layer.out_features,
+ "v_rank": self.to_v.lora_layer.rank,
+ "v_hidden_size": self.to_v.lora_layer.out_features,
+ "out_rank": self.to_out[0].lora_layer.rank,
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
+ }
+
+ if hasattr(self.processor, "attention_op"):
+ kwargs["attention_op"] = self.processor.attention_op
+
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
+ lora_processor = lora_processor_cls(
+ hidden_size,
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
+ rank=self.to_q.lora_layer.rank,
+ network_alpha=self.to_q.lora_layer.network_alpha,
+ )
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
+
+ # only save if used
+ if self.add_k_proj.lora_layer is not None:
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
+ else:
+ lora_processor.add_k_proj_lora = None
+ lora_processor.add_v_proj_lora = None
+ else:
+ raise ValueError(f"{lora_processor_cls} does not exist.")
+
+ return lora_processor
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+ r"""
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
+ is the number of heads initialized while constructing the `Attention` class.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+ r"""
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
+ the number of heads initialized while constructing the `Attention` class.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+
+ if out_dim == 3:
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
+
+ return tensor
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
+ ) -> torch.Tensor:
+ r"""
+ Compute the attention scores.
+
+ Args:
+ query (`torch.Tensor`): The query tensor.
+ key (`torch.Tensor`): The key tensor.
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
+
+ Returns:
+ `torch.Tensor`: The attention probabilities/scores.
+ """
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ r"""
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ The attention mask to prepare.
+ target_length (`int`):
+ The target length of the attention mask. This is the length of the attention mask after padding.
+ batch_size (`int`):
+ The batch size, which is used to repeat the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`):
+ The output dimension of the attention mask. Can be either `3` or `4`.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ r"""
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
+ `Attention` class.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+ Returns:
+ `torch.Tensor`: The normalized encoder hidden states.
+ """
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+
+ if isinstance(self.norm_cross, nn.LayerNorm):
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ elif isinstance(self.norm_cross, nn.GroupNorm):
+ # Group norm norms along the channels dimension and expects
+ # input to be in the shape of (N, C, *). In this case, we want
+ # to norm along the hidden dimension, so we need to move
+ # (batch_size, sequence_length, hidden_size) ->
+ # (batch_size, hidden_size, sequence_length)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ else:
+ assert False
+
+ return encoder_hidden_states
+
+ def _init_compress(self):
+ if len(self.compress_kv_factor) == 2:
+ self.sr = nn.Conv2d(
+ self.inner_dim,
+ self.inner_dim,
+ groups=self.inner_dim,
+ kernel_size=self.compress_kv_factor,
+ stride=self.compress_kv_factor,
+ )
+ self.sr.weight.data.fill_(1 / self.compress_kv_factor[0] ** 2)
+ elif len(self.compress_kv_factor) == 1:
+ self.kernel_size = self.compress_kv_factor[0]
+ self.sr = nn.Conv1d(
+ self.inner_dim,
+ self.inner_dim,
+ groups=self.inner_dim,
+ kernel_size=self.compress_kv_factor[0],
+ stride=self.compress_kv_factor[0],
+ )
+ self.sr.weight.data.fill_(1 / self.compress_kv_factor[0])
+ self.sr.bias.data.zero_()
+ self.norm = nn.LayerNorm(self.inner_dim)
+
+
+class AttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self, dim=1152, attention_mode="xformers", use_rope=False, rope_scaling=None, compress_kv_factor=None):
+ self.dim = dim
+ self.attention_mode = attention_mode
+ self.use_rope = use_rope
+ self.rope_scaling = rope_scaling
+ self.compress_kv_factor = compress_kv_factor
+ if self.use_rope:
+ self._init_rope()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def _init_rope(self):
+ if self.rope_scaling is None:
+ self.rope2d = RoPE2D()
+ self.rope1d = RoPE1D()
+ else:
+ scaling_type = self.rope_scaling["type"]
+ scaling_factor_2d = self.rope_scaling["factor_2d"]
+ scaling_factor_1d = self.rope_scaling["factor_1d"]
+ if scaling_type == "linear":
+ self.rope2d = LinearScalingRoPE2D(scaling_factor=scaling_factor_2d)
+ self.rope1d = LinearScalingRoPE1D(scaling_factor=scaling_factor_1d)
+ else:
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ scale: float = 1.0,
+ position_q: Optional[torch.LongTensor] = None,
+ position_k: Optional[torch.LongTensor] = None,
+ last_shape: Tuple[int] = None,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ if self.compress_kv_factor is not None:
+ batch_size = hidden_states.shape[0]
+ if len(last_shape) == 2:
+ encoder_hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, self.dim, *last_shape)
+ encoder_hidden_states = (
+ attn.sr(encoder_hidden_states).reshape(batch_size, self.dim, -1).permute(0, 2, 1)
+ )
+ elif len(last_shape) == 1:
+ encoder_hidden_states = hidden_states.permute(0, 2, 1)
+ if last_shape[0] % 2 == 1:
+ first_frame_pad = encoder_hidden_states[:, :, :1].repeat((1, 1, attn.kernel_size - 1))
+ encoder_hidden_states = torch.concatenate((first_frame_pad, encoder_hidden_states), dim=2)
+ encoder_hidden_states = attn.sr(encoder_hidden_states).permute(0, 2, 1)
+ else:
+ raise NotImplementedError(f"NotImplementedError with last_shape {last_shape}")
+
+ encoder_hidden_states = attn.norm(encoder_hidden_states)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+ query = attn.to_q(hidden_states, *args)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states, *args)
+ value = attn.to_v(encoder_hidden_states, *args)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if self.use_rope:
+ # require the shape of (batch_size x nheads x ntokens x dim)
+ if position_q.ndim == 3:
+ query = self.rope2d(query, position_q)
+ elif position_q.ndim == 2:
+ query = self.rope1d(query, position_q)
+ else:
+ raise NotImplementedError
+ if position_k.ndim == 3:
+ key = self.rope2d(key, position_k)
+ elif position_k.ndim == 2:
+ key = self.rope1d(key, position_k)
+ else:
+ raise NotImplementedError
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states, *args)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+ r"""
+ A gated self-attention dense layer that combines visual features and object features.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ context_dim (`int`): The number of channels in the context.
+ n_heads (`int`): The number of heads to use for attention.
+ d_head (`int`): The number of channels in each head.
+ """
+
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+ self.enabled = True
+
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
+ if not self.enabled:
+ return x
+
+ n_visual = x.shape[1]
+ objs = self.linear(objs)
+
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+ return x
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(linear_cls(inner_dim, dim_out))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
+ for module in self.net:
+ if isinstance(module, compatible_cls):
+ hidden_states = module(hidden_states, scale)
+ else:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock_(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ attention_mode: str = "xformers",
+ use_rope: bool = False,
+ rope_scaling: Optional[Dict] = None,
+ compress_kv_factor: Optional[Tuple] = None,
+ block_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ attention_mode=attention_mode,
+ use_rope=use_rope,
+ rope_scaling=rope_scaling,
+ compress_kv_factor=compress_kv_factor,
+ )
+
+ # # 2. Cross-Attn
+ # if cross_attention_dim is not None or double_self_attention:
+ # # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # # the second cross attention block.
+ # self.norm2 = (
+ # AdaLayerNorm(dim, num_embeds_ada_norm)
+ # if self.use_ada_layer_norm
+ # else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ # )
+ # self.attn2 = Attention(
+ # query_dim=dim,
+ # cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ # heads=num_attention_heads,
+ # dim_head=attention_head_dim,
+ # dropout=dropout,
+ # bias=attention_bias,
+ # upcast_attention=upcast_attention,
+ # ) # is self-attn if encoder_hidden_states is none
+ # else:
+ # self.norm2 = None
+ # self.attn2 = None
+
+ # 3. Feed-forward
+ # if not self.use_ada_layer_norm_single:
+ # self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if self.use_ada_layer_norm_single:
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ # pab
+ self.last_out = None
+ self.count = 0
+ self.block_idx = block_idx
+ self.temp_mlp_count = 0
+
+ # parallel
+ self.parallel_manager: ParallelManager = None
+
+ def set_last_out(self, last_out: torch.Tensor):
+ self.last_out = last_out
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ position_q: Optional[torch.LongTensor] = None,
+ position_k: Optional[torch.LongTensor] = None,
+ frame: int = None,
+ org_timestep: Optional[torch.LongTensor] = None,
+ all_timesteps=None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+ # 1. Retrieve lora scale.
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ broadcast_temporal, self.count = if_broadcast_temporal(int(org_timestep[0]), self.count)
+ if broadcast_temporal:
+ attn_output = self.last_out
+ assert self.use_ada_layer_norm_single
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ else:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.use_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.squeeze(1)
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ if self.parallel_manager.sp_size > 1:
+ norm_hidden_states = self.dynamic_switch(norm_hidden_states, to_spatial_shard=True)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ position_q=position_q,
+ position_k=position_k,
+ last_shape=frame,
+ **cross_attention_kwargs,
+ )
+
+ if self.parallel_manager.sp_size > 1:
+ attn_output = self.dynamic_switch(attn_output, to_spatial_shard=False)
+
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.use_ada_layer_norm_single:
+ attn_output = gate_msa * attn_output
+
+ if enable_pab():
+ self.set_last_out(attn_output)
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # # 3. Cross-Attention
+ # if self.attn2 is not None:
+ # if self.use_ada_layer_norm:
+ # norm_hidden_states = self.norm2(hidden_states, timestep)
+ # elif self.use_ada_layer_norm_zero or self.use_layer_norm:
+ # norm_hidden_states = self.norm2(hidden_states)
+ # elif self.use_ada_layer_norm_single:
+ # # For PixArt norm2 isn't applied here:
+ # # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ # norm_hidden_states = hidden_states
+ # else:
+ # raise ValueError("Incorrect norm")
+
+ # if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
+ # norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ # attn_output = self.attn2(
+ # norm_hidden_states,
+ # encoder_hidden_states=encoder_hidden_states,
+ # attention_mask=encoder_attention_mask,
+ # **cross_attention_kwargs,
+ # )
+ # hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ # if not self.use_ada_layer_norm_single:
+ # norm_hidden_states = self.norm3(hidden_states)
+
+ if enable_pab():
+ broadcast_mlp, self.temp_mlp_count, broadcast_next, broadcast_range = if_broadcast_mlp(
+ int(org_timestep[0]),
+ self.temp_mlp_count,
+ self.block_idx,
+ all_timesteps.tolist(),
+ is_temporal=True,
+ )
+
+ if enable_pab() and broadcast_mlp:
+ ff_output = get_mlp_output(
+ broadcast_range,
+ timestep=int(org_timestep[0]),
+ block_idx=self.block_idx,
+ is_temporal=True,
+ )
+ else:
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.use_ada_layer_norm_single:
+ # norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = self.norm3(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [
+ self.ff(hid_slice, scale=lora_scale)
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
+ ],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.use_ada_layer_norm_single:
+ ff_output = gate_mlp * ff_output
+
+ if enable_pab() and broadcast_next:
+ save_mlp_output(
+ timestep=int(org_timestep[0]),
+ block_idx=self.block_idx,
+ ff_output=ff_output,
+ is_temporal=True,
+ )
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+ def dynamic_switch(self, x, to_spatial_shard: bool):
+ if to_spatial_shard:
+ scatter_dim, gather_dim = 0, 1
+ scatter_pad = get_pad("spatial")
+ gather_pad = get_pad("temporal")
+ else:
+ scatter_dim, gather_dim = 1, 0
+ scatter_pad = get_pad("temporal")
+ gather_pad = get_pad("spatial")
+ x = all_to_all_with_pad(
+ x,
+ self.parallel_manager.sp_group,
+ scatter_dim=scatter_dim,
+ gather_dim=gather_dim,
+ scatter_pad=scatter_pad,
+ gather_pad=gather_pad,
+ )
+ return x
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ attention_mode: str = "xformers",
+ use_rope: bool = False,
+ rope_scaling: Optional[Dict] = None,
+ compress_kv_factor: Optional[Tuple] = None,
+ block_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ attention_mode=attention_mode,
+ use_rope=use_rope,
+ rope_scaling=rope_scaling,
+ compress_kv_factor=compress_kv_factor,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = (
+ AdaLayerNorm(dim, num_embeds_ada_norm)
+ if self.use_ada_layer_norm
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ )
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ attention_mode=attention_mode, # only xformers support attention_mask
+ use_rope=False, # do not position in cross attention
+ compress_kv_factor=None,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ if not self.use_ada_layer_norm_single:
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ )
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if self.use_ada_layer_norm_single:
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ # pab
+ self.cross_last = None
+ self.cross_count = 0
+ self.spatial_last = None
+ self.spatial_count = 0
+ self.block_idx = block_idx
+ self.spatila_mlp_count = 0
+
+ def set_cross_last(self, last_out: torch.Tensor):
+ self.cross_last = last_out
+
+ def set_spatial_last(self, last_out: torch.Tensor):
+ self.spatial_last = last_out
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ position_q: Optional[torch.LongTensor] = None,
+ position_k: Optional[torch.LongTensor] = None,
+ hw: Tuple[int, int] = None,
+ org_timestep: Optional[torch.LongTensor] = None,
+ all_timesteps=None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ # 1. Retrieve lora scale.
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ broadcast_spatial, self.spatial_count = if_broadcast_spatial(int(org_timestep[0]), self.spatial_count)
+ if broadcast_spatial:
+ attn_output = self.spatial_last
+ assert self.use_ada_layer_norm_single
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ else:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.use_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.squeeze(1)
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ position_q=position_q,
+ position_k=position_k,
+ last_shape=hw,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.use_ada_layer_norm_single:
+ attn_output = gate_msa * attn_output
+
+ if enable_pab():
+ self.set_spatial_last(attn_output)
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ broadcast_cross, self.cross_count = if_broadcast_cross(int(org_timestep[0]), self.cross_count)
+ if broadcast_cross:
+ hidden_states = hidden_states + self.cross_last
+ else:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ # For PixArt norm2 isn't applied here:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ norm_hidden_states = hidden_states
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ position_q=None, # cross attn do not need relative position
+ position_k=None,
+ last_shape=None,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ if enable_pab():
+ self.set_cross_last(attn_output)
+
+ if enable_pab():
+ broadcast_mlp, self.spatila_mlp_count, broadcast_next, broadcast_range = if_broadcast_mlp(
+ int(org_timestep[0]),
+ self.spatila_mlp_count,
+ self.block_idx,
+ all_timesteps.tolist(),
+ is_temporal=False,
+ )
+
+ if enable_pab() and broadcast_mlp:
+ ff_output = get_mlp_output(
+ broadcast_range,
+ timestep=int(org_timestep[0]),
+ block_idx=self.block_idx,
+ is_temporal=False,
+ )
+ else:
+ # 4. Feed-forward
+ if not self.use_ada_layer_norm_single:
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.use_ada_layer_norm_single:
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.use_ada_layer_norm_single:
+ ff_output = gate_mlp * ff_output
+
+ if enable_pab() and broadcast_next:
+ save_mlp_output(
+ timestep=int(org_timestep[0]),
+ block_idx=self.block_idx,
+ ff_output=ff_output,
+ is_temporal=False,
+ )
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+
+class AdaLayerNormSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm single (adaLN-single).
+
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
+ """
+
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.emb = CombinedTimestepSizeEmbeddings(
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
+ )
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ batch_size: int = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # No modulation happening here.
+ embedded_timestep = self.emb(
+ timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None
+ )
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
+
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ """
+ The output of [`Transformer2DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
+ distributions for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+class LatteT2V(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ """
+ A 2D Transformer model for image-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ num_vector_embeds (`int`, *optional*):
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ patch_size_t: int = 1,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ attention_type: str = "default",
+ caption_channels: int = None,
+ video_length: int = 16,
+ attention_mode: str = "flash",
+ use_rope: bool = False,
+ model_max_length: int = 300,
+ rope_scaling_type: str = "linear",
+ compress_kv_factor: int = 1,
+ interpolation_scale_1d: float = None,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+ self.video_length = video_length
+ self.use_rope = use_rope
+ self.model_max_length = model_max_length
+ self.compress_kv_factor = compress_kv_factor
+ self.num_layers = num_layers
+ self.config.hidden_size = model_max_length
+
+ assert not (self.compress_kv_factor != 1 and use_rope), "Can not both enable compressing kv and using rope"
+
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
+
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
+ self.is_input_vectorized = num_vector_embeds is not None
+ # self.is_input_patches = in_channels is not None and patch_size is not None
+ self.is_input_patches = True
+
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
+ deprecation_message = (
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+ )
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
+ norm_type = "ada_norm"
+
+ # 2. Define input layers
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
+
+ self.height = sample_size[0]
+ self.width = sample_size[1]
+
+ self.patch_size = patch_size
+ interpolation_scale_2d = self.config.sample_size[0] // 64 # => 64 (= 512 pixart) has interpolation scale 1
+ interpolation_scale_2d = max(interpolation_scale_2d, 1)
+ self.pos_embed = PatchEmbed(
+ height=sample_size[0],
+ width=sample_size[1],
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ interpolation_scale=interpolation_scale_2d,
+ )
+
+ # define temporal positional embedding
+ if interpolation_scale_1d is None:
+ if self.config.video_length % 2 == 1:
+ interpolation_scale_1d = (
+ self.config.video_length - 1
+ ) // 16 # => 16 (= 16 Latte) has interpolation scale 1
+ else:
+ interpolation_scale_1d = self.config.video_length // 16 # => 16 (= 16 Latte) has interpolation scale 1
+ # interpolation_scale_1d = self.config.video_length // 5 #
+ interpolation_scale_1d = max(interpolation_scale_1d, 1)
+ temp_pos_embed = get_1d_sincos_pos_embed(
+ inner_dim, video_length, interpolation_scale=interpolation_scale_1d
+ ) # 1152 hidden size
+ self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
+
+ rope_scaling = None
+ if self.use_rope:
+ self.position_getter_2d = PositionGetter2D()
+ self.position_getter_1d = PositionGetter1D()
+ rope_scaling = dict(
+ type=rope_scaling_type, factor_2d=interpolation_scale_2d, factor_1d=interpolation_scale_1d
+ )
+
+ # 3. Define transformers blocks, spatial attention
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=double_self_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ attention_mode=attention_mode,
+ use_rope=use_rope,
+ rope_scaling=rope_scaling,
+ compress_kv_factor=(compress_kv_factor, compress_kv_factor)
+ if d >= num_layers // 2 and compress_kv_factor != 1
+ else None, # follow pixart-sigma, apply in second-half layers
+ block_idx=d,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # Define temporal transformers blocks
+ self.temporal_transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock_( # one attention
+ inner_dim,
+ num_attention_heads, # num_attention_heads
+ attention_head_dim, # attention_head_dim 72
+ dropout=dropout,
+ cross_attention_dim=None,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=False,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ attention_mode=attention_mode,
+ use_rope=use_rope,
+ rope_scaling=rope_scaling,
+ compress_kv_factor=(compress_kv_factor,)
+ if d >= num_layers // 2 and compress_kv_factor != 1
+ else None, # follow pixart-sigma, apply in second-half layers
+ block_idx=d,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+ if self.is_input_continuous:
+ # TODO: should use out_channels for continuous projections
+ if use_linear_projection:
+ self.proj_out = linear_cls(inner_dim, in_channels)
+ else:
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+ elif self.is_input_patches and norm_type != "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+ elif self.is_input_patches and norm_type == "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+
+ # 5. PixArt-Alpha blocks.
+ self.adaln_single = None
+ self.use_additional_conditions = False
+ if norm_type == "ada_norm_single":
+ # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
+ # additional conditions until we find better name
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
+
+ self.caption_projection = None
+ if caption_channels is not None:
+ self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
+
+ self.gradient_checkpointing = False
+
+ # parallel
+ self.parallel_manager: ParallelManager = None
+
+ def enable_parallel(self, dp_size, sp_size, enable_cp):
+ # update cfg parallel
+ if enable_cp and sp_size % 2 == 0:
+ sp_size = sp_size // 2
+ cp_size = 2
+ else:
+ cp_size = 1
+
+ self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size)
+
+ for _, module in self.named_modules():
+ if hasattr(module, "parallel_manager"):
+ module.parallel_manager = self.parallel_manager
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ def make_position(self, b, t, use_image_num, h, w, device):
+ pos_hw = self.position_getter_2d(b * (t + use_image_num), h, w, device) # fake_b = b*(t+use_image_num)
+ pos_t = self.position_getter_1d(b * h * w, t, device) # fake_b = b*h*w
+ return pos_hw, pos_t
+
+ def make_attn_mask(self, attention_mask, frame, dtype):
+ attention_mask = rearrange(attention_mask, "b t h w -> (b t) 1 (h w)")
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(dtype)) * -10000.0
+ attention_mask = attention_mask.to(self.dtype)
+ return attention_mask
+
+ def vae_to_diff_mask(self, attention_mask, use_image_num):
+ dtype = attention_mask.dtype
+ # b, t+use_image_num, h, w, assume t as channel
+ # this version do not use 3d patch embedding
+ attention_mask = F.max_pool2d(
+ attention_mask, kernel_size=(self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size)
+ )
+ attention_mask = attention_mask.bool().to(dtype)
+ return attention_mask
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: Optional[torch.LongTensor] = None,
+ all_timesteps=None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_image_num: int = 0,
+ enable_temporal_attentions: bool = True,
+ return_dict: bool = True,
+ ):
+ """
+ The [`Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ attention_mask ( `torch.Tensor`, *optional*):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ # 0. Split batch
+ if self.parallel_manager.cp_size > 1:
+ (
+ hidden_states,
+ timestep,
+ encoder_hidden_states,
+ class_labels,
+ attention_mask,
+ encoder_attention_mask,
+ ) = batch_func(
+ partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
+ hidden_states,
+ timestep,
+ encoder_hidden_states,
+ class_labels,
+ attention_mask,
+ encoder_attention_mask,
+ )
+ input_batch_size, c, frame, h, w = hidden_states.shape
+ frame = frame - use_image_num # 20-4=16
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
+ org_timestep = timestep
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (input_batch_size, frame + use_image_num, h, w), device=hidden_states.device, dtype=hidden_states.dtype
+ )
+ attention_mask = self.vae_to_diff_mask(attention_mask, use_image_num)
+ dtype = attention_mask.dtype
+ attention_mask_compress = F.max_pool2d(
+ attention_mask.float(), kernel_size=self.compress_kv_factor, stride=self.compress_kv_factor
+ )
+ attention_mask_compress = attention_mask_compress.to(dtype)
+
+ attention_mask = self.make_attn_mask(attention_mask, frame, hidden_states.dtype)
+ attention_mask_compress = self.make_attn_mask(attention_mask_compress, frame, hidden_states.dtype)
+
+ # 1 + 4, 1 -> video condition, 4 -> image condition
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+ encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous()
+ encoder_attention_mask = encoder_attention_mask.to(self.dtype)
+ elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
+ encoder_attention_mask_video = repeat(
+ encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame
+ ).contiguous()
+ encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
+ encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1)
+ encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1)
+ encoder_attention_mask = encoder_attention_mask.to(self.dtype)
+
+ # Retrieve lora scale.
+ cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 1. Input
+ if self.is_input_patches: # here
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
+ hw = (height, width)
+ num_patches = height * width
+
+ hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # alrady add positional embeddings
+
+ if self.adaln_single is not None:
+ if self.use_additional_conditions and added_cond_kwargs is None:
+ raise ValueError(
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
+ )
+ # batch_size = hidden_states.shape[0]
+ batch_size = input_batch_size
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ # 2. Blocks
+ if self.caption_projection is not None:
+ batch_size = hidden_states.shape[0]
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states.to(self.dtype)) # 3 120 1152
+
+ if use_image_num != 0 and self.training:
+ encoder_hidden_states_video = encoder_hidden_states[:, :1, ...]
+ encoder_hidden_states_video = repeat(
+ encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame
+ ).contiguous()
+ encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...]
+ encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
+ encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous()
+ else:
+ encoder_hidden_states_spatial = repeat(
+ encoder_hidden_states, "b 1 t d -> (b f) t d", f=frame
+ ).contiguous()
+
+ # prepare timesteps for spatial and temporal block
+ timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
+ timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
+
+ pos_hw, pos_t = None, None
+ if self.use_rope:
+ pos_hw, pos_t = self.make_position(
+ input_batch_size, frame, use_image_num, height, width, hidden_states.device
+ )
+
+ if self.parallel_manager.sp_size > 1:
+ set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
+ set_pad("spatial", num_patches, self.parallel_manager.sp_group)
+ hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
+ encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
+ timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
+ attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
+ attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
+ temp_pos_embed = split_sequence(
+ self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
+ )
+ else:
+ temp_pos_embed = self.temp_pos_embed
+
+ for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
+ if self.training and self.gradient_checkpointing:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ spatial_block,
+ hidden_states,
+ attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
+ encoder_hidden_states_spatial,
+ encoder_attention_mask,
+ timestep_spatial,
+ cross_attention_kwargs,
+ class_labels,
+ pos_hw,
+ pos_hw,
+ hw,
+ use_reentrant=False,
+ )
+
+ if enable_temporal_attentions:
+ hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
+
+ if use_image_num != 0: # image-video joitn training
+ hidden_states_video = hidden_states[:, :frame, ...]
+ hidden_states_image = hidden_states[:, frame:, ...]
+
+ # if i == 0 and not self.use_rope:
+ if i == 0:
+ hidden_states_video = hidden_states_video + temp_pos_embed
+
+ hidden_states_video = torch.utils.checkpoint.checkpoint(
+ temp_block,
+ hidden_states_video,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ pos_t,
+ pos_t,
+ (frame,),
+ use_reentrant=False,
+ )
+
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
+ hidden_states = rearrange(
+ hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
+ ).contiguous()
+
+ else:
+ # if i == 0 and not self.use_rope:
+ if i == 0:
+ hidden_states = hidden_states + temp_pos_embed
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ temp_block,
+ hidden_states,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ pos_t,
+ pos_t,
+ (frame,),
+ use_reentrant=False,
+ )
+
+ hidden_states = rearrange(
+ hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
+ ).contiguous()
+ else:
+ hidden_states = spatial_block(
+ hidden_states,
+ attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
+ encoder_hidden_states_spatial,
+ encoder_attention_mask,
+ timestep_spatial,
+ cross_attention_kwargs,
+ class_labels,
+ pos_hw,
+ pos_hw,
+ hw,
+ org_timestep,
+ all_timesteps=all_timesteps,
+ )
+
+ if enable_temporal_attentions:
+ # b c f h w, f = 16 + 4
+ hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
+
+ if use_image_num != 0 and self.training:
+ hidden_states_video = hidden_states[:, :frame, ...]
+ hidden_states_image = hidden_states[:, frame:, ...]
+
+ # if i == 0 and not self.use_rope:
+ # hidden_states_video = hidden_states_video + temp_pos_embed
+
+ hidden_states_video = temp_block(
+ hidden_states_video,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ pos_t,
+ pos_t,
+ (frame,),
+ org_timestep,
+ )
+
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
+ hidden_states = rearrange(
+ hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
+ ).contiguous()
+
+ else:
+ # if i == 0 and not self.use_rope:
+ if i == 0:
+ hidden_states = hidden_states + temp_pos_embed
+ hidden_states = temp_block(
+ hidden_states,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ pos_t,
+ pos_t,
+ (frame,),
+ org_timestep,
+ all_timesteps=all_timesteps,
+ )
+
+ hidden_states = rearrange(
+ hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
+ ).contiguous()
+
+ if self.parallel_manager.sp_size > 1:
+ hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
+
+ if self.is_input_patches:
+ if self.config.norm_type != "ada_norm_single":
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+ elif self.config.norm_type == "ada_norm_single":
+ embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale) + shift
+ hidden_states = self.proj_out(hidden_states)
+
+ # unpatchify
+ if self.adaln_single is None:
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+ output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
+
+ # 3. Gather batch for data parallelism
+ if self.parallel_manager.cp_size > 1:
+ output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+
+ config_file = os.path.join(pretrained_model_path, "config.json")
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ model = cls.from_config(config, **kwargs)
+ return model
+
+ def split_from_second_dim(self, x, batch_size):
+ x = x.view(batch_size, -1, *x.shape[1:])
+ x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
+ x = x.reshape(-1, *x.shape[2:])
+ return x
+
+ def gather_from_second_dim(self, x, batch_size):
+ x = x.view(batch_size, -1, *x.shape[1:])
+ x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
+ x = x.reshape(-1, *x.shape[2:])
+ return x
+
+
+# depth = num_layers * 2
+def LatteT2V_XL_122(**kwargs):
+ return LatteT2V(
+ num_layers=28,
+ attention_head_dim=72,
+ num_attention_heads=16,
+ patch_size_t=1,
+ patch_size=2,
+ norm_type="ada_norm_single",
+ caption_channels=4096,
+ cross_attention_dim=1152,
+ **kwargs,
+ )
+
+
+def LatteT2V_D64_XL_122(**kwargs):
+ return LatteT2V(
+ num_layers=28,
+ attention_head_dim=64,
+ num_attention_heads=18,
+ patch_size_t=1,
+ patch_size=2,
+ norm_type="ada_norm_single",
+ caption_channels=4096,
+ cross_attention_dim=1152,
+ **kwargs,
+ )
+
+
+Latte_models = {
+ "LatteT2V-XL/122": LatteT2V_XL_122,
+ "LatteT2V-D64-XL/122": LatteT2V_D64_XL_122,
+}
diff --git a/videosys/models/transformers/open_sora_plan_v120_transformer_3d.py b/videosys/models/transformers/open_sora_plan_v120_transformer_3d.py
new file mode 100644
index 0000000..761d0d9
--- /dev/null
+++ b/videosys/models/transformers/open_sora_plan_v120_transformer_3d.py
@@ -0,0 +1,2183 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+import collections
+import re
+from typing import Any, Dict, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.attention import FeedForward, GatedSelfAttentionDense
+from diffusers.models.attention_processor import Attention as Attention_
+from diffusers.models.embeddings import PixArtAlphaTextProjection, SinusoidalPositionalEmbedding
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero
+from diffusers.utils import deprecate, is_torch_version
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from einops import rearrange, repeat
+from torch import nn
+from torch.nn import functional as F
+
+from videosys.core.comm import all_to_all_comm, gather_sequence, split_sequence
+from videosys.core.pab_mgr import enable_pab, if_broadcast_cross, if_broadcast_spatial
+from videosys.core.parallel_mgr import ParallelManager
+from videosys.core.pipeline import VideoSysPipelineOutput
+
+torch_npu = None
+npu_config = None
+set_run_dtype = None
+
+
+class PositionGetter3D(object):
+ """return positions of patches"""
+
+ def __init__(
+ self,
+ ):
+ self.cache_positions = {}
+
+ def __call__(self, b, t, h, w, device):
+ if not (b, t, h, w) in self.cache_positions:
+ x = torch.arange(w, device=device)
+ y = torch.arange(h, device=device)
+ z = torch.arange(t, device=device)
+ pos = torch.cartesian_prod(z, y, x)
+ pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, b, -1).clone()
+ poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous())
+ max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max()))
+
+ self.cache_positions[b, t, h, w] = (poses, max_poses)
+ pos = self.cache_positions[b, t, h, w]
+
+ return pos
+
+
+class RoPE3D(torch.nn.Module):
+ def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)):
+ super().__init__()
+ self.base = freq
+ self.F0 = F0
+ self.interpolation_scale_t = interpolation_scale_thw[0]
+ self.interpolation_scale_h = interpolation_scale_thw[1]
+ self.interpolation_scale_w = interpolation_scale_thw[2]
+ self.cache = {}
+
+ def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1):
+ if (D, seq_len, device, dtype) not in self.cache:
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
+ freqs = torch.cat((freqs, freqs), dim=-1)
+ cos = freqs.cos() # (Seq, Dim)
+ sin = freqs.sin()
+ self.cache[D, seq_len, device, dtype] = (cos, sin)
+ return self.cache[D, seq_len, device, dtype]
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
+ assert pos1d.ndim == 2
+ # for (batch_size x ntokens x nheads x dim)
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
+
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
+
+ def forward(self, tokens, positions):
+ """
+ input:
+ * tokens: batch_size x nheads x ntokens x dim
+ * positions: batch_size x ntokens x 3 (t, y and x position of each token)
+ output:
+ * tokens after appplying RoPE3D (batch_size x nheads x ntokens x x dim)
+ """
+ assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three"
+ D = tokens.size(3) // 3
+ poses, max_poses = positions
+ assert len(poses) == 3 and poses[0].ndim == 2 # Batch, Seq, 3
+ cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t)
+ cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h)
+ cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w)
+ # split features into three along the feature dimension, and apply rope1d on each half
+ t, y, x = tokens.chunk(3, dim=-1)
+ t = self.apply_rope1d(t, poses[0], cos_t, sin_t)
+ y = self.apply_rope1d(y, poses[1], cos_y, sin_y)
+ x = self.apply_rope1d(x, poses[2], cos_x, sin_x)
+ tokens = torch.cat((t, y, x), dim=-1)
+ return tokens
+
+
+def get_3d_sincos_pos_embed(
+ embed_dim,
+ grid_size,
+ cls_token=False,
+ extra_tokens=0,
+ interpolation_scale=1.0,
+ base_size=16,
+):
+ """
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ # if isinstance(grid_size, int):
+ # grid_size = (grid_size, grid_size)
+ grid_t = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size[0]) / interpolation_scale[0]
+ grid_h = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size[1]) / interpolation_scale[1]
+ grid_w = np.arange(grid_size[2], dtype=np.float32) / (grid_size[2] / base_size[2]) / interpolation_scale[2]
+ grid = np.meshgrid(grid_w, grid_h, grid_t) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([3, 1, grid_size[2], grid_size[1], grid_size[0]])
+ pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
+ # import ipdb;ipdb.set_trace()
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
+ if embed_dim % 3 != 0:
+ raise ValueError("embed_dim must be divisible by 3")
+
+ # import ipdb;ipdb.set_trace()
+ # use 1/3 of dimensions to encode grid_t/h/w
+ emb_t = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (T*H*W, D/3)
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (T*H*W, D/3)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (T*H*W, D/3)
+
+ emb = np.concatenate([emb_t, emb_h, emb_w], axis=1) # (T*H*W, D)
+ return emb
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim,
+ grid_size,
+ cls_token=False,
+ extra_tokens=0,
+ interpolation_scale=1.0,
+ base_size=16,
+):
+ """
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ # if isinstance(grid_size, int):
+ # grid_size = (grid_size, grid_size)
+
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size[0]) / interpolation_scale[0]
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size[1]) / interpolation_scale[1]
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ # use 1/3 of dimensions to encode grid_t/h/w
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed(
+ embed_dim,
+ grid_size,
+ cls_token=False,
+ extra_tokens=0,
+ interpolation_scale=1.0,
+ base_size=16,
+):
+ """
+ grid_size: int of the grid return: pos_embed: [grid_size, embed_dim] or
+ [1+grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ # if isinstance(grid_size, int):
+ # grid_size = (grid_size, grid_size)
+
+ grid = np.arange(grid_size, dtype=np.float32) / (grid_size / base_size) / interpolation_scale
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) # (H*W, D/2)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
+ """
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+class PatchEmbed2D(nn.Module):
+ """2D Image to Patch Embedding but with 3D position embedding"""
+
+ def __init__(
+ self,
+ num_frames=1,
+ height=224,
+ width=224,
+ patch_size_t=1,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ interpolation_scale=(1, 1),
+ interpolation_scale_t=1,
+ use_abs_pos=True,
+ ):
+ super().__init__()
+ # assert num_frames == 1
+ self.use_abs_pos = use_abs_pos
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ self.patch_size_t = patch_size_t
+ self.patch_size = patch_size
+ # See:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
+
+ self.height, self.width = height // patch_size, width // patch_size
+ self.base_size = (height // patch_size, width // patch_size)
+ self.interpolation_scale = (interpolation_scale[0], interpolation_scale[1])
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, (self.height, self.width), base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ )
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
+
+ self.num_frames = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t
+ self.base_size_t = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t
+ self.interpolation_scale_t = interpolation_scale_t
+ temp_pos_embed = get_1d_sincos_pos_embed(
+ embed_dim, self.num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t
+ )
+ self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
+ # self.temp_embed_gate = nn.Parameter(torch.tensor([0.0]))
+
+ def forward(self, latent, num_frames):
+ b, _, _, _, _ = latent.shape
+ video_latent, image_latent = None, None
+ # b c 1 h w
+ # assert latent.shape[-3] == 1 and num_frames == 1
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
+ latent = rearrange(latent, "b c t h w -> (b t) c h w")
+ latent = self.proj(latent)
+
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C
+ if self.layer_norm:
+ latent = self.norm(latent)
+
+ if self.use_abs_pos:
+ # Interpolate positional embeddings if needed.
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
+ if self.height != height or self.width != width:
+ # raise NotImplementedError
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=self.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed)
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = self.pos_embed
+
+ if self.num_frames != num_frames:
+ raise NotImplementedError
+ else:
+ temp_pos_embed = self.temp_pos_embed
+
+ latent = (latent + pos_embed).to(latent.dtype)
+
+ latent = rearrange(latent, "(b t) n c -> b t n c", b=b)
+ video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:]
+
+ if self.use_abs_pos:
+ # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh()
+ temp_pos_embed = temp_pos_embed.unsqueeze(2)
+ video_latent = (
+ (video_latent + temp_pos_embed).to(video_latent.dtype)
+ if video_latent is not None and video_latent.numel() > 0
+ else None
+ )
+ image_latent = (
+ (image_latent + temp_pos_embed[:, :1]).to(image_latent.dtype)
+ if image_latent is not None and image_latent.numel() > 0
+ else None
+ )
+
+ video_latent = (
+ rearrange(video_latent, "b t n c -> b (t n) c")
+ if video_latent is not None and video_latent.numel() > 0
+ else None
+ )
+ image_latent = (
+ rearrange(image_latent, "b t n c -> (b t) n c")
+ if image_latent is not None and image_latent.numel() > 0
+ else None
+ )
+
+ if num_frames == 1 and image_latent is None:
+ image_latent = video_latent
+ video_latent = None
+ # print('video_latent is None, image_latent is None', video_latent is None, image_latent is None)
+ return video_latent, image_latent
+
+
+class OverlapPatchEmbed3D(nn.Module):
+ """2D Image to Patch Embedding but with 3D position embedding"""
+
+ def __init__(
+ self,
+ num_frames=1,
+ height=224,
+ width=224,
+ patch_size_t=1,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ interpolation_scale=(1, 1),
+ interpolation_scale_t=1,
+ use_abs_pos=True,
+ ):
+ super().__init__()
+ # assert patch_size_t == 1 and patch_size == 1
+ self.use_abs_pos = use_abs_pos
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+
+ self.proj = nn.Conv3d(
+ in_channels,
+ embed_dim,
+ kernel_size=(patch_size_t, patch_size, patch_size),
+ stride=(patch_size_t, patch_size, patch_size),
+ bias=bias,
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ self.patch_size_t = patch_size_t
+ self.patch_size = patch_size
+ # See:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
+
+ self.height, self.width = height // patch_size, width // patch_size
+ self.base_size = (height // patch_size, width // patch_size)
+ self.interpolation_scale = (interpolation_scale[0], interpolation_scale[1])
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, (self.height, self.width), base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ )
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
+
+ self.num_frames = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t
+ self.base_size_t = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t
+ self.interpolation_scale_t = interpolation_scale_t
+ temp_pos_embed = get_1d_sincos_pos_embed(
+ embed_dim, self.num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t
+ )
+ self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
+ # self.temp_embed_gate = nn.Parameter(torch.tensor([0.0]))
+
+ def forward(self, latent, num_frames):
+ b, _, _, _, _ = latent.shape
+ video_latent, image_latent = None, None
+ # b c 1 h w
+ # assert latent.shape[-3] == 1 and num_frames == 1
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
+ # latent = rearrange(latent, 'b c t h w -> (b t) c h w')
+ latent = self.proj(latent)
+
+ if self.flatten:
+ # latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C
+ latent = rearrange(latent, "b c t h w -> (b t) (h w) c ")
+ if self.layer_norm:
+ latent = self.norm(latent)
+
+ if self.use_abs_pos:
+ # Interpolate positional embeddings if needed.
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
+ if self.height != height or self.width != width:
+ # raise NotImplementedError
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=self.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed)
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = self.pos_embed
+
+ if self.num_frames != num_frames:
+ # import ipdb;ipdb.set_trace()
+ # raise NotImplementedError
+ temp_pos_embed = get_1d_sincos_pos_embed(
+ embed_dim=self.temp_pos_embed.shape[-1],
+ grid_size=num_frames,
+ base_size=self.base_size_t,
+ interpolation_scale=self.interpolation_scale_t,
+ )
+ temp_pos_embed = torch.from_numpy(temp_pos_embed)
+ temp_pos_embed = temp_pos_embed.float().unsqueeze(0).to(latent.device)
+ else:
+ temp_pos_embed = self.temp_pos_embed
+
+ latent = (latent + pos_embed).to(latent.dtype)
+
+ latent = rearrange(latent, "(b t) n c -> b t n c", b=b)
+ video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:]
+
+ if self.use_abs_pos:
+ # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh()
+ temp_pos_embed = temp_pos_embed.unsqueeze(2)
+ video_latent = (
+ (video_latent + temp_pos_embed).to(video_latent.dtype)
+ if video_latent is not None and video_latent.numel() > 0
+ else None
+ )
+ image_latent = (
+ (image_latent + temp_pos_embed[:, :1]).to(image_latent.dtype)
+ if image_latent is not None and image_latent.numel() > 0
+ else None
+ )
+
+ video_latent = (
+ rearrange(video_latent, "b t n c -> b (t n) c")
+ if video_latent is not None and video_latent.numel() > 0
+ else None
+ )
+ image_latent = (
+ rearrange(image_latent, "b t n c -> (b t) n c")
+ if image_latent is not None and image_latent.numel() > 0
+ else None
+ )
+
+ if num_frames == 1 and image_latent is None:
+ image_latent = video_latent
+ video_latent = None
+ return video_latent, image_latent
+
+
+class OverlapPatchEmbed2D(nn.Module):
+ """2D Image to Patch Embedding but with 3D position embedding"""
+
+ def __init__(
+ self,
+ num_frames=1,
+ height=224,
+ width=224,
+ patch_size_t=1,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ interpolation_scale=(1, 1),
+ interpolation_scale_t=1,
+ use_abs_pos=True,
+ ):
+ super().__init__()
+ assert patch_size_t == 1
+ self.use_abs_pos = use_abs_pos
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ self.patch_size_t = patch_size_t
+ self.patch_size = patch_size
+ # See:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
+
+ self.height, self.width = height // patch_size, width // patch_size
+ self.base_size = (height // patch_size, width // patch_size)
+ self.interpolation_scale = (interpolation_scale[0], interpolation_scale[1])
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, (self.height, self.width), base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ )
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
+
+ self.num_frames = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t
+ self.base_size_t = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t
+ self.interpolation_scale_t = interpolation_scale_t
+ temp_pos_embed = get_1d_sincos_pos_embed(
+ embed_dim, self.num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t
+ )
+ self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
+ # self.temp_embed_gate = nn.Parameter(torch.tensor([0.0]))
+
+ def forward(self, latent, num_frames):
+ b, _, _, _, _ = latent.shape
+ video_latent, image_latent = None, None
+ # b c 1 h w
+ # assert latent.shape[-3] == 1 and num_frames == 1
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
+ latent = rearrange(latent, "b c t h w -> (b t) c h w")
+ latent = self.proj(latent)
+
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C
+ if self.layer_norm:
+ latent = self.norm(latent)
+
+ if self.use_abs_pos:
+ # Interpolate positional embeddings if needed.
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
+ if self.height != height or self.width != width:
+ # raise NotImplementedError
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=self.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed)
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = self.pos_embed
+
+ if self.num_frames != num_frames:
+ # import ipdb;ipdb.set_trace()
+ # raise NotImplementedError
+ temp_pos_embed = get_1d_sincos_pos_embed(
+ embed_dim=self.temp_pos_embed.shape[-1],
+ grid_size=num_frames,
+ base_size=self.base_size_t,
+ interpolation_scale=self.interpolation_scale_t,
+ )
+ temp_pos_embed = torch.from_numpy(temp_pos_embed)
+ temp_pos_embed = temp_pos_embed.float().unsqueeze(0).to(latent.device)
+ else:
+ temp_pos_embed = self.temp_pos_embed
+
+ latent = (latent + pos_embed).to(latent.dtype)
+
+ latent = rearrange(latent, "(b t) n c -> b t n c", b=b)
+ video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:]
+
+ if self.use_abs_pos:
+ # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh()
+ temp_pos_embed = temp_pos_embed.unsqueeze(2)
+ video_latent = (
+ (video_latent + temp_pos_embed).to(video_latent.dtype)
+ if video_latent is not None and video_latent.numel() > 0
+ else None
+ )
+ image_latent = (
+ (image_latent + temp_pos_embed[:, :1]).to(image_latent.dtype)
+ if image_latent is not None and image_latent.numel() > 0
+ else None
+ )
+
+ video_latent = (
+ rearrange(video_latent, "b t n c -> b (t n) c")
+ if video_latent is not None and video_latent.numel() > 0
+ else None
+ )
+ image_latent = (
+ rearrange(image_latent, "b t n c -> (b t) n c")
+ if image_latent is not None and image_latent.numel() > 0
+ else None
+ )
+
+ if num_frames == 1 and image_latent is None:
+ image_latent = video_latent
+ video_latent = None
+ return video_latent, image_latent
+
+
+class Attention(Attention_):
+ def __init__(self, downsampler, attention_mode, use_rope, interpolation_scale_thw, **kwags):
+ processor = AttnProcessor2_0(
+ attention_mode=attention_mode, use_rope=use_rope, interpolation_scale_thw=interpolation_scale_thw
+ )
+ super().__init__(processor=processor, **kwags)
+ self.downsampler = None
+ if downsampler: # downsampler k155_s122
+ downsampler_ker_size = list(re.search(r"k(\d{2,3})", downsampler).group(1)) # 122
+ down_factor = list(re.search(r"s(\d{2,3})", downsampler).group(1))
+ downsampler_ker_size = [int(i) for i in downsampler_ker_size]
+ downsampler_padding = [(i - 1) // 2 for i in downsampler_ker_size]
+ down_factor = [int(i) for i in down_factor]
+
+ if len(downsampler_ker_size) == 2:
+ self.downsampler = DownSampler2d(
+ kwags["query_dim"],
+ kwags["query_dim"],
+ kernel_size=downsampler_ker_size,
+ stride=1,
+ padding=downsampler_padding,
+ groups=kwags["query_dim"],
+ down_factor=down_factor,
+ down_shortcut=True,
+ )
+ elif len(downsampler_ker_size) == 3:
+ self.downsampler = DownSampler3d(
+ kwags["query_dim"],
+ kwags["query_dim"],
+ kernel_size=downsampler_ker_size,
+ stride=1,
+ padding=downsampler_padding,
+ groups=kwags["query_dim"],
+ down_factor=down_factor,
+ down_shortcut=True,
+ )
+
+ # parallel
+ self.parallel_manager: ParallelManager = None
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ r"""
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ The attention mask to prepare.
+ target_length (`int`):
+ The target length of the attention mask. This is the length of the attention mask after padding.
+ batch_size (`int`):
+ The batch size, which is used to repeat the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`):
+ The output dimension of the attention mask. Can be either `3` or `4`.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+
+class DownSampler3d(nn.Module):
+ def __init__(self, *args, **kwargs):
+ """Required kwargs: down_factor, downsampler"""
+ super().__init__()
+ self.down_factor = kwargs.pop("down_factor")
+ self.down_shortcut = kwargs.pop("down_shortcut")
+ self.layer = nn.Conv3d(*args, **kwargs)
+
+ def forward(self, x, attention_mask, t, h, w):
+ x.shape[0]
+ x = rearrange(x, "b (t h w) d -> b d t h w", t=t, h=h, w=w)
+ if npu_config is None:
+ x = self.layer(x) + (x if self.down_shortcut else 0)
+ else:
+ x_dtype = x.dtype
+ x = npu_config.run_conv3d(self.layer, x, x_dtype) + (x if self.down_shortcut else 0)
+
+ self.t = t // self.down_factor[0]
+ self.h = h // self.down_factor[1]
+ self.w = w // self.down_factor[2]
+ x = rearrange(
+ x,
+ "b d (t dt) (h dh) (w dw) -> (b dt dh dw) (t h w) d",
+ t=t // self.down_factor[0],
+ h=h // self.down_factor[1],
+ w=w // self.down_factor[2],
+ dt=self.down_factor[0],
+ dh=self.down_factor[1],
+ dw=self.down_factor[2],
+ )
+
+ attention_mask = rearrange(attention_mask, "b 1 (t h w) -> b 1 t h w", t=t, h=h, w=w)
+ attention_mask = rearrange(
+ attention_mask,
+ "b 1 (t dt) (h dh) (w dw) -> (b dt dh dw) 1 (t h w)",
+ t=t // self.down_factor[0],
+ h=h // self.down_factor[1],
+ w=w // self.down_factor[2],
+ dt=self.down_factor[0],
+ dh=self.down_factor[1],
+ dw=self.down_factor[2],
+ )
+ return x, attention_mask
+
+ def reverse(self, x, t, h, w):
+ x = rearrange(
+ x,
+ "(b dt dh dw) (t h w) d -> b (t dt h dh w dw) d",
+ t=t,
+ h=h,
+ w=w,
+ dt=self.down_factor[0],
+ dh=self.down_factor[1],
+ dw=self.down_factor[2],
+ )
+ return x
+
+
+class DownSampler2d(nn.Module):
+ def __init__(self, *args, **kwargs):
+ """Required kwargs: down_factor, downsampler"""
+ super().__init__()
+ self.down_factor = kwargs.pop("down_factor")
+ self.down_shortcut = kwargs.pop("down_shortcut")
+ self.layer = nn.Conv2d(*args, **kwargs)
+
+ def forward(self, x, attention_mask, t, h, w):
+ x.shape[0]
+ x = rearrange(x, "b (t h w) d -> (b t) d h w", t=t, h=h, w=w)
+ x = self.layer(x) + (x if self.down_shortcut else 0)
+
+ self.t = 1
+ self.h = h // self.down_factor[0]
+ self.w = w // self.down_factor[1]
+
+ x = rearrange(
+ x,
+ "b d (h dh) (w dw) -> (b dh dw) (h w) d",
+ h=h // self.down_factor[0],
+ w=w // self.down_factor[1],
+ dh=self.down_factor[0],
+ dw=self.down_factor[1],
+ )
+
+ attention_mask = rearrange(attention_mask, "b 1 (t h w) -> (b t) 1 h w", h=h, w=w)
+ attention_mask = rearrange(
+ attention_mask,
+ "b 1 (h dh) (w dw) -> (b dh dw) 1 (h w)",
+ h=h // self.down_factor[0],
+ w=w // self.down_factor[1],
+ dh=self.down_factor[0],
+ dw=self.down_factor[1],
+ )
+ return x, attention_mask
+
+ def reverse(self, x, t, h, w):
+ x = rearrange(
+ x, "(b t dh dw) (h w) d -> b (t h dh w dw) d", t=t, h=h, w=w, dh=self.down_factor[0], dw=self.down_factor[1]
+ )
+ return x
+
+
+class AttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self, attention_mode="xformers", use_rope=False, interpolation_scale_thw=(1, 1, 1)):
+ self.use_rope = use_rope
+ self.interpolation_scale_thw = interpolation_scale_thw
+ if self.use_rope:
+ self._init_rope(interpolation_scale_thw)
+ self.attention_mode = attention_mode
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def _init_rope(self, interpolation_scale_thw):
+ self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw)
+ self.position_getter = PositionGetter3D()
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ frame: int = 8,
+ height: int = 16,
+ width: int = 16,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ if attn.downsampler is not None:
+ hidden_states, attention_mask = attn.downsampler(hidden_states, attention_mask, t=frame, h=height, w=width)
+ frame, height, width = attn.downsampler.t, attn.downsampler.h, attn.downsampler.w
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.parallel_manager.sp_size > 1 and query.shape[2] == key.shape[2]:
+ func = lambda x: all_to_all_comm(
+ x, process_group=attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2
+ )
+ query, key, value = map(func, [query, key, value])
+
+ if self.use_rope:
+ # require the shape of (batch_size x nheads x ntokens x dim)
+ pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device)
+ query = self.rope(query, pos_thw)
+ key = self.rope(key, pos_thw)
+
+ # 0, -10000 ->(bool) False, True ->(any) True ->(not) False
+ # 0, 0 ->(bool) False, False ->(any) False ->(not) True
+ if attention_mask is None or not torch.any(attention_mask.bool()): # 0 mean visible
+ attention_mask = None
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ if attn.parallel_manager.sp_size > 1 and query.shape[2] == key.shape[2]:
+ hidden_states = all_to_all_comm(
+ hidden_states, process_group=attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ if attn.downsampler is not None:
+ hidden_states = attn.downsampler.reverse(hidden_states, t=frame, h=height, w=width)
+ return hidden_states
+
+
+class FeedForward_Conv3d(nn.Module):
+ def __init__(self, downsampler, dim, hidden_features, bias=True):
+ super(FeedForward_Conv3d, self).__init__()
+
+ self.bias = bias
+
+ self.project_in = nn.Linear(dim, hidden_features, bias=bias)
+
+ self.dwconv = nn.ModuleList(
+ [
+ nn.Conv3d(
+ hidden_features,
+ hidden_features,
+ kernel_size=(5, 5, 5),
+ stride=1,
+ padding=(2, 2, 2),
+ dilation=1,
+ groups=hidden_features,
+ bias=bias,
+ ),
+ nn.Conv3d(
+ hidden_features,
+ hidden_features,
+ kernel_size=(3, 3, 3),
+ stride=1,
+ padding=(1, 1, 1),
+ dilation=1,
+ groups=hidden_features,
+ bias=bias,
+ ),
+ nn.Conv3d(
+ hidden_features,
+ hidden_features,
+ kernel_size=(1, 1, 1),
+ stride=1,
+ padding=(0, 0, 0),
+ dilation=1,
+ groups=hidden_features,
+ bias=bias,
+ ),
+ ]
+ )
+
+ self.project_out = nn.Linear(hidden_features, dim, bias=bias)
+
+ def forward(self, x, t, h, w):
+ # import ipdb;ipdb.set_trace()
+ if npu_config is None:
+ x = self.project_in(x)
+ x = rearrange(x, "b (t h w) d -> b d t h w", t=t, h=h, w=w)
+ x = F.gelu(x)
+ out = x
+ for module in self.dwconv:
+ out = out + module(x)
+ out = rearrange(out, "b d t h w -> b (t h w) d", t=t, h=h, w=w)
+ x = self.project_out(out)
+ else:
+ x_dtype = x.dtype
+ x = npu_config.run_conv3d(self.project_in, x, npu_config.replaced_type)
+ x = rearrange(x, "b (t h w) d -> b d t h w", t=t, h=h, w=w)
+ x = F.gelu(x)
+ out = x
+ for module in self.dwconv:
+ out = out + npu_config.run_conv3d(module, x, npu_config.replaced_type)
+ out = rearrange(out, "b d t h w -> b (t h w) d", t=t, h=h, w=w)
+ x = npu_config.run_conv3d(self.project_out, out, x_dtype)
+ return x
+
+
+class FeedForward_Conv2d(nn.Module):
+ def __init__(self, downsampler, dim, hidden_features, bias=True):
+ super(FeedForward_Conv2d, self).__init__()
+
+ self.bias = bias
+
+ self.project_in = nn.Linear(dim, hidden_features, bias=bias)
+
+ self.dwconv = nn.ModuleList(
+ [
+ nn.Conv2d(
+ hidden_features,
+ hidden_features,
+ kernel_size=(5, 5),
+ stride=1,
+ padding=(2, 2),
+ dilation=1,
+ groups=hidden_features,
+ bias=bias,
+ ),
+ nn.Conv2d(
+ hidden_features,
+ hidden_features,
+ kernel_size=(3, 3),
+ stride=1,
+ padding=(1, 1),
+ dilation=1,
+ groups=hidden_features,
+ bias=bias,
+ ),
+ nn.Conv2d(
+ hidden_features,
+ hidden_features,
+ kernel_size=(1, 1),
+ stride=1,
+ padding=(0, 0),
+ dilation=1,
+ groups=hidden_features,
+ bias=bias,
+ ),
+ ]
+ )
+
+ self.project_out = nn.Linear(hidden_features, dim, bias=bias)
+
+ def forward(self, x, t, h, w):
+ # import ipdb;ipdb.set_trace()
+ x = self.project_in(x)
+ x = rearrange(x, "b (t h w) d -> (b t) d h w", t=t, h=h, w=w)
+ x = F.gelu(x)
+ out = x
+ for module in self.dwconv:
+ out = out + module(x)
+ out = rearrange(out, "(b t) d h w -> b (t h w) d", t=t, h=h, w=w)
+ x = self.project_out(out)
+ return x
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
+ ada_norm_bias: Optional[int] = None,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ attention_mode: str = "xformers",
+ downsampler: str = None,
+ use_rope: bool = False,
+ interpolation_scale_thw: Tuple[int] = (1, 1, 1),
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ self.downsampler = downsampler
+
+ # We keep these boolean flags for backward-compatibility.
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ self.norm_type = norm_type
+ self.num_embeds_ada_norm = num_embeds_ada_norm
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if norm_type == "ada_norm":
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_zero":
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_continuous":
+ self.norm1 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "rms_norm",
+ )
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ attention_mode=attention_mode,
+ downsampler=downsampler,
+ use_rope=use_rope,
+ interpolation_scale_thw=interpolation_scale_thw,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ if norm_type == "ada_norm":
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_continuous":
+ self.norm2 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "rms_norm",
+ )
+ else:
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ attention_mode=attention_mode,
+ downsampler=False,
+ use_rope=False,
+ interpolation_scale_thw=interpolation_scale_thw,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ if norm_type == "ada_norm_continuous":
+ self.norm3 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "layer_norm",
+ )
+
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]:
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+ elif norm_type == "layer_norm_i2vgen":
+ self.norm3 = None
+
+ if downsampler:
+ downsampler_ker_size = list(re.search(r"k(\d{2,3})", downsampler).group(1)) # 122
+ # if len(downsampler_ker_size) == 3:
+ # self.ff = FeedForward_Conv3d(
+ # downsampler,
+ # dim,
+ # 2 * dim,
+ # bias=ff_bias,
+ # )
+ # elif len(downsampler_ker_size) == 2:
+ self.ff = FeedForward_Conv2d(
+ downsampler,
+ dim,
+ 2 * dim,
+ bias=ff_bias,
+ )
+ else:
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if norm_type == "ada_norm_single":
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ # pab
+ self.spatial_last = None
+ self.spatial_count = 0
+ self.cross_last = None
+ self.cross_count = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ frame: int = None,
+ height: int = None,
+ width: int = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ org_timestep: Optional[torch.LongTensor] = None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ # import ipdb;ipdb.set_trace()
+ if self.norm_type == "ada_norm_single":
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ else:
+ raise ValueError("Incorrect norm used")
+
+ # 1. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ broadcast, self.spatial_count = if_broadcast_spatial(int(org_timestep[0]), self.spatial_count)
+ if broadcast:
+ attn_output = self.spatial_last
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ frame=frame,
+ height=height,
+ width=width,
+ **cross_attention_kwargs,
+ )
+
+ if enable_pab():
+ self.spatial_last = attn_output
+
+ if self.norm_type == "ada_norm_zero":
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.norm_type == "ada_norm_single":
+ attn_output = gate_msa * attn_output
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 1.2 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ broadcast, self.cross_count = if_broadcast_cross(int(org_timestep[0]), self.cross_count)
+ if broadcast:
+ attn_output = self.cross_last
+
+ else:
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
+ norm_hidden_states = self.norm2(hidden_states)
+ elif self.norm_type == "ada_norm_single":
+ # For PixArt norm2 isn't applied here:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ norm_hidden_states = hidden_states
+ elif self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ if enable_pab():
+ self.cross_last = attn_output
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ # i2vgen doesn't have this norm 🤷♂️
+ if self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ elif not self.norm_type == "ada_norm_single":
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.norm_type == "ada_norm_zero":
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.norm_type == "ada_norm_single":
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ # if self._chunk_size is not None:
+ # # "feed_forward_chunk_size" can be used to save memory
+ # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ # else:
+
+ if self.downsampler:
+ ff_output = self.ff(norm_hidden_states, t=frame, h=height, w=width)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.norm_type == "ada_norm_zero":
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.norm_type == "ada_norm_single":
+ ff_output = gate_mlp * ff_output
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+
+def to_2tuple(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return (x, x)
+
+
+class OpenSoraT2V(ModelMixin, ConfigMixin):
+ """
+ A 2D Transformer model for image-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ num_vector_embeds (`int`, *optional*):
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ sample_size_t: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ patch_size_t: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ attention_type: str = "default",
+ caption_channels: int = None,
+ interpolation_scale_h: float = None,
+ interpolation_scale_w: float = None,
+ interpolation_scale_t: float = None,
+ use_additional_conditions: Optional[bool] = None,
+ attention_mode: str = "xformers",
+ downsampler: str = None,
+ use_rope: bool = False,
+ use_stable_fp32: bool = False,
+ ):
+ super().__init__()
+
+ # Validate inputs.
+ if patch_size is not None:
+ if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
+ raise NotImplementedError(
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
+ )
+ elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
+ )
+
+ # Set some common variables used across the board.
+ self.use_rope = use_rope
+ self.use_linear_projection = use_linear_projection
+ self.interpolation_scale_t = interpolation_scale_t
+ self.interpolation_scale_h = interpolation_scale_h
+ self.interpolation_scale_w = interpolation_scale_w
+ self.downsampler = downsampler
+ self.caption_channels = caption_channels
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.gradient_checkpointing = False
+ self.config.hidden_size = self.inner_dim
+ use_additional_conditions = False
+ # if use_additional_conditions is None:
+ # if norm_type == "ada_norm_single" and sample_size == 128:
+ # use_additional_conditions = True
+ # else:
+ # use_additional_conditions = False
+ self.use_additional_conditions = use_additional_conditions
+
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ assert in_channels is not None and patch_size is not None
+
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
+ deprecation_message = (
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+ " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+ )
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
+ norm_type = "ada_norm"
+
+ # 2. Initialize the right blocks.
+ # Initialize the output blocks and other projection blocks when necessary.
+ self._init_patched_inputs(norm_type=norm_type)
+
+ # parallel
+ self.parallel_manager: ParallelManager = None
+
+ def _init_patched_inputs(self, norm_type):
+ assert self.config.sample_size_t is not None, "OpenSoraT2V over patched input must provide sample_size_t"
+ assert self.config.sample_size is not None, "OpenSoraT2V over patched input must provide sample_size"
+ # assert not (self.config.sample_size_t == 1 and self.config.patch_size_t == 2), "Image do not need patchfy in t-dim"
+
+ self.num_frames = self.config.sample_size_t
+ self.config.sample_size = to_2tuple(self.config.sample_size)
+ self.height = self.config.sample_size[0]
+ self.width = self.config.sample_size[1]
+ self.patch_size_t = self.config.patch_size_t
+ self.patch_size = self.config.patch_size
+ interpolation_scale_t = (
+ ((self.config.sample_size_t - 1) // 16 + 1)
+ if self.config.sample_size_t % 2 == 1
+ else self.config.sample_size_t / 16
+ )
+ interpolation_scale_t = (
+ self.config.interpolation_scale_t
+ if self.config.interpolation_scale_t is not None
+ else interpolation_scale_t
+ )
+ interpolation_scale = (
+ self.config.interpolation_scale_h
+ if self.config.interpolation_scale_h is not None
+ else self.config.sample_size[0] / 30,
+ self.config.interpolation_scale_w
+ if self.config.interpolation_scale_w is not None
+ else self.config.sample_size[1] / 40,
+ )
+ if self.config.downsampler is not None and len(self.config.downsampler) == 9:
+ self.pos_embed = OverlapPatchEmbed3D(
+ num_frames=self.config.sample_size_t,
+ height=self.config.sample_size[0],
+ width=self.config.sample_size[1],
+ patch_size_t=self.config.patch_size_t,
+ patch_size=self.config.patch_size,
+ in_channels=self.in_channels,
+ embed_dim=self.inner_dim,
+ interpolation_scale=interpolation_scale,
+ interpolation_scale_t=interpolation_scale_t,
+ use_abs_pos=not self.config.use_rope,
+ )
+ elif self.config.downsampler is not None and len(self.config.downsampler) == 7:
+ self.pos_embed = OverlapPatchEmbed2D(
+ num_frames=self.config.sample_size_t,
+ height=self.config.sample_size[0],
+ width=self.config.sample_size[1],
+ patch_size_t=self.config.patch_size_t,
+ patch_size=self.config.patch_size,
+ in_channels=self.in_channels,
+ embed_dim=self.inner_dim,
+ interpolation_scale=interpolation_scale,
+ interpolation_scale_t=interpolation_scale_t,
+ use_abs_pos=not self.config.use_rope,
+ )
+
+ else:
+ self.pos_embed = PatchEmbed2D(
+ num_frames=self.config.sample_size_t,
+ height=self.config.sample_size[0],
+ width=self.config.sample_size[1],
+ patch_size_t=self.config.patch_size_t,
+ patch_size=self.config.patch_size,
+ in_channels=self.in_channels,
+ embed_dim=self.inner_dim,
+ interpolation_scale=interpolation_scale,
+ interpolation_scale_t=interpolation_scale_t,
+ use_abs_pos=not self.config.use_rope,
+ )
+ interpolation_scale_thw = (interpolation_scale_t, *interpolation_scale)
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ self.inner_dim,
+ self.config.num_attention_heads,
+ self.config.attention_head_dim,
+ dropout=self.config.dropout,
+ cross_attention_dim=self.config.cross_attention_dim,
+ activation_fn=self.config.activation_fn,
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
+ attention_bias=self.config.attention_bias,
+ only_cross_attention=self.config.only_cross_attention,
+ double_self_attention=self.config.double_self_attention,
+ upcast_attention=self.config.upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
+ norm_eps=self.config.norm_eps,
+ attention_type=self.config.attention_type,
+ attention_mode=self.config.attention_mode,
+ downsampler=self.config.downsampler,
+ use_rope=self.config.use_rope,
+ interpolation_scale_thw=interpolation_scale_thw,
+ )
+ for _ in range(self.config.num_layers)
+ ]
+ )
+
+ if self.config.norm_type != "ada_norm_single":
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
+ self.proj_out_2 = nn.Linear(
+ self.inner_dim,
+ self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels,
+ )
+ elif self.config.norm_type == "ada_norm_single":
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
+ self.proj_out = nn.Linear(
+ self.inner_dim,
+ self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels,
+ )
+
+ # PixArt-Alpha blocks.
+ self.adaln_single = None
+ if self.config.norm_type == "ada_norm_single":
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
+ # additional conditions until we find better name
+ self.adaln_single = AdaLayerNormSingle(
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
+ )
+
+ self.caption_projection = None
+ if self.caption_channels is not None:
+ self.caption_projection = PixArtAlphaTextProjection(
+ in_features=self.caption_channels, hidden_size=self.inner_dim
+ )
+
+ def enable_parallel(self, dp_size, sp_size, enable_cp):
+ # update cfg parallel
+ if enable_cp and sp_size % 2 == 0:
+ sp_size = sp_size // 2
+ cp_size = 2
+ else:
+ cp_size = 1
+
+ self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size)
+
+ for _, module in self.named_modules():
+ if hasattr(module, "parallel_manager"):
+ module.parallel_manager = self.parallel_manager
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: Optional[torch.LongTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_image_num: Optional[int] = 0,
+ return_dict: bool = True,
+ **kwargs,
+ ):
+ """
+ The [`Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ attention_mask ( `torch.Tensor`, *optional*):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ batch_size, c, frame, h, w = hidden_states.shape
+ # print('hidden_states.shape', hidden_states.shape)
+ frame = frame - use_image_num # 21-4=17
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ print.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ attention_mask_vid, attention_mask_img = None, None
+ if attention_mask is not None and attention_mask.ndim == 4:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ # b, frame+use_image_num, h, w -> a video with images
+ # b, 1, h, w -> only images
+ attention_mask = attention_mask.to(self.dtype)
+ attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w
+ attention_mask_img = attention_mask[:, frame:] # b, use_image_num, h, w
+
+ if attention_mask_vid.numel() > 0:
+ attention_mask_vid_first_frame = attention_mask_vid[:, :1].repeat(1, self.patch_size_t - 1, 1, 1)
+ attention_mask_vid = torch.cat([attention_mask_vid_first_frame, attention_mask_vid], dim=1)
+ attention_mask_vid = attention_mask_vid.unsqueeze(1) # b 1 t h w
+ attention_mask_vid = F.max_pool3d(
+ attention_mask_vid,
+ kernel_size=(self.patch_size_t, self.patch_size, self.patch_size),
+ stride=(self.patch_size_t, self.patch_size, self.patch_size),
+ )
+ attention_mask_vid = rearrange(attention_mask_vid, "b 1 t h w -> (b 1) 1 (t h w)")
+ if attention_mask_img.numel() > 0:
+ attention_mask_img = F.max_pool2d(
+ attention_mask_img,
+ kernel_size=(self.patch_size, self.patch_size),
+ stride=(self.patch_size, self.patch_size),
+ )
+ attention_mask_img = rearrange(attention_mask_img, "b i h w -> (b i) 1 (h w)")
+
+ attention_mask_vid = (
+ (1 - attention_mask_vid.bool().to(self.dtype)) * -10000.0 if attention_mask_vid.numel() > 0 else None
+ )
+ attention_mask_img = (
+ (1 - attention_mask_img.bool().to(self.dtype)) * -10000.0 if attention_mask_img.numel() > 0 else None
+ )
+
+ if frame == 1 and use_image_num == 0:
+ attention_mask_img = attention_mask_vid
+ attention_mask_vid = None
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3:
+ # b, 1+use_image_num, l -> a video with images
+ # b, 1, l -> only images
+ encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
+ in_t = encoder_attention_mask.shape[1]
+ encoder_attention_mask_vid = encoder_attention_mask[:, : in_t - use_image_num] # b, 1, l
+ encoder_attention_mask_vid = (
+ rearrange(encoder_attention_mask_vid, "b 1 l -> (b 1) 1 l")
+ if encoder_attention_mask_vid.numel() > 0
+ else None
+ )
+
+ encoder_attention_mask_img = encoder_attention_mask[:, in_t - use_image_num :] # b, use_image_num, l
+ encoder_attention_mask_img = (
+ rearrange(encoder_attention_mask_img, "b i l -> (b i) 1 l")
+ if encoder_attention_mask_img.numel() > 0
+ else None
+ )
+
+ if frame == 1 and use_image_num == 0:
+ encoder_attention_mask_img = encoder_attention_mask_vid
+ encoder_attention_mask_vid = None
+
+ if npu_config is not None and attention_mask_vid is not None:
+ attention_mask_vid = npu_config.get_attention_mask(attention_mask_vid, attention_mask_vid.shape[-1])
+ encoder_attention_mask_vid = npu_config.get_attention_mask(
+ encoder_attention_mask_vid, attention_mask_vid.shape[-2]
+ )
+ if npu_config is not None and attention_mask_img is not None:
+ attention_mask_img = npu_config.get_attention_mask(attention_mask_img, attention_mask_img.shape[-1])
+ encoder_attention_mask_img = npu_config.get_attention_mask(
+ encoder_attention_mask_img, attention_mask_img.shape[-2]
+ )
+
+ # 1. Input
+ frame = ((frame - 1) // self.patch_size_t + 1) if frame % 2 == 1 else frame // self.patch_size_t # patchfy
+ # print('frame', frame)
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
+
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+ (
+ hidden_states_vid,
+ hidden_states_img,
+ encoder_hidden_states_vid,
+ encoder_hidden_states_img,
+ timestep_vid,
+ timestep_img,
+ embedded_timestep_vid,
+ embedded_timestep_img,
+ ) = self._operate_on_patched_inputs(
+ hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size, frame, use_image_num
+ )
+ # 2. Blocks
+ if self.parallel_manager.sp_size > 1:
+ if hidden_states_vid is not None:
+ hidden_states_vid = split_sequence(
+ hidden_states_vid, dim=1, process_group=self.parallel_manager.sp_group, grad_scale="down"
+ )
+
+ for block in self.transformer_blocks:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ if hidden_states_vid is not None:
+ hidden_states_vid = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states_vid,
+ attention_mask_vid,
+ encoder_hidden_states_vid,
+ encoder_attention_mask_vid,
+ timestep_vid,
+ cross_attention_kwargs,
+ class_labels,
+ frame,
+ height,
+ width,
+ **ckpt_kwargs,
+ )
+ if hidden_states_img is not None:
+ hidden_states_img = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states_img,
+ attention_mask_img,
+ encoder_hidden_states_img,
+ encoder_attention_mask_img,
+ timestep_img,
+ cross_attention_kwargs,
+ class_labels,
+ 1,
+ height,
+ width,
+ **ckpt_kwargs,
+ )
+ else:
+ if hidden_states_vid is not None:
+ hidden_states_vid = block(
+ hidden_states_vid,
+ attention_mask=attention_mask_vid,
+ encoder_hidden_states=encoder_hidden_states_vid,
+ encoder_attention_mask=encoder_attention_mask_vid,
+ timestep=timestep_vid,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ frame=frame,
+ height=height,
+ width=width,
+ org_timestep=timestep,
+ )
+ if hidden_states_img is not None:
+ hidden_states_img = block(
+ hidden_states_img,
+ attention_mask=attention_mask_img,
+ encoder_hidden_states=encoder_hidden_states_img,
+ encoder_attention_mask=encoder_attention_mask_img,
+ timestep=timestep_img,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ frame=1,
+ height=height,
+ width=width,
+ org_timestep=timestep,
+ )
+
+ if self.parallel_manager.sp_size > 1:
+ if hidden_states_vid is not None:
+ hidden_states_vid = gather_sequence(
+ hidden_states_vid, dim=1, process_group=self.parallel_manager.sp_group, grad_scale="up"
+ )
+
+ # 3. Output
+ output_vid, output_img = None, None
+ if hidden_states_vid is not None:
+ output_vid = self._get_output_for_patched_inputs(
+ hidden_states=hidden_states_vid,
+ timestep=timestep_vid,
+ class_labels=class_labels,
+ embedded_timestep=embedded_timestep_vid,
+ num_frames=frame,
+ height=height,
+ width=width,
+ ) # b c t h w
+ if hidden_states_img is not None:
+ output_img = self._get_output_for_patched_inputs(
+ hidden_states=hidden_states_img,
+ timestep=timestep_img,
+ class_labels=class_labels,
+ embedded_timestep=embedded_timestep_img,
+ num_frames=1,
+ height=height,
+ width=width,
+ ) # b c 1 h w
+ if use_image_num != 0:
+ output_img = rearrange(output_img, "(b i) c 1 h w -> b c i h w", i=use_image_num)
+
+ if output_vid is not None and output_img is not None:
+ output = torch.cat([output_vid, output_img], dim=2)
+ elif output_vid is not None:
+ output = output_vid
+ elif output_img is not None:
+ output = output_img
+
+ if not return_dict:
+ return (output,)
+
+ return VideoSysPipelineOutput(video=output)
+
+ def _operate_on_patched_inputs(
+ self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size, frame, use_image_num
+ ):
+ # batch_size = hidden_states.shape[0]
+ hidden_states_vid, hidden_states_img = self.pos_embed(hidden_states.to(self.dtype), frame)
+ timestep_vid, timestep_img = None, None
+ embedded_timestep_vid, embedded_timestep_img = None, None
+ encoder_hidden_states_vid, encoder_hidden_states_img = None, None
+
+ if self.adaln_single is not None:
+ if self.use_additional_conditions and added_cond_kwargs is None:
+ raise ValueError(
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
+ )
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype
+ ) # b 6d, b d
+ if hidden_states_vid is None:
+ timestep_img = timestep
+ embedded_timestep_img = embedded_timestep
+ else:
+ timestep_vid = timestep
+ embedded_timestep_vid = embedded_timestep
+ if hidden_states_img is not None:
+ timestep_img = repeat(timestep, "b d -> (b i) d", i=use_image_num).contiguous()
+ embedded_timestep_img = repeat(embedded_timestep, "b d -> (b i) d", i=use_image_num).contiguous()
+
+ if self.caption_projection is not None:
+ encoder_hidden_states = self.caption_projection(
+ encoder_hidden_states
+ ) # b, 1+use_image_num, l, d or b, 1, l, d
+ if hidden_states_vid is None:
+ encoder_hidden_states_img = rearrange(encoder_hidden_states, "b 1 l d -> (b 1) l d")
+ else:
+ encoder_hidden_states_vid = rearrange(encoder_hidden_states[:, :1], "b 1 l d -> (b 1) l d")
+ if hidden_states_img is not None:
+ encoder_hidden_states_img = rearrange(encoder_hidden_states[:, 1:], "b i l d -> (b i) l d")
+
+ return (
+ hidden_states_vid,
+ hidden_states_img,
+ encoder_hidden_states_vid,
+ encoder_hidden_states_img,
+ timestep_vid,
+ timestep_img,
+ embedded_timestep_vid,
+ embedded_timestep_img,
+ )
+
+ def _get_output_for_patched_inputs(
+ self, hidden_states, timestep, class_labels, embedded_timestep, num_frames, height=None, width=None
+ ):
+ # import ipdb;ipdb.set_trace()
+ if self.config.norm_type != "ada_norm_single":
+ conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=self.dtype)
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+ elif self.config.norm_type == "ada_norm_single":
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale) + shift
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.squeeze(1)
+
+ # unpatchify
+ if self.adaln_single is None:
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(
+ -1,
+ num_frames,
+ height,
+ width,
+ self.patch_size_t,
+ self.patch_size,
+ self.patch_size,
+ self.out_channels,
+ )
+ )
+ hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(
+ -1,
+ self.out_channels,
+ num_frames * self.patch_size_t,
+ height * self.patch_size,
+ width * self.patch_size,
+ )
+ )
+ # import ipdb;ipdb.set_trace()
+ # if output.shape[2] % 2 == 0:
+ # output = output[:, :, 1:]
+ return output
+
+
+def OpenSoraT2V_S_122(**kwargs):
+ return OpenSoraT2V(
+ num_layers=28,
+ attention_head_dim=96,
+ num_attention_heads=16,
+ patch_size_t=1,
+ patch_size=2,
+ norm_type="ada_norm_single",
+ caption_channels=4096,
+ cross_attention_dim=1536,
+ **kwargs,
+ )
+
+
+def OpenSoraT2V_B_122(**kwargs):
+ return OpenSoraT2V(
+ num_layers=32,
+ attention_head_dim=96,
+ num_attention_heads=16,
+ patch_size_t=1,
+ patch_size=2,
+ norm_type="ada_norm_single",
+ caption_channels=4096,
+ cross_attention_dim=1920,
+ **kwargs,
+ )
+
+
+def OpenSoraT2V_L_122(**kwargs):
+ return OpenSoraT2V(
+ num_layers=40,
+ attention_head_dim=128,
+ num_attention_heads=16,
+ patch_size_t=1,
+ patch_size=2,
+ norm_type="ada_norm_single",
+ caption_channels=4096,
+ cross_attention_dim=2048,
+ **kwargs,
+ )
+
+
+def OpenSoraT2V_ROPE_L_122(**kwargs):
+ return OpenSoraT2V(
+ num_layers=32,
+ attention_head_dim=96,
+ num_attention_heads=24,
+ patch_size_t=1,
+ patch_size=2,
+ norm_type="ada_norm_single",
+ caption_channels=4096,
+ cross_attention_dim=2304,
+ **kwargs,
+ )
+
+
+OpenSora_models = {
+ "OpenSoraT2V-S/122": OpenSoraT2V_S_122, # 1.1B
+ "OpenSoraT2V-B/122": OpenSoraT2V_B_122,
+ "OpenSoraT2V-L/122": OpenSoraT2V_L_122,
+ "OpenSoraT2V-ROPE-L/122": OpenSoraT2V_ROPE_L_122,
+}
+
+OpenSora_models_class = {
+ "OpenSoraT2V-S/122": OpenSoraT2V,
+ "OpenSoraT2V-B/122": OpenSoraT2V,
+ "OpenSoraT2V-L/122": OpenSoraT2V,
+ "OpenSoraT2V-ROPE-L/122": OpenSoraT2V,
+}
diff --git a/videosys/models/transformers/open_sora_transformer_3d.py b/videosys/models/transformers/open_sora_transformer_3d.py
index 4a9c213..27aec68 100644
--- a/videosys/models/transformers/open_sora_transformer_3d.py
+++ b/videosys/models/transformers/open_sora_transformer_3d.py
@@ -20,15 +20,7 @@ from timm.models.vision_transformer import Mlp
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from transformers import PretrainedConfig, PreTrainedModel
-from videosys.core.comm import (
- all_to_all_with_pad,
- gather_sequence,
- get_spatial_pad,
- get_temporal_pad,
- set_spatial_pad,
- set_temporal_pad,
- split_sequence,
-)
+from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
from videosys.core.pab_mgr import (
enable_pab,
get_mlp_output,
@@ -38,12 +30,7 @@ from videosys.core.pab_mgr import (
if_broadcast_temporal,
save_mlp_output,
)
-from videosys.core.parallel_mgr import (
- enable_sequence_parallel,
- get_cfg_parallel_size,
- get_data_parallel_group,
- get_sequence_parallel_group,
-)
+from videosys.core.parallel_mgr import ParallelManager
from videosys.models.modules.activations import approx_gelu
from videosys.models.modules.attentions import OpenSoraAttention, OpenSoraMultiHeadCrossAttention
from videosys.models.modules.embeddings import (
@@ -143,6 +130,9 @@ class STDiT3Block(nn.Module):
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
+ # parallel
+ self.parallel_manager: ParallelManager = None
+
# pab
self.block_idx = block_idx
self.attn_count = 0
@@ -188,9 +178,7 @@ class STDiT3Block(nn.Module):
if self.temporal:
broadcast_attn, self.attn_count = if_broadcast_temporal(int(timestep[0]), self.attn_count)
else:
- broadcast_attn, self.attn_count = if_broadcast_spatial(
- int(timestep[0]), self.attn_count, self.block_idx
- )
+ broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count)
if enable_pab() and broadcast_attn:
x_m_s = self.last_attn
@@ -203,12 +191,12 @@ class STDiT3Block(nn.Module):
# attention
if self.temporal:
- if enable_sequence_parallel():
+ if self.parallel_manager.sp_size > 1:
x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=True)
x_m = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S)
x_m = self.attn(x_m)
x_m = rearrange(x_m, "(B S) T C -> B (T S) C", T=T, S=S)
- if enable_sequence_parallel():
+ if self.parallel_manager.sp_size > 1:
x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=False)
else:
x_m = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S)
@@ -287,17 +275,17 @@ class STDiT3Block(nn.Module):
def dynamic_switch(self, x, s, t, to_spatial_shard: bool):
if to_spatial_shard:
scatter_dim, gather_dim = 2, 1
- scatter_pad = get_spatial_pad()
- gather_pad = get_temporal_pad()
+ scatter_pad = get_pad("spatial")
+ gather_pad = get_pad("temporal")
else:
scatter_dim, gather_dim = 1, 2
- scatter_pad = get_temporal_pad()
- gather_pad = get_spatial_pad()
+ scatter_pad = get_pad("temporal")
+ gather_pad = get_pad("spatial")
x = rearrange(x, "b (t s) d -> b t s d", t=t, s=s)
x = all_to_all_with_pad(
x,
- get_sequence_parallel_group(),
+ self.parallel_manager.sp_group,
scatter_dim=scatter_dim,
gather_dim=gather_dim,
scatter_pad=scatter_pad,
@@ -449,6 +437,24 @@ class STDiT3(PreTrainedModel):
for param in self.y_embedder.parameters():
param.requires_grad = False
+ # parallel
+ self.parallel_manager: ParallelManager = None
+
+ def enable_parallel(self, dp_size, sp_size, enable_cp):
+ # update cfg parallel
+ if enable_cp and sp_size % 2 == 0:
+ sp_size = sp_size // 2
+ cp_size = 2
+ else:
+ cp_size = 1
+
+ self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size)
+
+ for name, module in self.named_modules():
+ if "spatial_blocks" in name or "temporal_blocks" in name:
+ if hasattr(module, "parallel_manager"):
+ module.parallel_manager = self.parallel_manager
+
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
@@ -501,9 +507,14 @@ class STDiT3(PreTrainedModel):
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
):
# === Split batch ===
- if get_cfg_parallel_size() > 1:
+ if self.parallel_manager.cp_size > 1:
x, timestep, y, x_mask, mask = batch_func(
- partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
+ partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
+ x,
+ timestep,
+ y,
+ x_mask,
+ mask,
)
dtype = self.x_embedder.proj.weight.dtype
@@ -547,14 +558,14 @@ class STDiT3(PreTrainedModel):
x = x + pos_emb
# shard over the sequence dim if sp is enabled
- if enable_sequence_parallel():
- set_temporal_pad(T)
- set_spatial_pad(S)
- x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
+ if self.parallel_manager.sp_size > 1:
+ set_pad("temporal", T, self.parallel_manager.sp_group)
+ set_pad("spatial", S, self.parallel_manager.sp_group)
+ x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
T = x.shape[1]
x_mask_org = x_mask
x_mask = split_sequence(
- x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
+ x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
@@ -589,9 +600,9 @@ class STDiT3(PreTrainedModel):
all_timesteps=all_timesteps,
)
- if enable_sequence_parallel():
+ if self.parallel_manager.sp_size > 1:
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
- x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad())
+ x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
T, S = x.shape[1], x.shape[2]
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
x_mask = x_mask_org
@@ -604,8 +615,8 @@ class STDiT3(PreTrainedModel):
x = x.to(torch.float32)
# === Gather Output ===
- if get_cfg_parallel_size() > 1:
- x = gather_sequence(x, get_data_parallel_group(), dim=0)
+ if self.parallel_manager.cp_size > 1:
+ x = gather_sequence(x, self.parallel_manager.cp_group, dim=0)
return x
diff --git a/videosys/models/transformers/vchitect_transformer_3d.py b/videosys/models/transformers/vchitect_transformer_3d.py
new file mode 100644
index 0000000..166d654
--- /dev/null
+++ b/videosys/models/transformers/vchitect_transformer_3d.py
@@ -0,0 +1,644 @@
+# Adapted from Vchitect
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Vchitect: https://github.com/Vchitect/Vchitect-2.0
+# diffusers: https://github.com/huggingface/diffusers
+# --------------------------------------------------------
+
+
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
+from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
+from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
+from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
+from diffusers.utils import USE_PEFT_BACKEND, deprecate, unscale_lora_layers
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from einops import rearrange
+from torch import nn
+
+from videosys.core.comm import gather_from_second_dim, set_pad, split_from_second_dim
+from videosys.core.parallel_mgr import ParallelManager
+from videosys.models.modules.attentions import VchitectAttention, VchitectAttnProcessor
+
+
+def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
+ # "feed_forward_chunk_size" can be used to save memory
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
+ ff_output = torch.cat(
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
+ dim=chunk_dim,
+ )
+ return ff_output
+
+
+@maybe_allow_in_graph
+class JointTransformerBlock(nn.Module):
+ r"""
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
+
+ Reference: https://arxiv.org/abs/2403.03206
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
+ processing of `context` conditions.
+ """
+
+ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
+ super().__init__()
+
+ self.context_pre_only = context_pre_only
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
+
+ self.norm1 = AdaLayerNormZero(dim)
+
+ if context_norm_type == "ada_norm_continous":
+ self.norm1_context = AdaLayerNormContinuous(
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
+ )
+ elif context_norm_type == "ada_norm_zero":
+ self.norm1_context = AdaLayerNormZero(dim)
+ else:
+ raise ValueError(
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
+ )
+ processor = VchitectAttnProcessor()
+ self.attn = VchitectAttention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim // num_attention_heads,
+ heads=num_attention_heads,
+ out_dim=attention_head_dim,
+ context_pre_only=context_pre_only,
+ bias=True,
+ processor=processor,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ if not context_pre_only:
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+ else:
+ self.norm2_context = None
+ self.ff_context = None
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ freqs_cis: torch.Tensor,
+ full_seqlen: int,
+ Frame: int,
+ timestep: int,
+ ):
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+ if self.context_pre_only:
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
+ else:
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # Attention.
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ freqs_cis=freqs_cis,
+ full_seqlen=full_seqlen,
+ Frame=Frame,
+ timestep=timestep,
+ )
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ if self.context_pre_only:
+ encoder_hidden_states = None
+ else:
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ context_ff_output = _chunked_feed_forward(
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
+ )
+ else:
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return encoder_hidden_states, hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ inner_dim=None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim, bias=bias)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+class VchitectXLTransformerModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ """
+ The Transformer model introduced in Stable Diffusion 3.
+
+ Reference: https://arxiv.org/abs/2403.03206
+
+ Parameters:
+ sample_size (`int`): The width of the latent images. This is fixed during training since
+ it is used to learn a number of position embeddings.
+ patch_size (`int`): Patch size to turn the input data into small patches.
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
+ out_channels (`int`, defaults to 16): Number of output channels.
+
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 128,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ num_layers: int = 18,
+ attention_head_dim: int = 64,
+ num_attention_heads: int = 18,
+ joint_attention_dim: int = 4096,
+ caption_projection_dim: int = 1152,
+ pooled_projection_dim: int = 2048,
+ out_channels: int = 16,
+ pos_embed_max_size: int = 96,
+ rope_scaling_factor: float = 1.0,
+ ):
+ super().__init__()
+ default_out_channels = in_channels
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+
+ self.pos_embed = PatchEmbed(
+ height=self.config.sample_size,
+ width=self.config.sample_size,
+ patch_size=self.config.patch_size,
+ in_channels=self.config.in_channels,
+ embed_dim=self.inner_dim,
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
+ )
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
+ )
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
+ # `attention_head_dim` is doubled to account for the mixing.
+ # It needs to crafted when we get the actual checkpoints.
+ self.transformer_blocks = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.inner_dim,
+ context_pre_only=i == num_layers - 1,
+ )
+ for i in range(self.config.num_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ # Video param
+ # self.scatter_dim_zero = Identity()
+ self.freqs_cis = VchitectXLTransformerModel.precompute_freqs_cis(
+ self.inner_dim // self.config.num_attention_heads,
+ 1000000,
+ theta=1e6,
+ rope_scaling_factor=rope_scaling_factor, # todo max pos embeds
+ )
+
+ # self.vid_token = nn.Parameter(torch.empty(self.inner_dim))
+
+ # parallel
+ self.parallel_manager = None
+
+ def enable_parallel(self, dp_size, sp_size, enable_cp):
+ # update cfg parallel
+ if enable_cp and sp_size % 2 == 0:
+ sp_size = sp_size // 2
+ cp_size = 2
+ else:
+ cp_size = 1
+
+ self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size)
+
+ for _, module in self.named_modules():
+ if hasattr(module, "parallel_manager"):
+ module.parallel_manager = self.parallel_manager
+
+ @staticmethod
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, rope_scaling_factor: float = 1.0):
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ t = torch.arange(end, device=freqs.device, dtype=torch.float)
+ t = t / rope_scaling_factor
+ freqs = torch.outer(t, freqs).float()
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
+ return freqs_cis
+
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, VchitectAttnProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str, module: torch.nn.Module, processors: Dict[str, VchitectAttnProcessor]
+ ):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[VchitectAttnProcessor, Dict[str, VchitectAttnProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, VchitectAttention):
+ module.fuse_projections(fuse=True)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def patchify_and_embed(self, x):
+ B, F, C, H, W = x.size()
+ x = rearrange(x, "b f c h w -> (b f) c h w")
+ x = self.pos_embed(x) # [B L D]
+ return x, F, [(H, W)] * B
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ pooled_projections: torch.FloatTensor = None,
+ timestep: torch.LongTensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`VchitectXLTransformerModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ height, width = hidden_states.shape[-2:]
+
+ batch_size = hidden_states.shape[0]
+ hidden_states, F_num, _ = self.patchify_and_embed(
+ hidden_states
+ ) # takes care of adding positional embeddings too.
+ full_seq = batch_size * F_num
+
+ self.freqs_cis = self.freqs_cis.to(hidden_states.device)
+ freqs_cis = self.freqs_cis
+ # seq_length = hidden_states.size(1)
+ # freqs_cis = self.freqs_cis[:hidden_states.size(1)*F_num]
+ temb = self.time_text_embed(timestep, pooled_projections)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if self.parallel_manager.sp_size > 1:
+ set_pad("temporal", F_num, self.parallel_manager.sp_group)
+ hidden_states = split_from_second_dim(hidden_states, batch_size, self.parallel_manager.sp_group)
+ cur_temb = temb.repeat(hidden_states.shape[0] // batch_size, 1)
+ else:
+ cur_temb = temb.repeat(F_num, 1)
+
+ for block_idx, block in enumerate(self.transformer_blocks):
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=cur_temb,
+ freqs_cis=freqs_cis,
+ full_seqlen=full_seq,
+ Frame=F_num,
+ timestep=timestep,
+ )
+
+ if self.parallel_manager.sp_size > 1:
+ hidden_states = gather_from_second_dim(hidden_states, batch_size, self.parallel_manager.sp_group)
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # unpatchify
+ # hidden_states = hidden_states[:, :-1] #Drop the video token
+
+ # unpatchify
+ patch_size = self.config.patch_size
+ height = height // patch_size
+ width = width // patch_size
+
+ hidden_states = hidden_states.reshape(
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
+ )
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
+ return list(self.transformer_blocks)
+
+ @classmethod
+ def from_pretrained_temporal(cls, pretrained_model_path, torch_dtype, logger, subfolder=None, tp_size=1):
+ import json
+ import os
+
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+
+ config_file = os.path.join(pretrained_model_path, "config.json")
+
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ config["tp_size"] = tp_size
+ from safetensors.torch import load_file
+
+ model = cls.from_config(config)
+ # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+
+ model_files = [
+ os.path.join(pretrained_model_path, "diffusion_pytorch_model.bin"),
+ os.path.join(pretrained_model_path, "diffusion_pytorch_model.safetensors"),
+ ]
+
+ model_file = None
+
+ for fp in model_files:
+ if os.path.exists(fp):
+ model_file = fp
+
+ if not model_file:
+ raise RuntimeError(f"{model_file} does not exist")
+
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+
+ state_dict = load_file(model_file, device="cpu")
+ m, u = model.load_state_dict(state_dict, strict=False)
+ model = model.to(torch_dtype)
+
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
+ total_params = [p.numel() for n, p in model.named_parameters()]
+
+ if logger is not None:
+ logger.info(f"model_file: {model_file}")
+ logger.info(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ logger.info(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
+ logger.info(f"### Total Parameters: {sum(total_params) / 1e6} M")
+
+ return model
diff --git a/videosys/pipelines/cogvideox/pipeline_cogvideox.py b/videosys/pipelines/cogvideox/pipeline_cogvideox.py
index 7eaee3a..e8bd151 100644
--- a/videosys/pipelines/cogvideox/pipeline_cogvideox.py
+++ b/videosys/pipelines/cogvideox/pipeline_cogvideox.py
@@ -13,6 +13,7 @@ import math
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
+import torch.distributed as dist
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
@@ -26,19 +27,17 @@ from videosys.models.transformers.cogvideox_transformer_3d import CogVideoXTrans
from videosys.schedulers.scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
from videosys.schedulers.scheduling_dpm_cogvideox import CogVideoXDPMScheduler
from videosys.utils.logging import logger
-from videosys.utils.utils import save_video
+from videosys.utils.utils import save_video, set_seed
class CogVideoXPABConfig(PABConfig):
def __init__(
self,
- steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
spatial_range: int = 2,
):
super().__init__(
- steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
@@ -114,7 +113,7 @@ class CogVideoXConfig:
class CogVideoXPipeline(VideoSysPipeline):
- _optional_components = ["text_encoder", "tokenizer"]
+ _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
@@ -158,9 +157,6 @@ class CogVideoXPipeline(VideoSysPipeline):
subfolder="scheduler",
)
- # set eval and device
- self.set_eval_and_device(self._device, vae, transformer)
-
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
@@ -169,7 +165,7 @@ class CogVideoXPipeline(VideoSysPipeline):
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
- self.set_eval_and_device(self._device, text_encoder)
+ self.set_eval_and_device(self._device, text_encoder, vae, transformer)
# vae tiling
if config.vae_tiling:
@@ -187,6 +183,31 @@ class CogVideoXPipeline(VideoSysPipeline):
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ # parallel
+ self._set_parallel()
+
+ def _set_seed(self, seed):
+ if dist.get_world_size() == 1:
+ set_seed(seed)
+ else:
+ set_seed(seed, self.transformer.parallel_manager.dp_rank)
+
+ def _set_parallel(
+ self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
+ ):
+ # init sequence parallel
+ if sp_size is None:
+ sp_size = dist.get_world_size()
+ dp_size = 1
+ else:
+ assert (
+ dist.get_world_size() % sp_size == 0
+ ), f"world_size {dist.get_world_size()} must be divisible by sp_size"
+ dp_size = dist.get_world_size() // sp_size
+
+ # transformer parallel
+ self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
+
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
@@ -474,6 +495,7 @@ class CogVideoXPipeline(VideoSysPipeline):
num_frames: int = 49,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
+ seed: int = -1,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
num_videos_per_prompt: int = 1,
@@ -489,7 +511,6 @@ class CogVideoXPipeline(VideoSysPipeline):
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
- verbose=True,
) -> Union[VideoSysPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -572,6 +593,7 @@ class CogVideoXPipeline(VideoSysPipeline):
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
update_steps(num_inference_steps)
+ self._set_seed(seed)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@@ -653,11 +675,10 @@ class CogVideoXPipeline(VideoSysPipeline):
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
- progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
- # with self.progress_bar(total=num_inference_steps) as progress_bar:
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
- old_pred_original_sample = None
- for i, t in progress_wrap(list(enumerate(timesteps))):
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -674,7 +695,6 @@ class CogVideoXPipeline(VideoSysPipeline):
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
- all_timesteps=timesteps,
)[0]
noise_pred = noise_pred.float()
@@ -713,8 +733,8 @@ class CogVideoXPipeline(VideoSysPipeline):
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
- # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
- # progress_bar.update()
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
if not output_type == "latent":
video = self.decode_latents(latents)
diff --git a/videosys/pipelines/latte/pipeline_latte.py b/videosys/pipelines/latte/pipeline_latte.py
index 7c1d590..437baaa 100644
--- a/videosys/pipelines/latte/pipeline_latte.py
+++ b/videosys/pipelines/latte/pipeline_latte.py
@@ -29,13 +29,12 @@ from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
from videosys.models.transformers.latte_transformer_3d import LatteT2V
from videosys.utils.logging import logger
-from videosys.utils.utils import save_video
+from videosys.utils.utils import save_video, set_seed
class LattePABConfig(PABConfig):
def __init__(
self,
- steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 800],
spatial_range: int = 2,
@@ -62,7 +61,6 @@ class LattePABConfig(PABConfig):
},
):
super().__init__(
- steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
@@ -188,7 +186,7 @@ class LattePipeline(VideoSysPipeline):
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
- _optional_components = ["tokenizer", "text_encoder"]
+ _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
@@ -238,9 +236,6 @@ class LattePipeline(VideoSysPipeline):
if config.enable_pab:
set_pab_manager(config.pab_config)
- # set eval and device
- self.set_eval_and_device(device, vae, transformer)
-
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
@@ -249,11 +244,36 @@ class LattePipeline(VideoSysPipeline):
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
- self.set_eval_and_device(device, text_encoder)
+ self.set_eval_and_device(device, text_encoder, vae, transformer)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ # parallel
+ self._set_parallel()
+
+ def _set_seed(self, seed):
+ if dist.get_world_size() == 1:
+ set_seed(seed)
+ else:
+ set_seed(seed, self.transformer.parallel_manager.dp_rank)
+
+ def _set_parallel(
+ self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
+ ):
+ # init sequence parallel
+ if sp_size is None:
+ sp_size = dist.get_world_size()
+ dp_size = 1
+ else:
+ assert (
+ dist.get_world_size() % sp_size == 0
+ ), f"world_size {dist.get_world_size()} must be divisible by sp_size"
+ dp_size = dist.get_world_size() // sp_size
+
+ # transformer parallel
+ self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
+
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
@@ -658,6 +678,7 @@ class LattePipeline(VideoSysPipeline):
negative_prompt: str = "",
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
+ seed: int = -1,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -745,6 +766,7 @@ class LattePipeline(VideoSysPipeline):
width = 512
update_steps(num_inference_steps)
self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds)
+ self._set_seed(seed)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
diff --git a/videosys/pipelines/open_sora/pipeline_open_sora.py b/videosys/pipelines/open_sora/pipeline_open_sora.py
index 5fd3473..6199f26 100644
--- a/videosys/pipelines/open_sora/pipeline_open_sora.py
+++ b/videosys/pipelines/open_sora/pipeline_open_sora.py
@@ -2,19 +2,22 @@ import html
import json
import os
import re
+import urllib.parse as ul
from typing import Optional, Tuple, Union
import ftfy
import torch
+import torch.distributed as dist
+from bs4 import BeautifulSoup
from diffusers.models import AutoencoderKL
from transformers import AutoTokenizer, T5EncoderModel
-from videosys.core.pab_mgr import PABConfig, set_pab_manager
+from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
from videosys.models.autoencoders.autoencoder_kl_open_sora import OpenSoraVAE_V1_2
from videosys.models.transformers.open_sora_transformer_3d import STDiT3
from videosys.schedulers.scheduling_rflow_open_sora import RFLOW
-from videosys.utils.utils import save_video
+from videosys.utils.utils import save_video, set_seed
from .data_process import get_image_size, get_num_frames, prepare_multi_resolution_info, read_from_path
@@ -29,7 +32,6 @@ BAD_PUNCT_REGEX = re.compile(
class OpenSoraPABConfig(PABConfig):
def __init__(
self,
- steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [450, 930],
spatial_range: int = 2,
@@ -52,7 +54,6 @@ class OpenSoraPABConfig(PABConfig):
},
):
super().__init__(
- steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
@@ -187,11 +188,7 @@ class OpenSoraPipeline(VideoSysPipeline):
bad_punct_regex = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
-
- _optional_components = [
- "text_encoder",
- "tokenizer",
- ]
+ _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
@@ -234,9 +231,6 @@ class OpenSoraPipeline(VideoSysPipeline):
if config.enable_pab:
set_pab_manager(config.pab_config)
- # set eval and device
- self.set_eval_and_device(device, vae, transformer)
-
self.register_modules(
text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, tokenizer=tokenizer
)
@@ -245,7 +239,32 @@ class OpenSoraPipeline(VideoSysPipeline):
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
- self.set_eval_and_device(self._device, text_encoder)
+ self.set_eval_and_device(self._device, vae, transformer, text_encoder)
+
+ # parallel
+ self._set_parallel()
+
+ def _set_seed(self, seed):
+ if dist.get_world_size() == 1:
+ set_seed(seed)
+ else:
+ set_seed(seed, self.transformer.parallel_manager.dp_rank)
+
+ def _set_parallel(
+ self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
+ ):
+ # init sequence parallel
+ if sp_size is None:
+ sp_size = dist.get_world_size()
+ dp_size = 1
+ else:
+ assert (
+ dist.get_world_size() % sp_size == 0
+ ), f"world_size {dist.get_world_size()} must be divisible by sp_size"
+ dp_size = dist.get_world_size() // sp_size
+
+ # transformer parallel
+ self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
def get_text_embeddings(self, texts):
text_tokens_and_mask = self.tokenizer(
@@ -283,10 +302,6 @@ class OpenSoraPipeline(VideoSysPipeline):
return text.strip()
def _clean_caption(self, caption):
- import urllib.parse as ul
-
- from bs4 import BeautifulSoup
-
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
@@ -418,6 +433,7 @@ class OpenSoraPipeline(VideoSysPipeline):
loop: int = 1,
llm_refine: bool = False,
negative_prompt: str = "",
+ seed: int = -1,
ms: Optional[str] = "",
refs: Optional[str] = "",
aes: float = 6.5,
@@ -460,10 +476,6 @@ class OpenSoraPipeline(VideoSysPipeline):
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
- height (`int`, *optional*, defaults to self.unet.config.sample_size):
- The height in pixels of the generated image.
- width (`int`, *optional*, defaults to self.unet.config.sample_size):
- The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
@@ -508,6 +520,8 @@ class OpenSoraPipeline(VideoSysPipeline):
fps = 24
image_size = get_image_size(resolution, aspect_ratio)
num_frames = get_num_frames(num_frames)
+ self._set_seed(seed)
+ update_steps(self._config.num_sampling_steps)
# == prepare batch prompts ==
batch_prompts = [prompt]
@@ -621,7 +635,7 @@ class OpenSoraPipeline(VideoSysPipeline):
progress=verbose,
mask=masks,
)
- samples = self.vae.decode(samples.to(self._dtype), num_frames=num_frames)
+ samples = self.vae(samples.to(self._dtype), decode_only=True, num_frames=num_frames)
video_clips.append(samples)
for i in range(1, loop):
diff --git a/videosys/pipelines/open_sora_plan/__init__.py b/videosys/pipelines/open_sora_plan/__init__.py
index 7a1ddb8..ddf790d 100644
--- a/videosys/pipelines/open_sora_plan/__init__.py
+++ b/videosys/pipelines/open_sora_plan/__init__.py
@@ -1,3 +1,8 @@
-from .pipeline_open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
+from .pipeline_open_sora_plan import (
+ OpenSoraPlanConfig,
+ OpenSoraPlanPipeline,
+ OpenSoraPlanV110PABConfig,
+ OpenSoraPlanV120PABConfig,
+)
-__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanPABConfig"]
+__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig"]
diff --git a/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py b/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py
index b973361..ce34321 100644
--- a/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py
+++ b/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py
@@ -20,39 +20,27 @@ import torch.distributed as dist
import tqdm
from bs4 import BeautifulSoup
from diffusers.models import AutoencoderKL, Transformer2DModel
-from diffusers.schedulers import PNDMScheduler
+from diffusers.schedulers import EulerAncestralDiscreteScheduler, PNDMScheduler
from diffusers.utils.torch_utils import randn_tensor
-from transformers import T5EncoderModel, T5Tokenizer
+from transformers import AutoTokenizer, MT5EncoderModel, T5EncoderModel, T5Tokenizer
from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
+from videosys.models.autoencoders.autoencoder_kl_open_sora_plan_v110 import (
+ CausalVAEModelWrapper as CausalVAEModelWrapperV110,
+)
+from videosys.models.autoencoders.autoencoder_kl_open_sora_plan_v120 import (
+ CausalVAEModelWrapper as CausalVAEModelWrapperV120,
+)
+from videosys.models.transformers.open_sora_plan_v110_transformer_3d import LatteT2V
+from videosys.models.transformers.open_sora_plan_v120_transformer_3d import OpenSoraT2V
from videosys.utils.logging import logger
-from videosys.utils.utils import save_video
-
-from ...models.autoencoders.autoencoder_kl_open_sora_plan import ae_stride_config, getae_wrapper
-from ...models.transformers.open_sora_plan_transformer_3d import LatteT2V
-
-EXAMPLE_DOC_STRING = """
- Examples:
- ```py
- >>> import torch
- >>> from diffusers import PixArtAlphaPipeline
-
- >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
- >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
- >>> # Enable memory optimizations.
- >>> pipe.enable_model_cpu_offload()
-
- >>> prompt = "A small cactus with a happy face in the Sahara desert."
- >>> image = pipe(prompt).images[0]
- ```
-"""
+from videosys.utils.utils import save_video, set_seed
-class OpenSoraPlanPABConfig(PABConfig):
+class OpenSoraPlanV110PABConfig(PABConfig):
def __init__(
self,
- steps: int = 150,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
spatial_range: int = 2,
@@ -97,7 +85,6 @@ class OpenSoraPlanPABConfig(PABConfig):
},
):
super().__init__(
- steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
@@ -113,6 +100,26 @@ class OpenSoraPlanPABConfig(PABConfig):
)
+class OpenSoraPlanV120PABConfig(PABConfig):
+ def __init__(
+ self,
+ spatial_broadcast: bool = True,
+ spatial_threshold: list = [100, 850],
+ spatial_range: int = 2,
+ cross_broadcast: bool = True,
+ cross_threshold: list = [100, 850],
+ cross_range: int = 6,
+ ):
+ super().__init__(
+ spatial_broadcast=spatial_broadcast,
+ spatial_threshold=spatial_threshold,
+ spatial_range=spatial_range,
+ cross_broadcast=cross_broadcast,
+ cross_threshold=cross_threshold,
+ cross_range=cross_range,
+ )
+
+
class OpenSoraPlanConfig:
"""
This config is to instantiate a `OpenSoraPlanPipeline` class for video generation.
@@ -163,10 +170,10 @@ class OpenSoraPlanConfig:
def __init__(
self,
- transformer: str = "LanguageBind/Open-Sora-Plan-v1.1.0",
- ae: str = "CausalVAEModel_4x8x8",
- text_encoder: str = "DeepFloyd/t5-v1_1-xxl",
- num_frames: int = 65,
+ version: str = "v120",
+ transformer_type: str = "29x480p",
+ transformer: str = None,
+ text_encoder: str = None,
# ======= distributed ========
num_gpus: int = 1,
# ======= memory =======
@@ -175,15 +182,32 @@ class OpenSoraPlanConfig:
tile_overlap_factor: float = 0.25,
# ======= pab ========
enable_pab: bool = False,
- pab_config: PABConfig = OpenSoraPlanPABConfig(),
+ pab_config: PABConfig = None,
):
self.pipeline_cls = OpenSoraPlanPipeline
- self.ae = ae
- self.text_encoder = text_encoder
- self.transformer = transformer
- assert num_frames in [65, 221], "num_frames must be one of [65, 221]"
- self.num_frames = num_frames
- self.version = f"{num_frames}x512x512"
+
+ # get version
+ assert version in ["v110", "v120"], f"Unknown Open-Sora-Plan version: {version}"
+ self.version = version
+ self.transformer_type = transformer_type
+
+ # check transformer_type
+ if version == "v110":
+ assert transformer_type in ["65x512x512", "221x512x512"]
+ elif version == "v120":
+ assert transformer_type in ["93x480p", "93x720p", "29x480p", "29x720p"]
+ self.num_frames = int(transformer_type.split("x")[0])
+
+ # set default values according to version
+ if version == "v110":
+ transformer_default = "LanguageBind/Open-Sora-Plan-v1.1.0"
+ text_encoder_default = "DeepFloyd/t5-v1_1-xxl"
+ elif version == "v120":
+ transformer_default = "LanguageBind/Open-Sora-Plan-v1.2.0"
+ text_encoder_default = "google/mt5-xxl"
+ self.text_encoder = text_encoder or text_encoder_default
+ self.transformer = transformer or transformer_default
+
# ======= distributed ========
self.num_gpus = num_gpus
# ======= memory ========
@@ -192,7 +216,13 @@ class OpenSoraPlanConfig:
self.tile_overlap_factor = tile_overlap_factor
# ======= pab ========
self.enable_pab = enable_pab
- self.pab_config = pab_config
+ if self.enable_pab and pab_config is None:
+ if version == "v110":
+ self.pab_config = OpenSoraPlanV110PABConfig()
+ elif version == "v120":
+ self.pab_config = OpenSoraPlanV120PABConfig()
+ else:
+ self.pab_config = pab_config
class OpenSoraPlanPipeline(VideoSysPipeline):
@@ -221,7 +251,7 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
- _optional_components = ["tokenizer", "text_encoder"]
+ _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
@@ -240,26 +270,57 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
# init
if tokenizer is None:
- tokenizer = T5Tokenizer.from_pretrained(config.text_encoder)
+ if config.version == "v110":
+ tokenizer = T5Tokenizer.from_pretrained(config.text_encoder)
+ elif config.version == "v120":
+ tokenizer = AutoTokenizer.from_pretrained(config.text_encoder)
+
if text_encoder is None:
- text_encoder = T5EncoderModel.from_pretrained(config.text_encoder, torch_dtype=torch.float16)
+ if config.version == "v110":
+ text_encoder = T5EncoderModel.from_pretrained(config.text_encoder, torch_dtype=dtype)
+ elif config.version == "v120":
+ text_encoder = MT5EncoderModel.from_pretrained(
+ config.text_encoder, low_cpu_mem_usage=True, torch_dtype=dtype
+ )
+
if vae is None:
- vae = getae_wrapper(config.ae)(config.transformer, subfolder="vae").to(dtype=dtype)
+ if config.version == "v110":
+ vae = CausalVAEModelWrapperV110(config.transformer, subfolder="vae").to(dtype=dtype)
+ elif config.version == "v120":
+ vae = CausalVAEModelWrapperV120(config.transformer, subfolder="vae").to(dtype=dtype)
+
if transformer is None:
- transformer = LatteT2V.from_pretrained(config.transformer, subfolder=config.version, torch_dtype=dtype)
+ if config.version == "v110":
+ transformer = LatteT2V.from_pretrained(
+ config.transformer, subfolder=config.transformer_type, torch_dtype=dtype
+ )
+ elif config.version == "v120":
+ transformer = OpenSoraT2V.from_pretrained(
+ config.transformer, subfolder=config.transformer_type, torch_dtype=dtype
+ )
+
if scheduler is None:
- scheduler = PNDMScheduler()
+ if config.version == "v110":
+ scheduler = PNDMScheduler()
+ elif config.version == "v120":
+ scheduler = EulerAncestralDiscreteScheduler()
# setting
if config.enable_tiling:
vae.vae.enable_tiling()
vae.vae.tile_overlap_factor = config.tile_overlap_factor
- vae.vae_scale_factor = ae_stride_config[config.ae]
+ vae.vae.tile_sample_min_size = 512
+ vae.vae.tile_latent_min_size = 64
+ vae.vae.tile_sample_min_size_t = 29
+ vae.vae.tile_latent_min_size_t = 8
+ # if low_mem:
+ # vae.vae.tile_sample_min_size = 256
+ # vae.vae.tile_latent_min_size = 32
+ # vae.vae.tile_sample_min_size_t = 29
+ # vae.vae.tile_latent_min_size_t = 8
+ vae.vae_scale_factor = [4, 8, 8]
transformer.force_images = False
- # set eval and device
- self.set_eval_and_device(device, vae, transformer)
-
# pab
if config.enable_pab:
set_pab_manager(config.pab_config)
@@ -272,10 +333,35 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
- self.set_eval_and_device(device, text_encoder)
+ self.set_eval_and_device(device, text_encoder, vae, transformer)
# self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ # parallel
+ self._set_parallel()
+
+ def _set_seed(self, seed):
+ if dist.get_world_size() == 1:
+ set_seed(seed)
+ else:
+ set_seed(seed, self.transformer.parallel_manager.dp_rank)
+
+ def _set_parallel(
+ self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
+ ):
+ # init sequence parallel
+ if sp_size is None:
+ sp_size = dist.get_world_size()
+ dp_size = 1
+ else:
+ assert (
+ dist.get_world_size() % sp_size == 0
+ ), f"world_size {dist.get_world_size()} must be divisible by sp_size"
+ dp_size = dist.get_world_size() // sp_size
+
+ # transformer parallel
+ self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
+
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
@@ -449,6 +535,143 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
return prompt_embeds, negative_prompt_embeds
+ def encode_prompt_v120(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 120,
+ **kwargs,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
+ string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
+ """
+ if device is None:
+ device = getattr(self, "_execution_device", None) or getattr(self, "device", None) or torch.device("cuda")
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens = [negative_prompt] * batch_size
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ negative_prompt_attention_mask = uncond_input.attention_mask
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -476,6 +699,8 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
callback_steps,
prompt_embeds=None,
negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -519,6 +744,12 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
@@ -685,10 +916,13 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
+ seed: int = -1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -697,6 +931,7 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
mask_feature: bool = True,
enable_temporal_attentions: bool = True,
verbose: bool = True,
+ max_sequence_length: int = 512,
) -> Union[VideoSysPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -768,11 +1003,22 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
returned where the first element is a list with the generated images
"""
# 1. Check inputs. Raise error if not correct
- height = 512
- width = 512
+ height = self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1]
+ width = self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2]
num_frames = self._config.num_frames
update_steps(num_inference_steps)
- self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds)
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+ self._set_seed(seed)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
@@ -782,7 +1028,7 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
else:
batch_size = prompt_embeds.shape[0]
- device = self._execution_device
+ device = getattr(self, "_execution_device", None) or getattr(self, "device", None) or torch.device("cuda")
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -790,23 +1036,50 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
- prompt,
- do_classifier_free_guidance,
- negative_prompt=negative_prompt,
- num_images_per_prompt=num_images_per_prompt,
- device=device,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- clean_caption=clean_caption,
- mask_feature=mask_feature,
- )
+ if self._config.version == "v110":
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clean_caption=clean_caption,
+ mask_feature=mask_feature,
+ )
+ prompt_attention_mask = None
+ negative_prompt_attention_mask = None
+ else:
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt_v120(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ )
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ if prompt_attention_mask is not None and negative_prompt_attention_mask is not None:
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
- self.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps = self.scheduler.timesteps
+ if self._config.version == "v110":
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+ else:
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
@@ -827,15 +1100,10 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
- # if self.transformer.config.sample_size == 128:
- # resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
- # aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
- # resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
- # aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
- # added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
for i, t in progress_wrap(list(enumerate(timesteps))):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -856,9 +1124,15 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
- if prompt_embeds.ndim == 3:
+ if prompt_embeds is not None and prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
+ if prompt_attention_mask is not None and prompt_attention_mask.ndim == 2:
+ prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l
+ # prepare attention_mask.
+ # b c t h w -> b t h w
+ attention_mask = torch.ones_like(latent_model_input)[:, 0]
+
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
@@ -867,6 +1141,8 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
enable_temporal_attentions=enable_temporal_attentions,
+ attention_mask=attention_mask,
+ encoder_attention_mask=prompt_attention_mask,
return_dict=False,
)[0]
@@ -906,10 +1182,57 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
return VideoSysPipelineOutput(video=video)
def decode_latents(self, latents):
- video = self.vae.decode(latents) # b t c h w
- # b t c h w -> b t h w c
- video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous()
+ video = self.vae(latents.to(self.vae.vae.dtype))
+ video = (
+ ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous()
+ ) # b t h w c
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
return video
def save_video(self, video, output_path):
save_video(video, output_path, fps=24)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
diff --git a/videosys/pipelines/vchitect/__init__.py b/videosys/pipelines/vchitect/__init__.py
new file mode 100644
index 0000000..0ee6ac1
--- /dev/null
+++ b/videosys/pipelines/vchitect/__init__.py
@@ -0,0 +1,3 @@
+from .pipeline_vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline
+
+__all__ = ["VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig"]
diff --git a/videosys/pipelines/vchitect/pipeline_vchitect.py b/videosys/pipelines/vchitect/pipeline_vchitect.py
new file mode 100644
index 0000000..e46927d
--- /dev/null
+++ b/videosys/pipelines/vchitect/pipeline_vchitect.py
@@ -0,0 +1,1057 @@
+# Adapted from Vchitect
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Vchitect: https://github.com/Vchitect/Vchitect-2.0
+# diffusers: https://github.com/huggingface/diffusers
+# --------------------------------------------------------
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+import torch.distributed as dist
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models.autoencoders import AutoencoderKL
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils.torch_utils import randn_tensor
+from torch.amp import autocast
+from tqdm import tqdm
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
+from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
+from videosys.models.transformers.vchitect_transformer_3d import VchitectXLTransformerModel
+from videosys.utils.logging import logger
+from videosys.utils.utils import save_video, set_seed
+
+
+class VchitectPABConfig(PABConfig):
+ def __init__(
+ self,
+ spatial_broadcast: bool = True,
+ spatial_threshold: list = [100, 800],
+ spatial_range: int = 2,
+ temporal_broadcast: bool = True,
+ temporal_threshold: list = [100, 800],
+ temporal_range: int = 4,
+ cross_broadcast: bool = True,
+ cross_threshold: list = [100, 800],
+ cross_range: int = 6,
+ ):
+ super().__init__(
+ spatial_broadcast=spatial_broadcast,
+ spatial_threshold=spatial_threshold,
+ spatial_range=spatial_range,
+ temporal_broadcast=temporal_broadcast,
+ temporal_threshold=temporal_threshold,
+ temporal_range=temporal_range,
+ cross_broadcast=cross_broadcast,
+ cross_threshold=cross_threshold,
+ cross_range=cross_range,
+ )
+
+
+class VchitectConfig:
+ """
+ This config is to instantiate a `VchitectXLPipeline` class for video generation.
+
+ To be specific, this config will be passed to engine by `VideoSysEngine(config)`.
+ In the engine, it will be used to instantiate the corresponding pipeline class.
+ And the engine will call the `generate` function of the pipeline to generate the video.
+ If you want to explore the detail of generation, please refer to the pipeline class below.
+
+ Args:
+ model_path (str):
+ The model path to use. Defaults to "Vchitect/Vchitect-2.0-2B".
+ num_gpus (int):
+ The number of GPUs to use. Defaults to 1.
+ cpu_offload (bool):
+ Whether to enable cpu offload. Defaults to False.
+ enable_pab (bool):
+ Whether to enable Pyramid Attention Broadcast. Defaults to False.
+ pab_config (VchitectPABConfig):
+ The configuration for Pyramid Attention Broadcast. Defaults to `VchitectPABConfig()`.
+
+ Examples:
+ ```python
+ from videosys import OpenSoraPlanConfig, VideoSysEngine
+
+ # change num_gpus for multi-gpu inference
+ config = VchitectConfig("Vchitect/Vchitect-2.0-2B", num_gpus=1)
+ engine = VideoSysEngine(config)
+
+ prompt = "Sunset over the sea."
+ # seed=-1 means random seed. >0 means fixed seed.
+ # WxH: 480x288 624x352 432x240 768x432
+ video = engine.generate(
+ prompt=prompt,
+ negative_prompt="",
+ num_inference_steps=100,
+ guidance_scale=7.5,
+ width=480,
+ height=288,
+ frames=40,
+ seed=0,
+ ).video[0]
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
+ ```
+ """
+
+ def __init__(
+ self,
+ model_path: str = "Vchitect/Vchitect-2.0-2B",
+ # ======= distributed ========
+ num_gpus: int = 1,
+ # ======= memory =======
+ cpu_offload: bool = False,
+ # ======= pab ========
+ enable_pab: bool = False,
+ pab_config: VchitectPABConfig = VchitectPABConfig(),
+ ):
+ self.model_path = model_path
+ self.pipeline_cls = VchitectXLPipeline
+ # ======= distributed ========
+ self.num_gpus = num_gpus
+ # ======= memory ========
+ self.cpu_offload = cpu_offload
+ # ======= pab ========
+ self.enable_pab = enable_pab
+ self.pab_config = pab_config
+
+
+class VchitectXLPipeline(VideoSysPipeline):
+ r"""
+ Args:
+ transformer ([`VchitectXLTransformerModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
+ as its dimension.
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ text_encoder_3 ([`T5EncoderModel`]):
+ Frozen text-encoder. Stable Diffusion 3 uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_3 (`T5TokenizerFast`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
+ _optional_components = [
+ "text_encoder",
+ "text_encoder_2",
+ "text_encoder_3",
+ "tokenizer",
+ "tokenizer_2",
+ "tokenizer_3",
+ "vae",
+ "transformer",
+ "scheduler",
+ ]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
+
+ def __init__(
+ self,
+ config: VchitectConfig,
+ text_encoder: Optional[CLIPTextModelWithProjection] = None,
+ text_encoder_2: Optional[CLIPTextModelWithProjection] = None,
+ text_encoder_3: Optional[T5EncoderModel] = None,
+ tokenizer: Optional[CLIPTokenizer] = None,
+ tokenizer_2: Optional[CLIPTokenizer] = None,
+ tokenizer_3: Optional[T5TokenizerFast] = None,
+ vae: Optional[AutoencoderKL] = None,
+ transformer: Optional[VchitectXLTransformerModel] = None,
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
+ device: torch.device = torch.device("cuda"),
+ dtype: torch.dtype = torch.bfloat16,
+ ):
+ super().__init__()
+ self._config = config
+ self._device = device
+ self._dtype = dtype
+
+ if text_encoder is None:
+ self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
+ config.model_path, subfolder="text_encoder", torch_dtype=dtype
+ )
+ if text_encoder_2 is None:
+ self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
+ config.model_path, subfolder="text_encoder_2", torch_dtype=dtype
+ )
+ if text_encoder_3 is None:
+ self.text_encoder_3 = T5EncoderModel.from_pretrained(
+ config.model_path, subfolder="text_encoder_3", torch_dtype=dtype
+ )
+
+ if tokenizer is None:
+ self.tokenizer = CLIPTokenizer.from_pretrained(config.model_path, subfolder="tokenizer")
+ if tokenizer_2 is None:
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(config.model_path, subfolder="tokenizer_2")
+ if tokenizer_3 is None:
+ self.tokenizer_3 = T5TokenizerFast.from_pretrained(config.model_path, subfolder="tokenizer_3")
+
+ if vae is None:
+ self.vae = AutoencoderKL.from_pretrained(config.model_path, subfolder="vae", torch_dtype=dtype)
+
+ if transformer is None:
+ self.transformer = VchitectXLTransformerModel.from_pretrained(
+ config.model_path, subfolder="transformer", torch_dtype=dtype
+ )
+
+ if scheduler is None:
+ self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(config.model_path, subfolder="scheduler")
+
+ self.register_modules(
+ tokenizer=self.tokenizer,
+ tokenizer_2=self.tokenizer_2,
+ tokenizer_3=self.tokenizer_3,
+ text_encoder=self.text_encoder,
+ text_encoder_3=self.text_encoder_3,
+ text_encoder_2=self.text_encoder_2,
+ vae=self.vae,
+ transformer=self.transformer,
+ scheduler=self.scheduler,
+ )
+
+ # pab
+ if config.enable_pab:
+ set_pab_manager(config.pab_config)
+
+ # cpu offload
+ if config.cpu_offload:
+ self.enable_model_cpu_offload()
+ else:
+ self.set_eval_and_device(
+ device, self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.transformer, self.vae
+ )
+
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.max_sequence_length_t5 = 256
+ self.default_sample_size = (
+ self.transformer.config.sample_size
+ if hasattr(self, "transformer") and self.transformer is not None
+ else 128
+ )
+
+ # parallel
+ self._set_parallel()
+
+ def _set_parallel(
+ self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
+ ):
+ # init sequence parallel
+ if sp_size is None:
+ sp_size = dist.get_world_size()
+ dp_size = 1
+ else:
+ assert (
+ dist.get_world_size() % sp_size == 0
+ ), f"world_size {dist.get_world_size()} must be divisible by sp_size"
+ dp_size = dist.get_world_size() // sp_size
+
+ # transformer parallel
+ self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self.execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if self.text_encoder_3 is None:
+ return torch.zeros(
+ (batch_size, self.max_sequence_length_t5, self.transformer.config.joint_attention_dim),
+ device=device,
+ dtype=dtype,
+ )
+
+ text_inputs = self.tokenizer_3(
+ prompt,
+ padding="max_length",
+ max_length=self.max_sequence_length_t5,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.max_sequence_length_t5 - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.max_sequence_length_t5} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
+
+ dtype = self.text_encoder_3.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ clip_skip: Optional[int] = None,
+ clip_model_index: int = 0,
+ ):
+ device = device or self.execution_device
+
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
+
+ tokenizer = clip_tokenizers[clip_model_index]
+ text_encoder = clip_text_encoders[clip_model_index]
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+ pooled_prompt_embeds = prompt_embeds[0]
+
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+ def _set_seed(self, seed):
+ if dist.get_world_size() == 1:
+ set_seed(seed)
+ else:
+ set_seed(seed, self.transformer.parallel_manager.dp_rank)
+
+ @autocast("cuda", enabled=False)
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ prompt_3: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
+ `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self.execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ prompt_3 = prompt_3 or prompt
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
+
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ clip_model_index=0,
+ )
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
+ prompt=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ clip_model_index=1,
+ )
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
+
+ t5_prompt_embed = self._get_t5_prompt_embeds(
+ prompt=prompt_3,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ )
+
+ clip_prompt_embeds = torch.nn.functional.pad(
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
+ )
+
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+ negative_prompt_3 = (
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
+ )
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
+ negative_prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=None,
+ clip_model_index=0,
+ )
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
+ negative_prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=None,
+ clip_model_index=1,
+ )
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
+
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
+ prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device
+ )
+
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
+ negative_clip_prompt_embeds,
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
+ )
+
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
+ negative_pooled_prompt_embeds = torch.cat(
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
+ )
+ prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = (
+ prompt_embeds.to(self._device),
+ negative_prompt_embeds.to(self._device),
+ pooled_prompt_embeds.to(self._device),
+ negative_pooled_prompt_embeds.to(self._device),
+ )
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ prompt_3,
+ height,
+ width,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ negative_prompt_3=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_3 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ frames,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+ # 1, 60, 16, 32, 32
+ shape = (
+ batch_size,
+ frames,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @autocast("cuda", dtype=torch.bfloat16)
+ def generate(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ height: int = 288,
+ width: int = 480,
+ frames: int = 40,
+ num_inference_steps: int = 100,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.5,
+ seed: int = -1,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
+ will be used instead
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ frames = frames or 24
+ self._set_seed(seed)
+ update_steps(num_inference_steps)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ prompt_3,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=self.text_encoder.device,
+ clip_skip=self.clip_skip,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, self._device, timesteps
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ frames,
+ prompt_embeds.dtype,
+ self._device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ # with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0])
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_model_input[0, :].unsqueeze(0),
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds[0, :].unsqueeze(0),
+ pooled_projections=pooled_prompt_embeds[0, :].unsqueeze(0),
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ noise_pred_text = self.transformer(
+ hidden_states=latent_model_input[1, :].unsqueeze(0),
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds[1, :].unsqueeze(0),
+ pooled_projections=pooled_prompt_embeds[1, :].unsqueeze(0),
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+
+ # call the callback, if provided
+ # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ # progress_bar.update()
+
+ # if output_type == "latent":
+ # image = latents
+
+ # else:
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ videos = []
+ for v_idx in range(latents.shape[1]):
+ image = self.vae.decode(latents[:, v_idx], return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ videos.append(image[0])
+ videos = [videos]
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (videos,)
+
+ return VideoSysPipelineOutput(video=videos)
+
+ def save_video(self, video, output_path):
+ save_video(video, output_path, fps=8)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
diff --git a/videosys/utils/__init__.py b/videosys/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/videosys/utils/test.py b/videosys/utils/test.py
new file mode 100644
index 0000000..aec6f7b
--- /dev/null
+++ b/videosys/utils/test.py
@@ -0,0 +1,12 @@
+import functools
+
+import torch
+
+
+def empty_cache(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ torch.cuda.empty_cache()
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/videosys/utils/utils.py b/videosys/utils/utils.py
index 868e37e..622a36d 100644
--- a/videosys/utils/utils.py
+++ b/videosys/utils/utils.py
@@ -16,7 +16,17 @@ def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
p.requires_grad = flag
-def set_seed(seed):
+def set_seed(seed, dp_rank=None):
+ if seed == -1:
+ seed = random.randint(0, 1000000)
+
+ if dp_rank is not None:
+ seed = torch.tensor(seed, dtype=torch.int64).cuda()
+ if dist.get_world_size() > 1:
+ dist.broadcast(seed, 0)
+ seed = seed + dp_rank
+
+ seed = int(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)