From 21ea48e71a884cf9326f34c187a3704f3422aa09 Mon Sep 17 00:00:00 2001 From: LiewFeng <1361871897@qq.com> Date: Mon, 16 Dec 2024 17:27:04 +0800 Subject: [PATCH] update the version of videosys --- eval/teacache/README.md | 30 + eval/teacache/common_metrics/batch_eval.py | 205 ++ eval/teacache/experiments/latte.py | 52 +- eval/teacache/experiments/opensora.py | 64 +- eval/teacache/experiments/opensora_plan.py | 61 +- requirements.txt | 8 +- setup.py | 38 +- videosys/__init__.py | 13 +- videosys/core/comm.py | 42 +- videosys/core/engine.py | 2 +- videosys/core/pab_mgr.py | 11 +- videosys/core/parallel_mgr.py | 93 +- videosys/core/pipeline.py | 7 +- .../autoencoders/autoencoder_kl_open_sora.py | 5 +- .../autoencoder_kl_open_sora_plan_v110.py | 1643 ++++++++++ .../autoencoder_kl_open_sora_plan_v120.py | 1139 +++++++ videosys/models/modules/attentions.py | 644 +++- videosys/models/modules/normalization.py | 30 + .../transformers/cogvideox_transformer_3d.py | 188 +- .../transformers/latte_transformer_3d.py | 77 +- .../open_sora_plan_v110_transformer_3d.py | 2826 +++++++++++++++++ .../open_sora_plan_v120_transformer_3d.py | 2183 +++++++++++++ .../transformers/open_sora_transformer_3d.py | 83 +- .../transformers/vchitect_transformer_3d.py | 644 ++++ .../pipelines/cogvideox/pipeline_cogvideox.py | 52 +- videosys/pipelines/latte/pipeline_latte.py | 38 +- .../pipelines/open_sora/pipeline_open_sora.py | 58 +- videosys/pipelines/open_sora_plan/__init__.py | 9 +- .../open_sora_plan/pipeline_open_sora_plan.py | 473 ++- videosys/pipelines/vchitect/__init__.py | 3 + .../pipelines/vchitect/pipeline_vchitect.py | 1057 ++++++ videosys/utils/__init__.py | 0 videosys/utils/test.py | 12 + videosys/utils/utils.py | 12 +- 34 files changed, 11296 insertions(+), 506 deletions(-) create mode 100644 eval/teacache/README.md create mode 100644 eval/teacache/common_metrics/batch_eval.py create mode 100644 videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v110.py create mode 100644 videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v120.py create mode 100644 videosys/models/transformers/open_sora_plan_v110_transformer_3d.py create mode 100644 videosys/models/transformers/open_sora_plan_v120_transformer_3d.py create mode 100644 videosys/models/transformers/vchitect_transformer_3d.py create mode 100644 videosys/pipelines/vchitect/__init__.py create mode 100644 videosys/pipelines/vchitect/pipeline_vchitect.py create mode 100644 videosys/utils/__init__.py create mode 100644 videosys/utils/test.py diff --git a/eval/teacache/README.md b/eval/teacache/README.md new file mode 100644 index 0000000..2fc8e65 --- /dev/null +++ b/eval/teacache/README.md @@ -0,0 +1,30 @@ +# Evaluation of TeaCache + +We first generate videos according to VBench's prompts. + +And then calculate Vbench, PSNR, LPIPS and SSIM based on the video generated. + +1. Generate video +``` +cd eval/teacache +python experiments/latte.py +python experiments/opensora.py +python experiments/open_sora_plan.py +``` + +2. Calculate Vbench score +``` +# vbench is calculated independently +# get scores for all metrics +python vbench/run_vbench.py --video_path aaa --save_path bbb +# calculate final score +python vbench/cal_vbench.py --score_dir bbb +``` + +3. Calculate other metrics +``` +# these metrics are calculated compared with original model +# gt video is the video of original model +# generated video is our methods's results +python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb +``` diff --git a/eval/teacache/common_metrics/batch_eval.py b/eval/teacache/common_metrics/batch_eval.py new file mode 100644 index 0000000..1696235 --- /dev/null +++ b/eval/teacache/common_metrics/batch_eval.py @@ -0,0 +1,205 @@ +import argparse +import os + +import imageio +import torch +import torchvision.transforms.functional as F +import tqdm +from calculate_lpips import calculate_lpips +from calculate_psnr import calculate_psnr +from calculate_ssim import calculate_ssim + + +def load_video(video_path): + """ + Load a video from the given path and convert it to a PyTorch tensor. + """ + # Read the video using imageio + reader = imageio.get_reader(video_path, "ffmpeg") + + # Extract frames and convert to a list of tensors + frames = [] + for frame in reader: + # Convert the frame to a tensor and permute the dimensions to match (C, H, W) + frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1) + frames.append(frame_tensor) + + # Stack the list of tensors into a single tensor with shape (T, C, H, W) + video_tensor = torch.stack(frames) + + return video_tensor + + +def resize_video(video, target_height, target_width): + resized_frames = [] + for frame in video: + resized_frame = F.resize(frame, [target_height, target_width]) + resized_frames.append(resized_frame) + return torch.stack(resized_frames) + + +def resize_gt_video(gt_video, gen_video): + gen_video_shape = gen_video.shape + T_gen, _, H_gen, W_gen = gen_video_shape + T_eval, _, H_eval, W_eval = gt_video.shape + + if T_eval < T_gen: + raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).") + + if H_eval < H_gen or W_eval < W_gen: + # Resize the video maintaining the aspect ratio + resize_height = max(H_gen, int(H_gen * (H_eval / W_eval))) + resize_width = max(W_gen, int(W_gen * (W_eval / H_eval))) + gt_video = resize_video(gt_video, resize_height, resize_width) + # Recalculate the dimensions + T_eval, _, H_eval, W_eval = gt_video.shape + + # Center crop + start_h = (H_eval - H_gen) // 2 + start_w = (W_eval - W_gen) // 2 + cropped_video = gt_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen] + + return cropped_video + + +def get_video_ids(gt_video_dirs, gen_video_dirs): + video_ids = [] + for f in os.listdir(gt_video_dirs[0]): + if f.endswith(f".mp4"): + video_ids.append(f.replace(f".mp4", "")) + video_ids.sort() + + for video_dir in gt_video_dirs + gen_video_dirs: + tmp_video_ids = [] + for f in os.listdir(video_dir): + if f.endswith(f".mp4"): + tmp_video_ids.append(f.replace(f".mp4", "")) + tmp_video_ids.sort() + if tmp_video_ids != video_ids: + raise ValueError(f"Video IDs in {video_dir} are different.") + return video_ids + + +def get_videos(video_ids, gt_video_dirs, gen_video_dirs): + gt_videos = {} + generated_videos = {} + + for gt_video_dir in gt_video_dirs: + tmp_gt_videos_tensor = [] + for video_id in video_ids: + gt_video = load_video(os.path.join(gt_video_dir, f"{video_id}.mp4")) + tmp_gt_videos_tensor.append(gt_video) + gt_videos[gt_video_dir] = tmp_gt_videos_tensor + + for generated_video_dir in gen_video_dirs: + tmp_generated_videos_tensor = [] + for video_id in video_ids: + generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.mp4")) + tmp_generated_videos_tensor.append(generated_video) + generated_videos[generated_video_dir] = tmp_generated_videos_tensor + + return gt_videos, generated_videos + + +def print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs): + out_str = "" + + for gt_video_dir in gt_video_dirs: + for generated_video_dir in gen_video_dirs: + if gt_video_dir == generated_video_dir: + continue + lpips = sum(lpips_results[gt_video_dir][generated_video_dir]) / len( + lpips_results[gt_video_dir][generated_video_dir] + ) + psnr = sum(psnr_results[gt_video_dir][generated_video_dir]) / len( + psnr_results[gt_video_dir][generated_video_dir] + ) + ssim = sum(ssim_results[gt_video_dir][generated_video_dir]) / len( + ssim_results[gt_video_dir][generated_video_dir] + ) + out_str += f"\ngt: {gt_video_dir} -> gen: {generated_video_dir}, lpips: {lpips:.4f}, psnr: {psnr:.4f}, ssim: {ssim:.4f}" + + return out_str + + +def main(args): + device = "cuda" + gt_video_dirs = args.gt_video_dirs + gen_video_dirs = args.gen_video_dirs + + video_ids = get_video_ids(gt_video_dirs, gen_video_dirs) + print(f"Find {len(video_ids)} videos") + + prompt_interval = 1 + batch_size = 8 + calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True + + lpips_results = {} + psnr_results = {} + ssim_results = {} + for gt_video_dir in gt_video_dirs: + lpips_results[gt_video_dir] = {} + psnr_results[gt_video_dir] = {} + ssim_results[gt_video_dir] = {} + for generated_video_dir in gen_video_dirs: + lpips_results[gt_video_dir][generated_video_dir] = [] + psnr_results[gt_video_dir][generated_video_dir] = [] + ssim_results[gt_video_dir][generated_video_dir] = [] + + total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0) + + for idx in tqdm.tqdm(range(total_len)): + video_ids_batch = video_ids[idx * batch_size : (idx + 1) * batch_size] + gt_videos, generated_videos = get_videos(video_ids_batch, gt_video_dirs, gen_video_dirs) + + for gt_video_dir, gt_videos_tensor in gt_videos.items(): + for generated_video_dir, generated_videos_tensor in generated_videos.items(): + if gt_video_dir == generated_video_dir: + continue + + if not isinstance(gt_videos_tensor, torch.Tensor): + for i in range(len(gt_videos_tensor)): + gt_videos_tensor[i] = resize_gt_video(gt_videos_tensor[i], generated_videos_tensor[0]) + gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu() + + generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu() + + if calculate_lpips_flag: + result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device) + result = result["value"].values() + result = float(sum(result) / len(result)) + lpips_results[gt_video_dir][generated_video_dir].append(result) + + if calculate_psnr_flag: + result = calculate_psnr(gt_videos_tensor, generated_videos_tensor) + result = result["value"].values() + result = float(sum(result) / len(result)) + psnr_results[gt_video_dir][generated_video_dir].append(result) + + if calculate_ssim_flag: + result = calculate_ssim(gt_videos_tensor, generated_videos_tensor) + result = result["value"].values() + result = float(sum(result) / len(result)) + ssim_results[gt_video_dir][generated_video_dir].append(result) + + if (idx + 1) % prompt_interval == 0: + out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs) + print(f"Processed {idx + 1} / {total_len} videos. {out_str}") + + out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs) + + # save + with open(f"./batch_eval.txt", "w+") as f: + f.write(out_str) + + print(f"Processed all videos. {out_str}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--gt_video_dirs", type=str, nargs="+") + parser.add_argument("--gen_video_dirs", type=str, nargs="+") + + args = parser.parse_args() + + main(args) diff --git a/eval/teacache/experiments/latte.py b/eval/teacache/experiments/latte.py index 962da93..fbd6498 100644 --- a/eval/teacache/experiments/latte.py +++ b/eval/teacache/experiments/latte.py @@ -5,12 +5,7 @@ from einops import rearrange, repeat from torch import nn import numpy as np from typing import Any, Dict, Optional, Tuple -from videosys.core.parallel_mgr import ( - enable_sequence_parallel, - get_cfg_parallel_size, - get_data_parallel_group, - get_sequence_parallel_group, -) +from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence def teacache_forward( self, @@ -67,7 +62,7 @@ def teacache_forward( """ # 0. Split batch for data parallelism - if get_cfg_parallel_size() > 1: + if self.parallel_manager.cp_size > 1: ( hidden_states, timestep, @@ -77,7 +72,7 @@ def teacache_forward( attention_mask, encoder_attention_mask, ) = batch_func( - partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0), + partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), hidden_states, timestep, encoder_hidden_states, @@ -193,14 +188,14 @@ def teacache_forward( if not should_calc: hidden_states += self.previous_residual else: - if enable_sequence_parallel(): - set_temporal_pad(frame + use_image_num) - set_spatial_pad(num_patches) + if self.parallel_manager.sp_size > 1: + set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group) + set_pad("spatial", num_patches, self.parallel_manager.sp_group) hidden_states = self.split_from_second_dim(hidden_states, input_batch_size) encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size) timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size) temp_pos_embed = split_sequence( - self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad() + self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") ) else: temp_pos_embed = self.temp_pos_embed @@ -323,14 +318,14 @@ def teacache_forward( ).contiguous() self.previous_residual = hidden_states - hidden_states_origin else: - if enable_sequence_parallel(): - set_temporal_pad(frame + use_image_num) - set_spatial_pad(num_patches) + if self.parallel_manager.sp_size > 1: + set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group) + set_pad("spatial", num_patches, self.parallel_manager.sp_group) hidden_states = self.split_from_second_dim(hidden_states, input_batch_size) encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size) timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size) temp_pos_embed = split_sequence( - self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad() + self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") ) else: temp_pos_embed = self.temp_pos_embed @@ -451,7 +446,8 @@ def teacache_forward( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() - if enable_sequence_parallel(): + + if self.parallel_manager.sp_size > 1: if self.enable_teacache: if should_calc: hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size) @@ -488,8 +484,8 @@ def teacache_forward( output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous() # 3. Gather batch for data parallelism - if get_cfg_parallel_size() > 1: - output = gather_sequence(output, get_cfg_parallel_group(), dim=0) + if self.parallel_manager.cp_size > 1: + output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) if not return_dict: return (output,) @@ -497,10 +493,6 @@ def teacache_forward( return Transformer3DModelOutput(sample=output) -def eval_base(prompt_list): - config = LatteConfig() - engine = VideoSysEngine(config) - generate_func(engine, prompt_list, "./samples/latte_base", loop=5) def eval_teacache_slow(prompt_list): config = LatteConfig() @@ -523,10 +515,18 @@ def eval_teacache_fast(prompt_list): engine.driver_worker.transformer.previous_residual = None engine.driver_worker.transformer.__class__.forward = teacache_forward generate_func(engine, prompt_list, "./samples/latte_teacache_fast", loop=5) + + + +def eval_base(prompt_list): + config = LatteConfig() + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/latte_base", loop=5) + + if __name__ == "__main__": prompt_list = read_prompt_list("vbench/VBench_full_info.json") - # eval_base(prompt_list) + eval_base(prompt_list) eval_teacache_slow(prompt_list) - # eval_teacache_fast(prompt_list) - + eval_teacache_fast(prompt_list) \ No newline at end of file diff --git a/eval/teacache/experiments/opensora.py b/eval/teacache/experiments/opensora.py index a001c5a..3ef72fd 100644 --- a/eval/teacache/experiments/opensora.py +++ b/eval/teacache/experiments/opensora.py @@ -3,23 +3,23 @@ from videosys import OpenSoraConfig, VideoSysEngine import torch from einops import rearrange from videosys.models.transformers.open_sora_transformer_3d import t2i_modulate, auto_grad_checkpoint -from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_temporal_pad, set_spatial_pad, set_temporal_pad, split_sequence +from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence import numpy as np from videosys.utils.utils import batch_func -from videosys.core.parallel_mgr import ( - enable_sequence_parallel, - get_cfg_parallel_size, - get_data_parallel_group, - get_sequence_parallel_group, -) +from functools import partial def teacache_forward( self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs ): # === Split batch === - if get_cfg_parallel_size() > 1: + if self.parallel_manager.cp_size > 1: x, timestep, y, x_mask, mask = batch_func( - partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask + partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), + x, + timestep, + y, + x_mask, + mask, ) dtype = self.x_embedder.proj.weight.dtype @@ -60,7 +60,7 @@ def teacache_forward( # === get x embed === x = self.x_embedder(x) # [B, N, C] x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) - x = x + pos_emb + x = x + pos_emb if self.enable_teacache: inp = x.clone() @@ -84,7 +84,6 @@ def teacache_forward( self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = modulated_inp - # === blocks === if self.enable_teacache: if not should_calc: @@ -92,16 +91,15 @@ def teacache_forward( x += self.previous_residual else: # shard over the sequence dim if sp is enabled - if enable_sequence_parallel(): - set_temporal_pad(T) - set_spatial_pad(S) - x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()) + if self.parallel_manager.sp_size > 1: + set_pad("temporal", T, self.parallel_manager.sp_group) + set_pad("spatial", S, self.parallel_manager.sp_group) + x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")) T = x.shape[1] x_mask_org = x_mask x_mask = split_sequence( - x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad() + x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") ) - x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) origin_x = x.clone().detach() for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks): @@ -135,16 +133,17 @@ def teacache_forward( self.previous_residual = x - origin_x else: # shard over the sequence dim if sp is enabled - if enable_sequence_parallel(): - set_temporal_pad(T) - set_spatial_pad(S) - x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()) + if self.parallel_manager.sp_size > 1: + set_pad("temporal", T, self.parallel_manager.sp_group) + set_pad("spatial", S, self.parallel_manager.sp_group) + x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")) T = x.shape[1] x_mask_org = x_mask x_mask = split_sequence( - x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad() + x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") ) x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) + for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks): x = auto_grad_checkpoint( spatial_block, @@ -174,25 +173,23 @@ def teacache_forward( all_timesteps=all_timesteps, ) - if enable_sequence_parallel(): + if self.parallel_manager.sp_size > 1: if self.enable_teacache: if should_calc: x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) self.previous_residual = rearrange(self.previous_residual, "B (T S) C -> B T S C", T=T, S=S) - x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal")) - self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal")) + x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal")) + self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal")) T, S = x.shape[1], x.shape[2] x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) self.previous_residual = rearrange(self.previous_residual, "B T S C -> B (T S) C", T=T, S=S) x_mask = x_mask_org else: x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) - x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal")) + x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal")) T, S = x.shape[1], x.shape[2] x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) x_mask = x_mask_org - - # === final layer === x = self.final_layer(x, t, x_mask, t0, T, S) x = self.unpatchify(x, T, H, W, Tx, Hx, Wx) @@ -201,12 +198,11 @@ def teacache_forward( x = x.to(torch.float32) # === Gather Output === - if get_cfg_parallel_size() > 1: - x = gather_sequence(x, get_data_parallel_group(), dim=0) + if self.parallel_manager.cp_size > 1: + x = gather_sequence(x, self.parallel_manager.cp_group, dim=0) return x - def eval_base(prompt_list): config = OpenSoraConfig() engine = VideoSysEngine(config) @@ -235,9 +231,9 @@ def eval_teacache_fast(prompt_list): generate_func(engine, prompt_list, "./samples/opensora_teacache_fast", loop=5) - if __name__ == "__main__": prompt_list = read_prompt_list("vbench/VBench_full_info.json") - # eval_base(prompt_list) + eval_base(prompt_list) eval_teacache_slow(prompt_list) - # eval_teacache_fast(prompt_list) + eval_teacache_fast(prompt_list) + \ No newline at end of file diff --git a/eval/teacache/experiments/opensora_plan.py b/eval/teacache/experiments/opensora_plan.py index 377f307..b056687 100644 --- a/eval/teacache/experiments/opensora_plan.py +++ b/eval/teacache/experiments/opensora_plan.py @@ -5,12 +5,7 @@ import torch.nn.functional as F from einops import rearrange, repeat import numpy as np from typing import Any, Dict, Optional, Tuple -from videosys.core.parallel_mgr import ( - enable_sequence_parallel, - get_cfg_parallel_group, - get_cfg_parallel_size, - get_sequence_parallel_group, -) +from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence def teacache_forward( self, @@ -66,7 +61,7 @@ def teacache_forward( `tuple` where the first element is the sample tensor. """ # 0. Split batch - if get_cfg_parallel_size() > 1: + if self.parallel_manager.cp_size > 1: ( hidden_states, timestep, @@ -75,7 +70,7 @@ def teacache_forward( attention_mask, encoder_attention_mask, ) = batch_func( - partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0), + partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), hidden_states, timestep, encoder_hidden_states, @@ -204,20 +199,20 @@ def teacache_forward( if not should_calc: hidden_states += self.previous_residual else: - if enable_sequence_parallel(): - set_temporal_pad(frame + use_image_num) - set_spatial_pad(num_patches) + if self.parallel_manager.sp_size > 1: + set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group) + set_pad("spatial", num_patches, self.parallel_manager.sp_group) hidden_states = self.split_from_second_dim(hidden_states, input_batch_size) encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size) timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size) attention_mask = self.split_from_second_dim(attention_mask, input_batch_size) attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size) temp_pos_embed = split_sequence( - self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad() + self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") ) else: temp_pos_embed = self.temp_pos_embed - ori_hidden_states = hidden_states.clone() + ori_hidden_states = hidden_states.clone().detach() for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): if self.training and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( @@ -359,20 +354,20 @@ def teacache_forward( ).contiguous() self.previous_residual = hidden_states - ori_hidden_states else: - if enable_sequence_parallel(): - set_temporal_pad(frame + use_image_num) - set_spatial_pad(num_patches) + if self.parallel_manager.sp_size > 1: + set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group) + set_pad("spatial", num_patches, self.parallel_manager.sp_group) hidden_states = self.split_from_second_dim(hidden_states, input_batch_size) + self.previous_residual = self.split_from_second_dim(self.previous_residual, input_batch_size) if self.previous_residual is not None else None encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size) timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size) attention_mask = self.split_from_second_dim(attention_mask, input_batch_size) attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size) temp_pos_embed = split_sequence( - self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad() + self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") ) else: temp_pos_embed = self.temp_pos_embed - for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): if self.training and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( @@ -512,8 +507,8 @@ def teacache_forward( hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() - - if enable_sequence_parallel(): + + if self.parallel_manager.sp_size > 1: if self.enable_teacache: if should_calc: hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size) @@ -550,8 +545,8 @@ def teacache_forward( output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous() # 3. Gather batch for data parallelism - if get_cfg_parallel_size() > 1: - output = gather_sequence(output, get_cfg_parallel_group(), dim=0) + if self.parallel_manager.cp_size > 1: + output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) if not return_dict: return (output,) @@ -559,14 +554,8 @@ def teacache_forward( return Transformer3DModelOutput(sample=output) - -def eval_base(prompt_list): - config = OpenSoraPlanConfig() - engine = VideoSysEngine(config) - generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5) - def eval_teacache_slow(prompt_list): - config = OpenSoraPlanConfig() + config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512") engine = VideoSysEngine(config) engine.driver_worker.transformer.enable_teacache = True engine.driver_worker.transformer.rel_l1_thresh = 0.1 @@ -577,7 +566,7 @@ def eval_teacache_slow(prompt_list): generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_slow", loop=5) def eval_teacache_fast(prompt_list): - config = OpenSoraPlanConfig() + config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512") engine = VideoSysEngine(config) engine.driver_worker.transformer.enable_teacache = True engine.driver_worker.transformer.rel_l1_thresh = 0.2 @@ -587,8 +576,16 @@ def eval_teacache_fast(prompt_list): engine.driver_worker.transformer.__class__.forward = teacache_forward generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_fast", loop=5) + +def eval_base(prompt_list): + config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", ) + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5) + + if __name__ == "__main__": prompt_list = read_prompt_list("vbench/VBench_full_info.json") - # eval_base(prompt_list) + eval_base(prompt_list) eval_teacache_slow(prompt_list) - # eval_teacache_fast(prompt_list) + eval_teacache_fast(prompt_list) + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ea6d973..3b8a98d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ +accelerate>0.17.0 +bs4 click -colossalai +colossalai==0.4.0 diffusers==0.30.0 einops fabric @@ -16,7 +18,9 @@ pydantic ray rich safetensors +sentencepiece timm torch>=1.13 tqdm -transformers +peft==0.13.2 +transformers==4.39.3 diff --git a/setup.py b/setup.py index 50b9537..9a2ed05 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,9 @@ from typing import List from setuptools import find_packages, setup +from setuptools.command.develop import develop +from setuptools.command.egg_info import egg_info +from setuptools.command.install import install def fetch_requirements(path) -> List[str]: @@ -14,7 +17,9 @@ def fetch_requirements(path) -> List[str]: The lines in the requirements file. """ with open(path, "r") as fd: - return [r.strip() for r in fd.readlines()] + requirements = [r.strip() for r in fd.readlines()] + # requirements.remove("colossalai") + return requirements def fetch_readme() -> str: @@ -28,6 +33,28 @@ def fetch_readme() -> str: return f.read() +def custom_install(): + return ["pip", "install", "colossalai", "--no-deps"] + + +class CustomInstallCommand(install): + def run(self): + install.run(self) + self.spawn(custom_install()) + + +class CustomDevelopCommand(develop): + def run(self): + develop.run(self) + self.spawn(custom_install()) + + +class CustomEggInfoCommand(egg_info): + def run(self): + egg_info.run(self) + self.spawn(custom_install()) + + setup( name="videosys", version="2.0.0", @@ -39,12 +66,17 @@ setup( "*.egg-info", ) ), - description="VideoSys", + description="TeaCache", long_description=fetch_readme(), long_description_content_type="text/markdown", license="Apache Software License 2.0", install_requires=fetch_requirements("requirements.txt"), - python_requires=">=3.6", + python_requires=">=3.7", + # cmdclass={ + # "install": CustomInstallCommand, + # "develop": CustomDevelopCommand, + # "egg_info": CustomEggInfoCommand, + # }, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", diff --git a/videosys/__init__.py b/videosys/__init__.py index 859fb7c..6c539c0 100644 --- a/videosys/__init__.py +++ b/videosys/__init__.py @@ -3,13 +3,20 @@ from .core.parallel_mgr import initialize from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline -from .pipelines.open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline +from .pipelines.open_sora_plan import ( + OpenSoraPlanConfig, + OpenSoraPlanPipeline, + OpenSoraPlanV110PABConfig, + OpenSoraPlanV120PABConfig, +) +from .pipelines.vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline __all__ = [ "initialize", "VideoSysEngine", "LattePipeline", "LatteConfig", "LattePABConfig", - "OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanPABConfig", + "OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig", "OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig", - "CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig" + "CogVideoXPipeline", "CogVideoXConfig", "CogVideoXPABConfig", + "VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig" ] # fmt: skip diff --git a/videosys/core/comm.py b/videosys/core/comm.py index 175fba5..75c242b 100644 --- a/videosys/core/comm.py +++ b/videosys/core/comm.py @@ -7,8 +7,6 @@ from einops import rearrange from torch import Tensor from torch.distributed import ProcessGroup -from videosys.core.parallel_mgr import get_sequence_parallel_size - # ====================================================== # Model # ====================================================== @@ -369,30 +367,18 @@ def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0): # Pad # ============================== -SPTIAL_PAD = 0 -TEMPORAL_PAD = 0 +PAD_DICT = {} -def set_spatial_pad(dim_size: int): - sp_size = get_sequence_parallel_size() +def set_pad(name: str, dim_size: int, parallel_group: dist.ProcessGroup): + sp_size = dist.get_world_size(parallel_group) pad = (sp_size - (dim_size % sp_size)) % sp_size - global SPTIAL_PAD - SPTIAL_PAD = pad + global PAD_DICT + PAD_DICT[name] = pad -def get_spatial_pad() -> int: - return SPTIAL_PAD - - -def set_temporal_pad(dim_size: int): - sp_size = get_sequence_parallel_size() - pad = (sp_size - (dim_size % sp_size)) % sp_size - global TEMPORAL_PAD - TEMPORAL_PAD = pad - - -def get_temporal_pad() -> int: - return TEMPORAL_PAD +def get_pad(name) -> int: + return PAD_DICT[name] def all_to_all_with_pad( @@ -418,3 +404,17 @@ def all_to_all_with_pad( input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) return input_ + + +def split_from_second_dim(x, batch_size, parallel_group): + x = x.view(batch_size, -1, *x.shape[1:]) + x = split_sequence(x, parallel_group, dim=1, grad_scale="down", pad=get_pad("temporal")) + x = x.reshape(-1, *x.shape[2:]) + return x + + +def gather_from_second_dim(x, batch_size, parallel_group): + x = x.view(batch_size, -1, *x.shape[1:]) + x = gather_sequence(x, parallel_group, dim=1, grad_scale="up", pad=get_pad("temporal")) + x = x.reshape(-1, *x.shape[2:]) + return x diff --git a/videosys/core/engine.py b/videosys/core/engine.py index 6d1868f..086b193 100644 --- a/videosys/core/engine.py +++ b/videosys/core/engine.py @@ -66,7 +66,7 @@ class VideoSysEngine: # TODO: add more options here for pipeline, or wrap all options into config def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None): - videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method, seed=42) + videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method) pipeline = pipeline_cls(self.config) return pipeline diff --git a/videosys/core/pab_mgr.py b/videosys/core/pab_mgr.py index ecd387b..a582105 100644 --- a/videosys/core/pab_mgr.py +++ b/videosys/core/pab_mgr.py @@ -6,7 +6,6 @@ PAB_MANAGER = None class PABConfig: def __init__( self, - steps: int, cross_broadcast: bool = False, cross_threshold: list = None, cross_range: int = None, @@ -20,7 +19,7 @@ class PABConfig: mlp_spatial_broadcast_config: dict = None, mlp_temporal_broadcast_config: dict = None, ): - self.steps = steps + self.steps = None self.cross_broadcast = cross_broadcast self.cross_threshold = cross_threshold @@ -45,7 +44,7 @@ class PABManager: def __init__(self, config: PABConfig): self.config: PABConfig = config - init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}." + init_prompt = f"Init Pyramid Attention Broadcast." init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}." init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}." init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}." @@ -78,7 +77,7 @@ class PABManager: count = (count + 1) % self.config.steps return flag, count - def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int): + def if_broadcast_spatial(self, timestep: int, count: int): if ( self.config.spatial_broadcast and (timestep is not None) @@ -213,10 +212,10 @@ def if_broadcast_temporal(timestep: int, count: int): return PAB_MANAGER.if_broadcast_temporal(timestep, count) -def if_broadcast_spatial(timestep: int, count: int, block_idx: int): +def if_broadcast_spatial(timestep: int, count: int): if not enable_pab(): return False, count - return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx) + return PAB_MANAGER.if_broadcast_spatial(timestep, count) def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False): diff --git a/videosys/core/parallel_mgr.py b/videosys/core/parallel_mgr.py index fbb9ccf..2036930 100644 --- a/videosys/core/parallel_mgr.py +++ b/videosys/core/parallel_mgr.py @@ -1,14 +1,9 @@ -from typing import Optional - import torch import torch.distributed as dist from colossalai.cluster.process_group_mesh import ProcessGroupMesh from torch.distributed import ProcessGroup from videosys.utils.logging import init_dist_logger, logger -from videosys.utils.utils import set_seed - -PARALLEL_MANAGER = None class ParallelManager(ProcessGroupMesh): @@ -21,71 +16,28 @@ class ParallelManager(ProcessGroupMesh): self.dp_rank = dist.get_rank(self.dp_group) self.cp_size = cp_size - self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis) - self.cp_rank = dist.get_rank(self.cp_group) + if cp_size > 1: + self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis) + self.cp_rank = dist.get_rank(self.cp_group) + else: + self.cp_group = None + self.cp_rank = None self.sp_size = sp_size - self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis) - self.sp_rank = dist.get_rank(self.sp_group) - self.enable_sp = sp_size > 1 + if sp_size > 1: + self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis) + self.sp_rank = dist.get_rank(self.sp_group) + else: + self.sp_group = None + self.sp_rank = None logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}") -def set_parallel_manager(dp_size, cp_size, sp_size): - global PARALLEL_MANAGER - PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size) - - -def get_data_parallel_group(): - return PARALLEL_MANAGER.dp_group - - -def get_data_parallel_size(): - return PARALLEL_MANAGER.dp_size - - -def get_data_parallel_rank(): - return PARALLEL_MANAGER.dp_rank - - -def get_sequence_parallel_group(): - return PARALLEL_MANAGER.sp_group - - -def get_sequence_parallel_size(): - return PARALLEL_MANAGER.sp_size - - -def get_sequence_parallel_rank(): - return PARALLEL_MANAGER.sp_rank - - -def get_cfg_parallel_group(): - return PARALLEL_MANAGER.cp_group - - -def get_cfg_parallel_size(): - return PARALLEL_MANAGER.cp_size - - -def enable_sequence_parallel(): - if PARALLEL_MANAGER is None: - return False - return PARALLEL_MANAGER.enable_sp - - -def get_parallel_manager(): - return PARALLEL_MANAGER - - def initialize( rank=0, world_size=1, init_method=None, - seed: Optional[int] = None, - sp_size: Optional[int] = None, - enable_cp: bool = False, ): if not dist.is_initialized(): try: @@ -97,24 +49,3 @@ def initialize( init_dist_logger() torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - - # init sequence parallel - if sp_size is None: - sp_size = dist.get_world_size() - dp_size = 1 - else: - assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size" - dp_size = dist.get_world_size() // sp_size - - # update cfg parallel - # NOTE: enable cp parallel will be slower. disable it for now. - if False and enable_cp and sp_size % 2 == 0: - sp_size = sp_size // 2 - cp_size = 2 - else: - cp_size = 1 - - set_parallel_manager(dp_size, cp_size, sp_size) - - if seed is not None: - set_seed(seed + get_data_parallel_rank()) diff --git a/videosys/core/pipeline.py b/videosys/core/pipeline.py index 75b79d3..0aafb96 100644 --- a/videosys/core/pipeline.py +++ b/videosys/core/pipeline.py @@ -13,9 +13,10 @@ class VideoSysPipeline(DiffusionPipeline): @staticmethod def set_eval_and_device(device: torch.device, *modules): - for module in modules: - module.eval() - module.to(device) + modules = list(modules) + for i in range(len(modules)): + modules[i] = modules[i].eval() + modules[i] = modules[i].to(device) @abstractmethod def generate(self, *args, **kwargs): diff --git a/videosys/models/autoencoders/autoencoder_kl_open_sora.py b/videosys/models/autoencoders/autoencoder_kl_open_sora.py index d073fe4..9a69d80 100644 --- a/videosys/models/autoencoders/autoencoder_kl_open_sora.py +++ b/videosys/models/autoencoders/autoencoder_kl_open_sora.py @@ -694,7 +694,10 @@ class VideoAutoencoderPipeline(PreTrainedModel): else: return x - def forward(self, x): + def forward(self, x, decode_only=False, **kwargs): + if decode_only: + return self.decode(x, **kwargs) + assert self.cal_loss, "This method is only available when cal_loss is True" z, posterior, x_z = self.encode(x) x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2]) diff --git a/videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v110.py b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v110.py new file mode 100644 index 0000000..30338fc --- /dev/null +++ b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v110.py @@ -0,0 +1,1643 @@ +# Adapted from Open-Sora-Plan + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan +# -------------------------------------------------------- +import glob +import os +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from diffusers import ConfigMixin, ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import logging +from einops import rearrange +from torch import nn + +logging.set_verbosity_error() + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +def tensor_to_video(x): + x = x.detach().cpu() + x = torch.clamp(x, -1, 1) + x = (x + 1) / 2 + x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> + x = (255 * x).astype(np.uint8) + return x + + +def nonlinearity(x): + return x * torch.sigmoid(x) + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean + + +def resolve_str_to_obj(str_val, append=True): + return globals()[str_val] + + +class VideoBaseAE_PL(ModelMixin, ConfigMixin): + config_name = "config.json" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def encode(self, x: torch.Tensor, *args, **kwargs): + pass + + def decode(self, encoding: torch.Tensor, *args, **kwargs): + pass + + @property + def num_training_steps(self) -> int: + """Total training steps inferred from datamodule and devices.""" + if self.trainer.max_steps: + return self.trainer.max_steps + + limit_batches = self.trainer.limit_train_batches + batches = len(self.train_dataloader()) + batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches) + + num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) + if self.trainer.tpu_cores: + num_devices = max(num_devices, self.trainer.tpu_cores) + + effective_accum = self.trainer.accumulate_grad_batches * num_devices + return (batches // effective_accum) * self.trainer.max_epochs + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, "*.ckpt")) + if ckpt_files: + # Adapt to PyTorch Lightning + last_ckpt_file = ckpt_files[-1] + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + model = cls.from_config(config_file) + # print("init from {}".format(last_ckpt_file)) + model.init_from_ckpt(last_ckpt_file) + return model + else: + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + +class Encoder(nn.Module): + def __init__( + self, + z_channels: int, + hidden_size: int, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = (16,), + conv_in: str = "Conv2d", + conv_out: str = "CasualConv3d", + attention: str = "AttnBlock", + resnet_blocks: Tuple[str] = ( + "ResnetBlock2D", + "ResnetBlock2D", + "ResnetBlock2D", + "ResnetBlock3D", + ), + spatial_downsample: Tuple[str] = ( + "Downsample", + "Downsample", + "Downsample", + "", + ), + temporal_downsample: Tuple[str] = ("", "", "TimeDownsampleRes2x", ""), + mid_resnet: str = "ResnetBlock3D", + dropout: float = 0.0, + resolution: int = 256, + num_res_blocks: int = 2, + double_z: bool = True, + ) -> None: + super().__init__() + assert len(resnet_blocks) == len(hidden_size_mult), print(hidden_size_mult, resnet_blocks) + # ---- Config ---- + self.num_resolutions = len(hidden_size_mult) + self.resolution = resolution + self.num_res_blocks = num_res_blocks + + # ---- In ---- + self.conv_in = resolve_str_to_obj(conv_in)(3, hidden_size, kernel_size=3, stride=1, padding=1) + + # ---- Downsample ---- + curr_res = resolution + in_ch_mult = (1,) + tuple(hidden_size_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = hidden_size * in_ch_mult[i_level] + block_out = hidden_size * hidden_size_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + resolve_str_to_obj(resnet_blocks[i_level])( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(resolve_str_to_obj(attention)(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if spatial_downsample[i_level]: + down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(block_in, block_in) + curr_res = curr_res // 2 + if temporal_downsample[i_level]: + down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(block_in, block_in) + self.down.append(down) + + # ---- Mid ---- + self.mid = nn.Module() + self.mid.block_1 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) + self.mid.block_2 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + # ---- Out ---- + self.norm_out = Normalize(block_in) + self.conv_out = resolve_str_to_obj(conv_out)( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if hasattr(self.down[i_level], "downsample"): + hs.append(self.down[i_level].downsample(hs[-1])) + if hasattr(self.down[i_level], "time_downsample"): + hs_down = self.down[i_level].time_downsample(hs[-1]) + hs.append(hs_down) + + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + z_channels: int, + hidden_size: int, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = (16,), + conv_in: str = "Conv2d", + conv_out: str = "CasualConv3d", + attention: str = "AttnBlock", + resnet_blocks: Tuple[str] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + spatial_upsample: Tuple[str] = ( + "", + "SpatialUpsample2x", + "SpatialUpsample2x", + "SpatialUpsample2x", + ), + temporal_upsample: Tuple[str] = ("", "", "", "TimeUpsampleRes2x"), + mid_resnet: str = "ResnetBlock3D", + dropout: float = 0.0, + resolution: int = 256, + num_res_blocks: int = 2, + ): + super().__init__() + # ---- Config ---- + self.num_resolutions = len(hidden_size_mult) + self.resolution = resolution + self.num_res_blocks = num_res_blocks + + # ---- In ---- + block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.conv_in = resolve_str_to_obj(conv_in)(z_channels, block_in, kernel_size=3, padding=1) + + # ---- Mid ---- + self.mid = nn.Module() + self.mid.block_1 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) + self.mid.block_2 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + + # ---- Upsample ---- + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = hidden_size * hidden_size_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + resolve_str_to_obj(resnet_blocks[i_level])( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(resolve_str_to_obj(attention)(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if spatial_upsample[i_level]: + up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(block_in, block_in) + curr_res = curr_res * 2 + if temporal_upsample[i_level]: + up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(block_in, block_in) + self.up.insert(0, up) + + # ---- Out ---- + self.norm_out = Normalize(block_in) + self.conv_out = resolve_str_to_obj(conv_out)(block_in, 3, kernel_size=3, padding=1) + + def forward(self, z): + h = self.conv_in(z) + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if hasattr(self.up[i_level], "upsample"): + h = self.up[i_level].upsample(h) + if hasattr(self.up[i_level], "time_upsample"): + h = self.up[i_level].time_upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class CausalVAEModel(VideoBaseAE_PL): + @register_to_config + def __init__( + self, + lr: float = 1e-5, + hidden_size: int = 128, + z_channels: int = 4, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = [], + dropout: float = 0.0, + resolution: int = 256, + double_z: bool = True, + embed_dim: int = 4, + num_res_blocks: int = 2, + loss_type: str = "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator", + loss_params: dict = { + "kl_weight": 0.000001, + "logvar_init": 0.0, + "disc_start": 2001, + "disc_weight": 0.5, + }, + q_conv: str = "CausalConv3d", + encoder_conv_in: str = "CausalConv3d", + encoder_conv_out: str = "CausalConv3d", + encoder_attention: str = "AttnBlock3D", + encoder_resnet_blocks: Tuple[str] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + encoder_spatial_downsample: Tuple[str] = ( + "SpatialDownsample2x", + "SpatialDownsample2x", + "SpatialDownsample2x", + "", + ), + encoder_temporal_downsample: Tuple[str] = ( + "", + "TimeDownsample2x", + "TimeDownsample2x", + "", + ), + encoder_mid_resnet: str = "ResnetBlock3D", + decoder_conv_in: str = "CausalConv3d", + decoder_conv_out: str = "CausalConv3d", + decoder_attention: str = "AttnBlock3D", + decoder_resnet_blocks: Tuple[str] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + decoder_spatial_upsample: Tuple[str] = ( + "", + "SpatialUpsample2x", + "SpatialUpsample2x", + "SpatialUpsample2x", + ), + decoder_temporal_upsample: Tuple[str] = ("", "", "TimeUpsample2x", "TimeUpsample2x"), + decoder_mid_resnet: str = "ResnetBlock3D", + ) -> None: + super().__init__() + self.tile_sample_min_size = 256 + self.tile_sample_min_size_t = 65 + self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1))) + t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0] + self.tile_latent_min_size_t = int((self.tile_sample_min_size_t - 1) / (2 ** len(t_down_ratio))) + 1 + self.tile_overlap_factor = 0.25 + self.use_tiling = False + + self.learning_rate = lr + self.lr_g_factor = 1.0 + + self.encoder = Encoder( + z_channels=z_channels, + hidden_size=hidden_size, + hidden_size_mult=hidden_size_mult, + attn_resolutions=attn_resolutions, + conv_in=encoder_conv_in, + conv_out=encoder_conv_out, + attention=encoder_attention, + resnet_blocks=encoder_resnet_blocks, + spatial_downsample=encoder_spatial_downsample, + temporal_downsample=encoder_temporal_downsample, + mid_resnet=encoder_mid_resnet, + dropout=dropout, + resolution=resolution, + num_res_blocks=num_res_blocks, + double_z=double_z, + ) + + self.decoder = Decoder( + z_channels=z_channels, + hidden_size=hidden_size, + hidden_size_mult=hidden_size_mult, + attn_resolutions=attn_resolutions, + conv_in=decoder_conv_in, + conv_out=decoder_conv_out, + attention=decoder_attention, + resnet_blocks=decoder_resnet_blocks, + spatial_upsample=decoder_spatial_upsample, + temporal_upsample=decoder_temporal_upsample, + mid_resnet=decoder_mid_resnet, + dropout=dropout, + resolution=resolution, + num_res_blocks=num_res_blocks, + ) + + quant_conv_cls = resolve_str_to_obj(q_conv) + self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1) + self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1) + + def encode(self, x): + if self.use_tiling and ( + x.shape[-1] > self.tile_sample_min_size + or x.shape[-2] > self.tile_sample_min_size + or x.shape[-3] > self.tile_sample_min_size_t + ): + return self.tiled_encode(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + if self.use_tiling and ( + z.shape[-1] > self.tile_latent_min_size + or z.shape[-2] > self.tile_latent_min_size + or z.shape[-3] > self.tile_latent_min_size_t + ): + return self.tiled_decode(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx): + if hasattr(self.loss, "discriminator"): + return self._training_step_gan(batch, batch_idx=batch_idx) + else: + return self._training_step(batch, batch_idx=batch_idx) + + def _training_step(self, batch, batch_idx): + inputs = self.get_input(batch, "video") + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + split="train", + ) + self.log( + "aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def _training_step_gan(self, batch, batch_idx): + inputs = self.get_input(batch, "video") + reconstructions, posterior = self(inputs) + opt1, opt2 = self.optimizers() + + # ---- AE Loss ---- + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log( + "aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + opt1.zero_grad() + self.manual_backward(aeloss) + self.clip_gradients(opt1, gradient_clip_val=1, gradient_clip_algorithm="norm") + opt1.step() + # ---- GAN Loss ---- + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log( + "discloss", + discloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + opt2.zero_grad() + self.manual_backward(discloss) + self.clip_gradients(opt2, gradient_clip_val=1, gradient_clip_algorithm="norm") + opt2.step() + self.log_dict( + {**log_dict_ae, **log_dict_disc}, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=False, + ) + + def configure_optimizers(self): + from itertools import chain + + lr = self.learning_rate + modules_to_train = [ + self.encoder.named_parameters(), + self.decoder.named_parameters(), + self.post_quant_conv.named_parameters(), + self.quant_conv.named_parameters(), + ] + params_with_time = [] + params_without_time = [] + for name, param in chain(*modules_to_train): + if "time" in name: + params_with_time.append(param) + else: + params_without_time.append(param) + optimizers = [] + opt_ae = torch.optim.Adam( + [ + {"params": params_with_time, "lr": lr}, + {"params": params_without_time, "lr": lr}, + ], + lr=lr, + betas=(0.5, 0.9), + ) + optimizers.append(opt_ae) + + if hasattr(self.loss, "discriminator"): + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) + optimizers.append(opt_disc) + + return optimizers, [] + + def get_last_layer(self): + if hasattr(self.decoder.conv_out, "conv"): + return self.decoder.conv_out.conv.weight + else: + return self.decoder.conv_out.weight + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x): + t = x.shape[2] + t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)] + if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: + t_chunk_start_end = [[0, t]] + else: + t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] + if t_chunk_start_end[-1][-1] > t: + t_chunk_start_end[-1][-1] = t + elif t_chunk_start_end[-1][-1] < t: + last_start_end = [t_chunk_idx[-1], t] + t_chunk_start_end.append(last_start_end) + moments = [] + for idx, (start, end) in enumerate(t_chunk_start_end): + chunk_x = x[:, :, start:end] + if idx != 0: + moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:] + else: + moment = self.tiled_encode2d(chunk_x, return_moments=True) + moments.append(moment) + moments = torch.cat(moments, dim=2) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def tiled_decode(self, x): + t = x.shape[2] + t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)] + if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: + t_chunk_start_end = [[0, t]] + else: + t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] + if t_chunk_start_end[-1][-1] > t: + t_chunk_start_end[-1][-1] = t + elif t_chunk_start_end[-1][-1] < t: + last_start_end = [t_chunk_idx[-1], t] + t_chunk_start_end.append(last_start_end) + dec_ = [] + for idx, (start, end) in enumerate(t_chunk_start_end): + chunk_x = x[:, :, start:end] + if idx != 0: + dec = self.tiled_decode2d(chunk_x)[:, :, 1:] + else: + dec = self.tiled_decode2d(chunk_x) + dec_.append(dec) + dec_ = torch.cat(dec_, dim=2) + return dec_ + + def tiled_encode2d(self, x, return_moments=False): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + moments = torch.cat(result_rows, dim=3) + posterior = DiagonalGaussianDistribution(moments) + if return_moments: + return moments + return posterior + + def tiled_decode2d(self, z): + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def enable_tiling(self, use_tiling: bool = True): + self.use_tiling = use_tiling + + def disable_tiling(self): + self.enable_tiling(False) + + def init_from_ckpt(self, path, ignore_keys=list(), remove_loss=False): + sd = torch.load(path, map_location="cpu") + # print("init from " + path) + if "state_dict" in sd: + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, "video") + latents = self.encode(inputs).sample() + video_recon = self.decode(latents) + for idx in range(len(video_recon)): + self.logger.log_video(f"recon {batch_idx} {idx}", [tensor_to_video(video_recon[idx])], fps=[10]) + + +class CausalVAEModelWrapper(nn.Module): + def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs): + super(CausalVAEModelWrapper, self).__init__() + # if os.path.exists(ckpt): + # self.vae = CausalVAEModel.load_from_checkpoint(ckpt) + self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) + + def encode(self, x): # b c t h w + # x = self.vae.encode(x).sample() + x = self.vae.encode(x).sample().mul_(0.18215) + return x + + def decode(self, x): + # x = self.vae.decode(x) + x = self.vae.decode(x / 0.18215) + x = rearrange(x, "b c t h w -> b t c h w").contiguous() + return x + + def dtype(self): + return self.vae.dtype + + # + # def device(self): + # return self.vae.device + + def forward(self, x): + return self.decode(x) + + +videobase_ae_stride = { + "CausalVAEModel_4x8x8": [4, 8, 8], +} + +videobase_ae_channel = { + "CausalVAEModel_4x8x8": 4, +} + +videobase_ae = { + "CausalVAEModel_4x8x8": CausalVAEModelWrapper, +} + + +ae_stride_config = {} +ae_stride_config.update(videobase_ae_stride) + +ae_channel_config = {} +ae_channel_config.update(videobase_ae_channel) + + +def getae_wrapper(ae): + """deprecation""" + ae = videobase_ae.get(ae, None) + assert ae is not None + return ae + + +def video_to_image(func): + def wrapper(self, x, *args, **kwargs): + if x.dim() == 5: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = func(self, x, *args, **kwargs) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + + return wrapper + + +class Block(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + +class LinearAttention(Block): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock3D(Block): + """Compatible with old versions, there are issues, use with caution.""" + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, t, h, w = q.shape + q = q.reshape(b * t, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b * t, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b * t, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, t, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class AttnBlock3DFix(nn.Module): + """ + Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172. + """ + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c) + b, c, t, h, w = q.shape + q = q.permute(0, 2, 1, 3, 4) + q = q.reshape(b * t, c, h * w) + q = q.permute(0, 2, 1) + + # k: (b c t h w) -> (b t c h w) -> (b*t c h*w) + k = k.permute(0, 2, 1, 3, 4) + k = k.reshape(b * t, c, h * w) + + # w: (b*t hw hw) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + # v: (b c t h w) -> (b t c h w) -> (bt c hw) + # w_: (bt hw hw) -> (bt hw hw) + v = v.permute(0, 2, 1, 3, 4) + v = v.reshape(b * t, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + + # h_: (b*t c hw) -> (b t c h w) -> (b c t h w) + h_ = h_.reshape(b, t, c, h, w) + h_ = h_.permute(0, 2, 1, 3, 4) + + h_ = self.proj_out(h_) + + return x + h_ + + +class AttnBlock(Block): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + @video_to_image + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class TemporalAttnBlock(Block): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, t, h, w = q.shape + q = rearrange(q, "b c t h w -> (b h w) t c") + k = rearrange(k, "b c t h w -> (b h w) c t") + v = rearrange(v, "b c t h w -> (b h w) c t") + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w) + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + print(attn_type) + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "vanilla3D": + return AttnBlock3D(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Conv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]] = 3, + stride: Union[int, Tuple[int]] = 1, + padding: Union[str, int, Tuple[int]] = 0, + dilation: Union[int, Tuple[int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + @video_to_image + def forward(self, x): + return super().forward(x) + + +class CausalConv3d(nn.Module): + def __init__( + self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.time_kernel_size = self.kernel_size[0] + self.chan_in = chan_in + self.chan_out = chan_out + stride = kwargs.pop("stride", 1) + padding = kwargs.pop("padding", 0) + padding = list(cast_tuple(padding, 3)) + padding[0] = 0 + stride = cast_tuple(stride, 3) + self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding) + self._init_weights(init_method) + + def _init_weights(self, init_method): + torch.tensor(self.kernel_size) + if init_method == "avg": + assert self.kernel_size[1] == 1 and self.kernel_size[2] == 1, "only support temporal up/down sample" + assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out" + weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)) + + eyes = torch.concat( + [ + torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3, + torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3, + torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3, + ], + dim=-1, + ) + weight[:, :, :, 0, 0] = eyes + + self.conv.weight = nn.Parameter( + weight, + requires_grad=True, + ) + elif init_method == "zero": + self.conv.weight = nn.Parameter( + torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)), + requires_grad=True, + ) + if self.conv.bias is not None: + nn.init.constant_(self.conv.bias, 0) + + def forward(self, x): + # 1 + 16 16 as video, 1 as image + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w + x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16 + return self.conv(x) + + +class GroupNorm(Block): + def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True) + + def forward(self, x): + return self.norm(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) + std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:, :, None, None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height * width * torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:, :, None, None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h + + +def nonlinearity(x): + return x * torch.sigmoid(x) + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + dims = list(range(n_dims)) + del dims[src_dim] + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim): + super().__init__() + self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) + self.register_buffer("N", torch.zeros(n_codes)) + self.register_buffer("z_avg", self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + y = self._tile(flat_inputs) + + y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + distances = ( + (flat_inputs**2).sum(dim=1, keepdim=True) + - 2 * flat_inputs @ self.embeddings.t() + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) + ) + + encoding_indices = torch.argmin(distances, dim=1) + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) + encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) + + embeddings = F.embedding(encoding_indices, self.embeddings) + embeddings = shift_dim(embeddings, -1, 1) + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) + encode_sum = flat_inputs.t() @ encode_onehot + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + usage = (self.N.view(self.n_codes, 1) >= 1).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return dict( + embeddings=embeddings_st, + encodings=encoding_indices, + commitment_loss=commitment_loss, + perplexity=perplexity, + ) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings + + +class ResnetBlock2D(Block): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + @video_to_image + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + x = x + h + return x + + +class ResnetBlock3D(Block): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1) + else: + self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h + + +class Upsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.with_conv = True + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + @video_to_image + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.with_conv = True + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) + + @video_to_image + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class SpatialDownsample2x(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3, 3), + stride: Union[int, Tuple[int]] = (2, 2), + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 2) + stride = cast_tuple(stride, 2) + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=0) + + def forward(self, x): + pad = (0, 1, 0, 1, 0, 0) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class SpatialUpsample2x(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3, 3), + stride: Union[int, Tuple[int]] = (1, 1), + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=1) + + def forward(self, x): + t = x.shape[2] + x = rearrange(x, "b c t h w -> b (c t) h w") + x = F.interpolate(x, scale_factor=(2, 2), mode="nearest") + x = rearrange(x, "b (c t) h w -> b c t h w", t=t) + x = self.conv(x) + return x + + +class TimeDownsample2x(Block): + def __init__(self, chan_in, chan_out, kernel_size: int = 3): + super().__init__() + self.kernel_size = kernel_size + self.conv = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1)) + + def forward(self, x): + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size - 1, 1, 1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + return self.conv(x) + + +class TimeUpsample2x(Block): + def __init__(self, chan_in, chan_out): + super().__init__() + + def forward(self, x): + if x.size(2) > 1: + x, x_ = x[:, :, :1], x[:, :, 1:] + x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear") + x = torch.concat([x, x_], dim=2) + return x + + +class TimeDownsampleRes2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 2.0, + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1)) + self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1)) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + alpha = torch.sigmoid(self.mix_factor) + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x) + + +class TimeUpsampleRes2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 2.0, + ): + super().__init__() + self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + alpha = torch.sigmoid(self.mix_factor) + if x.size(2) > 1: + x, x_ = x[:, :, :1], x[:, :, 1:] + x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear") + x = torch.concat([x, x_], dim=2) + return alpha * x + (1 - alpha) * self.conv(x) + + +class TimeDownsampleResAdv2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 1.5, + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1)) + self.attn = TemporalAttnBlock(in_channels) + self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0) + self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1)) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + alpha = torch.sigmoid(self.mix_factor) + return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(self.attn((self.res(x)))) + + +class TimeUpsampleResAdv2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 1.5, + ): + super().__init__() + self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0) + self.attn = TemporalAttnBlock(in_channels) + self.norm = Normalize(in_channels=in_channels) + self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + if x.size(2) > 1: + x, x_ = x[:, :, :1], x[:, :, 1:] + x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear") + x = torch.concat([x, x_], dim=2) + alpha = torch.sigmoid(self.mix_factor) + return alpha * x + (1 - alpha) * self.conv(self.attn(self.res(x))) diff --git a/videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v120.py b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v120.py new file mode 100644 index 0000000..0bb4aa8 --- /dev/null +++ b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan_v120.py @@ -0,0 +1,1139 @@ +# Adapted from Open-Sora-Plan + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan +# -------------------------------------------------------- + +import glob +import os +from copy import deepcopy +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers import ConfigMixin, ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from einops import rearrange +from huggingface_hub import hf_hub_download +from torch import nn +from torch.utils.checkpoint import checkpoint +from torchvision.transforms import Lambda + +npu_config = None + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length) + + +class Block(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + +class CausalConv3d(nn.Module): + def __init__( + self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.time_kernel_size = self.kernel_size[0] + self.chan_in = chan_in + self.chan_out = chan_out + stride = kwargs.pop("stride", 1) + padding = kwargs.pop("padding", 0) + padding = list(cast_tuple(padding, 3)) + padding[0] = 0 + stride = cast_tuple(stride, 3) + self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding) + self.pad = nn.ReplicationPad2d((0, 0, self.time_kernel_size - 1, 0)) + self._init_weights(init_method) + + def _init_weights(self, init_method): + torch.tensor(self.kernel_size) + if init_method == "avg": + assert self.kernel_size[1] == 1 and self.kernel_size[2] == 1, "only support temporal up/down sample" + assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out" + weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)) + + eyes = torch.concat( + [ + torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3, + torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3, + torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3, + ], + dim=-1, + ) + weight[:, :, :, 0, 0] = eyes + + self.conv.weight = nn.Parameter( + weight, + requires_grad=True, + ) + elif init_method == "zero": + self.conv.weight = nn.Parameter( + torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)), + requires_grad=True, + ) + if self.conv.bias is not None: + nn.init.constant_(self.conv.bias, 0) + + def forward(self, x): + if npu_config is not None and npu_config.on_npu: + x_dtype = x.dtype + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w + x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16 + return npu_config.run_conv3d(self.conv, x, x_dtype) + else: + # 1 + 16 16 as video, 1 as image + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w + x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16 + return self.conv(x) + + +def nonlinearity(x): + return x * torch.sigmoid(x) + + +def video_to_image(func): + def wrapper(self, x, *args, **kwargs): + if x.dim() == 5: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = func(self, x, *args, **kwargs) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + + return wrapper + + +class Conv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]] = 3, + stride: Union[int, Tuple[int]] = 1, + padding: Union[str, int, Tuple[int]] = 0, + dilation: Union[int, Tuple[int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + @video_to_image + def forward(self, x): + return super().forward(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class VideoBaseAE(ModelMixin, ConfigMixin): + config_name = "config.json" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def encode(self, x: torch.Tensor, *args, **kwargs): + pass + + def decode(self, encoding: torch.Tensor, *args, **kwargs): + pass + + @property + def num_training_steps(self) -> int: + """Total training steps inferred from datamodule and devices.""" + if self.trainer.max_steps: + return self.trainer.max_steps + + limit_batches = self.trainer.limit_train_batches + batches = len(self.train_dataloader()) + batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches) + + num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) + if self.trainer.tpu_cores: + num_devices = max(num_devices, self.trainer.tpu_cores) + + effective_accum = self.trainer.accumulate_grad_batches * num_devices + return (batches // effective_accum) * self.trainer.max_epochs + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, "*.ckpt")) + if not ckpt_files: + ckpt_file = hf_hub_download(pretrained_model_name_or_path, subfolder="vae", filename="checkpoint.ckpt") + config_file = hf_hub_download(pretrained_model_name_or_path, subfolder="vae", filename="config.json") + else: + ckpt_file = ckpt_files[-1] + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + + # Adapt to checkpoint + model = cls.from_config(config_file) + model.init_from_ckpt(ckpt_file) + return model + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean + + +class ResnetBlock2D(Block): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + @video_to_image + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + x = x + h + return x + + +class ResnetBlock3D(Block): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1) + else: + self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) + + def forward(self, x): + h = x + if npu_config is None: + h = self.norm1(h) + else: + h = npu_config.run_group_norm(self.norm1, h) + h = nonlinearity(h) + h = self.conv1(h) + if npu_config is None: + h = self.norm2(h) + else: + h = npu_config.run_group_norm(self.norm2, h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h + + +class SpatialUpsample2x(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3, 3), + stride: Union[int, Tuple[int]] = (1, 1), + unup=False, + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.unup = unup + self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=1) + + def forward(self, x): + if not self.unup: + t = x.shape[2] + x = rearrange(x, "b c t h w -> b (c t) h w") + x = F.interpolate(x, scale_factor=(2, 2), mode="nearest") + x = rearrange(x, "b (c t) h w -> b c t h w", t=t) + x = self.conv(x) + return x + + +class Spatial2xTime2x3DUpsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, x): + if x.size(2) > 1: + x, x_ = x[:, :, :1], x[:, :, 1:] + x_ = F.interpolate(x_, scale_factor=(2, 2, 2), mode="trilinear") + x = F.interpolate(x, scale_factor=(1, 2, 2), mode="trilinear") + x = torch.concat([x, x_], dim=2) + else: + x = F.interpolate(x, scale_factor=(1, 2, 2), mode="trilinear") + return self.conv(x) + + +class AttnBlock3DFix(nn.Module): + """ + Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172. + """ + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + + def forward(self, x): + h_ = x + if npu_config is None: + h_ = self.norm(h_) + else: + h_ = npu_config.run_group_norm(self.norm, h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c) + b, c, t, h, w = q.shape + q = q.permute(0, 2, 1, 3, 4) + q = q.reshape(b * t, c, h * w) + q = q.permute(0, 2, 1) + + # k: (b c t h w) -> (b t c h w) -> (b*t c h*w) + k = k.permute(0, 2, 1, 3, 4) + k = k.reshape(b * t, c, h * w) + + # w: (b*t hw hw) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + # v: (b c t h w) -> (b t c h w) -> (bt c hw) + # w_: (bt hw hw) -> (bt hw hw) + v = v.permute(0, 2, 1, 3, 4) + v = v.reshape(b * t, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + + # h_: (b*t c hw) -> (b t c h w) -> (b c t h w) + h_ = h_.reshape(b, t, c, h, w) + h_ = h_.permute(0, 2, 1, 3, 4) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Spatial2xTime2x3DDownsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=0, stride=2) + + def forward(self, x): + pad = (0, 1, 0, 1, 0, 0) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Downsample(Block): + def __init__(self, in_channels, out_channels, undown=False): + super().__init__() + self.with_conv = True + self.undown = undown + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + if self.undown: + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) + + @video_to_image + def forward(self, x): + if self.with_conv: + if self.undown: + if npu_config is not None and npu_config.on_npu: + x_dtype = x.dtype + x = x.to(npu_config.replaced_type) + x = npu_config.run_conv3d(self.conv, x, x_dtype) + else: + x = self.conv(x) + else: + pad = (0, 1, 0, 1) + if npu_config is not None and npu_config.on_npu: + x_dtype = x.dtype + x = x.to(npu_config.replaced_type) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = npu_config.run_conv3d(self.conv, x, x_dtype) + else: + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock3D_GC(Block): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1) + else: + self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) + + def forward(self, x): + return checkpoint(self._forward, x, use_reentrant=True) + + def _forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h + + +def resolve_str_to_obj(str_val, append=True): + return globals()[str_val] + + +class Encoder(nn.Module): + def __init__( + self, + z_channels: int, + hidden_size: int, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = (16,), + conv_in="Conv2d", + conv_out="CasualConv3d", + attention="AttnBlock", + resnet_blocks=( + "ResnetBlock2D", + "ResnetBlock2D", + "ResnetBlock2D", + "ResnetBlock3D", + ), + spatial_downsample=( + "Downsample", + "Downsample", + "Downsample", + "", + ), + temporal_downsample=("", "", "TimeDownsampleRes2x", ""), + mid_resnet="ResnetBlock3D", + dropout: float = 0.0, + resolution: int = 256, + num_res_blocks: int = 2, + double_z: bool = True, + ) -> None: + super().__init__() + assert len(resnet_blocks) == len(hidden_size_mult), print(hidden_size_mult, resnet_blocks) + # ---- Config ---- + self.num_resolutions = len(hidden_size_mult) + self.resolution = resolution + self.num_res_blocks = num_res_blocks + + # ---- In ---- + self.conv_in = resolve_str_to_obj(conv_in)(3, hidden_size, kernel_size=3, stride=1, padding=1) + + # ---- Downsample ---- + curr_res = resolution + in_ch_mult = (1,) + tuple(hidden_size_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = hidden_size * in_ch_mult[i_level] + block_out = hidden_size * hidden_size_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + resolve_str_to_obj(resnet_blocks[i_level])( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(resolve_str_to_obj(attention)(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if spatial_downsample[i_level]: + down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(block_in, block_in) + curr_res = curr_res // 2 + if temporal_downsample[i_level]: + down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(block_in, block_in) + self.down.append(down) + + # ---- Mid ---- + self.mid = nn.Module() + self.mid.block_1 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) + self.mid.block_2 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + # ---- Out ---- + self.norm_out = Normalize(block_in) + self.conv_out = resolve_str_to_obj(conv_out)( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if hasattr(self.down[i_level], "downsample"): + hs.append(self.down[i_level].downsample(hs[-1])) + if hasattr(self.down[i_level], "time_downsample"): + hs_down = self.down[i_level].time_downsample(hs[-1]) + hs.append(hs_down) + + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + if npu_config is None: + h = self.norm_out(h) + else: + h = npu_config.run_group_norm(self.norm_out, h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + z_channels: int, + hidden_size: int, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = (16,), + conv_in="Conv2d", + conv_out="CasualConv3d", + attention="AttnBlock", + resnet_blocks=( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + spatial_upsample=( + "", + "SpatialUpsample2x", + "SpatialUpsample2x", + "SpatialUpsample2x", + ), + temporal_upsample=("", "", "", "TimeUpsampleRes2x"), + mid_resnet="ResnetBlock3D", + dropout: float = 0.0, + resolution: int = 256, + num_res_blocks: int = 2, + ): + super().__init__() + # ---- Config ---- + self.num_resolutions = len(hidden_size_mult) + self.resolution = resolution + self.num_res_blocks = num_res_blocks + + # ---- In ---- + block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.conv_in = resolve_str_to_obj(conv_in)(z_channels, block_in, kernel_size=3, padding=1) + + # ---- Mid ---- + self.mid = nn.Module() + self.mid.block_1 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) + self.mid.block_2 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + + # ---- Upsample ---- + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = hidden_size * hidden_size_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + resolve_str_to_obj(resnet_blocks[i_level])( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(resolve_str_to_obj(attention)(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if spatial_upsample[i_level]: + up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(block_in, block_in) + curr_res = curr_res * 2 + if temporal_upsample[i_level]: + up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(block_in, block_in) + self.up.insert(0, up) + + # ---- Out ---- + self.norm_out = Normalize(block_in) + self.conv_out = resolve_str_to_obj(conv_out)(block_in, 3, kernel_size=3, padding=1) + + def forward(self, z): + h = self.conv_in(z) + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if hasattr(self.up[i_level], "upsample"): + h = self.up[i_level].upsample(h) + if hasattr(self.up[i_level], "time_upsample"): + h = self.up[i_level].time_upsample(h) + if npu_config is None: + h = self.norm_out(h) + else: + h = npu_config.run_group_norm(self.norm_out, h) + h = nonlinearity(h) + if npu_config is None: + h = self.conv_out(h) + else: + h_dtype = h.dtype + h = npu_config.run_conv3d(self.conv_out, h, h_dtype) + return h + + +class CausalVAEModel(VideoBaseAE): + @register_to_config + def __init__( + self, + hidden_size: int = 128, + z_channels: int = 4, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = [], + dropout: float = 0.0, + resolution: int = 256, + double_z: bool = True, + embed_dim: int = 4, + num_res_blocks: int = 2, + q_conv: str = "CausalConv3d", + encoder_conv_in="CausalConv3d", + encoder_conv_out="CausalConv3d", + encoder_attention="AttnBlock3D", + encoder_resnet_blocks=( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + encoder_spatial_downsample=( + "SpatialDownsample2x", + "SpatialDownsample2x", + "SpatialDownsample2x", + "", + ), + encoder_temporal_downsample=( + "", + "TimeDownsample2x", + "TimeDownsample2x", + "", + ), + encoder_mid_resnet="ResnetBlock3D", + decoder_conv_in="CausalConv3d", + decoder_conv_out="CausalConv3d", + decoder_attention="AttnBlock3D", + decoder_resnet_blocks=( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + decoder_spatial_upsample=( + "", + "SpatialUpsample2x", + "SpatialUpsample2x", + "SpatialUpsample2x", + ), + decoder_temporal_upsample=("", "", "TimeUpsample2x", "TimeUpsample2x"), + decoder_mid_resnet="ResnetBlock3D", + use_quant_layer: bool = True, + ) -> None: + super().__init__() + + self.tile_sample_min_size = 256 + self.tile_sample_min_size_t = 33 + self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1))) + + # t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0] + # self.tile_latent_min_size_t = int((self.tile_sample_min_size_t-1) / (2 ** len(t_down_ratio))) + 1 + self.tile_latent_min_size_t = 16 + self.tile_overlap_factor = 0.125 + self.use_tiling = False + + self.use_quant_layer = use_quant_layer + + self.encoder = Encoder( + z_channels=z_channels, + hidden_size=hidden_size, + hidden_size_mult=hidden_size_mult, + attn_resolutions=attn_resolutions, + conv_in=encoder_conv_in, + conv_out=encoder_conv_out, + attention=encoder_attention, + resnet_blocks=encoder_resnet_blocks, + spatial_downsample=encoder_spatial_downsample, + temporal_downsample=encoder_temporal_downsample, + mid_resnet=encoder_mid_resnet, + dropout=dropout, + resolution=resolution, + num_res_blocks=num_res_blocks, + double_z=double_z, + ) + + self.decoder = Decoder( + z_channels=z_channels, + hidden_size=hidden_size, + hidden_size_mult=hidden_size_mult, + attn_resolutions=attn_resolutions, + conv_in=decoder_conv_in, + conv_out=decoder_conv_out, + attention=decoder_attention, + resnet_blocks=decoder_resnet_blocks, + spatial_upsample=decoder_spatial_upsample, + temporal_upsample=decoder_temporal_upsample, + mid_resnet=decoder_mid_resnet, + dropout=dropout, + resolution=resolution, + num_res_blocks=num_res_blocks, + ) + if self.use_quant_layer: + quant_conv_cls = resolve_str_to_obj(q_conv) + self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1) + self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1) + + def get_encoder(self): + if self.use_quant_layer: + return [self.quant_conv, self.encoder] + return [self.encoder] + + def get_decoder(self): + if self.use_quant_layer: + return [self.post_quant_conv, self.decoder] + return [self.decoder] + + def encode(self, x): + if self.use_tiling and ( + x.shape[-1] > self.tile_sample_min_size + or x.shape[-2] > self.tile_sample_min_size + or x.shape[-3] > self.tile_sample_min_size_t + ): + return self.tiled_encode(x) + h = self.encoder(x) + if self.use_quant_layer: + h = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(h) + return posterior + + def decode(self, z): + if self.use_tiling and ( + z.shape[-1] > self.tile_latent_min_size + or z.shape[-2] > self.tile_latent_min_size + or z.shape[-3] > self.tile_latent_min_size_t + ): + return self.tiled_decode(z) + if self.use_quant_layer: + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def on_train_start(self): + self.ema = deepcopy(self) if self.save_ema == True else None + + def get_last_layer(self): + if hasattr(self.decoder.conv_out, "conv"): + return self.decoder.conv_out.conv.weight + else: + return self.decoder.conv_out.weight + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x): + t = x.shape[2] + t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)] + if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: + t_chunk_start_end = [[0, t]] + else: + t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] + if t_chunk_start_end[-1][-1] > t: + t_chunk_start_end[-1][-1] = t + elif t_chunk_start_end[-1][-1] < t: + last_start_end = [t_chunk_idx[-1], t] + t_chunk_start_end.append(last_start_end) + moments = [] + for idx, (start, end) in enumerate(t_chunk_start_end): + chunk_x = x[:, :, start:end] + if idx != 0: + moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:] + else: + moment = self.tiled_encode2d(chunk_x, return_moments=True) + moments.append(moment) + moments = torch.cat(moments, dim=2) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def tiled_decode(self, x): + t = x.shape[2] + t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)] + if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: + t_chunk_start_end = [[0, t]] + else: + t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] + if t_chunk_start_end[-1][-1] > t: + t_chunk_start_end[-1][-1] = t + elif t_chunk_start_end[-1][-1] < t: + last_start_end = [t_chunk_idx[-1], t] + t_chunk_start_end.append(last_start_end) + dec_ = [] + for idx, (start, end) in enumerate(t_chunk_start_end): + chunk_x = x[:, :, start:end] + if idx != 0: + dec = self.tiled_decode2d(chunk_x)[:, :, 1:] + else: + dec = self.tiled_decode2d(chunk_x) + dec_.append(dec) + dec_ = torch.cat(dec_, dim=2) + return dec_ + + def tiled_encode2d(self, x, return_moments=False): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + if self.use_quant_layer: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + moments = torch.cat(result_rows, dim=3) + posterior = DiagonalGaussianDistribution(moments) + if return_moments: + return moments + return posterior + + def tiled_decode2d(self, z): + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + if self.use_quant_layer: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def enable_tiling(self, use_tiling: bool = True): + self.use_tiling = use_tiling + + def disable_tiling(self): + self.enable_tiling(False) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu") + # print("init from " + path) + + if "ema_state_dict" in sd and len(sd["ema_state_dict"]) > 0 and os.environ.get("NOT_USE_EMA_MODEL", 0) == 0: + # print("Load from ema model!") + sd = sd["ema_state_dict"] + sd = {key.replace("module.", ""): value for key, value in sd.items()} + elif "state_dict" in sd: + # print("Load from normal model!") + if "gen_model" in sd["state_dict"]: + sd = sd["state_dict"]["gen_model"] + else: + sd = sd["state_dict"] + + keys = list(sd.keys()) + + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + miss, unexpected = self.load_state_dict(sd, strict=False) + assert len(miss) == 0, f"miss key: {miss}" + if len(unexpected) > 0: + for i in unexpected: + assert "loss" in i, "unexpected key: {i}" + + +ae_stride_config = { + "CausalVAEModel_D4_2x8x8": [2, 8, 8], + "CausalVAEModel_D8_2x8x8": [2, 8, 8], + "CausalVAEModel_D4_4x8x8": [4, 8, 8], + "CausalVAEModel_D8_4x8x8": [4, 8, 8], +} + + +ae_channel_config = { + "CausalVAEModel_D4_2x8x8": 4, + "CausalVAEModel_D8_2x8x8": 8, + "CausalVAEModel_D4_4x8x8": 4, + "CausalVAEModel_D8_4x8x8": 8, +} + + +ae_denorm = { + "CausalVAEModel_D4_2x8x8": lambda x: (x + 1.0) / 2.0, + "CausalVAEModel_D8_2x8x8": lambda x: (x + 1.0) / 2.0, + "CausalVAEModel_D4_4x8x8": lambda x: (x + 1.0) / 2.0, + "CausalVAEModel_D8_4x8x8": lambda x: (x + 1.0) / 2.0, +} + +ae_norm = { + "CausalVAEModel_D4_2x8x8": Lambda(lambda x: 2.0 * x - 1.0), + "CausalVAEModel_D8_2x8x8": Lambda(lambda x: 2.0 * x - 1.0), + "CausalVAEModel_D4_4x8x8": Lambda(lambda x: 2.0 * x - 1.0), + "CausalVAEModel_D8_4x8x8": Lambda(lambda x: 2.0 * x - 1.0), +} + + +class CausalVAEModelWrapper(nn.Module): + def __init__(self, model_path, subfolder=None, cache_dir=None, use_ema=False, **kwargs): + super(CausalVAEModelWrapper, self).__init__() + # if os.path.exists(ckpt): + # self.vae = CausalVAEModel.load_from_checkpoint(ckpt) + # hf_hub_download(model_path, subfolder="vae", filename="checkpoint.ckpt") + # cached_download(hf_hub_url(model_path, subfolder="vae", filename="checkpoint.ckpt")) + self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) + if use_ema: + self.vae.init_from_ema(model_path) + self.vae = self.vae.ema + + def encode(self, x): # b c t h w + # x = self.vae.encode(x).sample() + x = self.vae.encode(x).sample().mul_(0.18215) + return x + + def decode(self, x): + # x = self.vae.decode(x) + x = self.vae.decode(x / 0.18215) + x = rearrange(x, "b c t h w -> b t c h w").contiguous() + return x + + def forward(self, x): + return self.decode(x) + + def dtype(self): + return self.vae.dtype diff --git a/videosys/models/modules/attentions.py b/videosys/models/modules/attentions.py index 8e2c20c..4bbba72 100644 --- a/videosys/models/modules/attentions.py +++ b/videosys/models/modules/attentions.py @@ -1,12 +1,21 @@ +import inspect from dataclasses import dataclass -from typing import Iterable, List, Tuple +from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint +from diffusers.models.attention import Attention +from diffusers.models.attention_processor import AttnProcessor +from einops import rearrange +from torch import nn +from torch.amp import autocast -from videosys.models.modules.normalization import LlamaRMSNorm +from videosys.core.comm import all_to_all_with_pad, get_pad, set_pad +from videosys.core.pab_mgr import enable_pab, if_broadcast_cross, if_broadcast_spatial, if_broadcast_temporal +from videosys.models.modules.normalization import LlamaRMSNorm, VchitectSpatialNorm +from videosys.utils.logging import logger class OpenSoraAttention(nn.Module): @@ -203,3 +212,634 @@ class _SeqLenInfo: seqstart=seqstart, seqstart_py=seqstart_py, ) + + +class VchitectAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional[AttnProcessor] = None, + out_dim: int = None, + context_pre_only: bool = None, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = VchitectSpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_k = nn.LayerNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.to_q_cross = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_q_temp = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k_temp = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v_temp = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k_temp = None + self.to_v_temp = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + self.to_out_temporal = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + nn.init.constant_(self.to_out_temporal.weight, 0) + nn.init.constant_(self.to_out_temporal.bias, 0) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + self.to_add_out_temporal = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + nn.init.constant_(self.to_add_out_temporal.weight, 0) + nn.init.constant_(self.to_add_out_temporal.bias, 0) + + self.to_out_context = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + nn.init.constant_(self.to_out_context.weight, 0) + nn.init.constant_(self.to_out_context.bias, 0) + + # set attention processor + self.set_processor(processor) + + # parallel + self.parallel_manager = None + + # pab + self.spatial_count = 0 + self.last_spatial = None + self.temporal_count = 0 + self.last_temporal = None + self.cross_count = 0 + self.last_cross = None + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + full_seqlen: Optional[int] = None, + Frame: Optional[int] = None, + timestep: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + freqs_cis=freqs_cis, + full_seqlen=full_seqlen, + Frame=Frame, + timestep=timestep, + **cross_attention_kwargs, + ) + + @torch.no_grad() + def fuse_projections(self, fuse=True): + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not self.is_cross_attention: + # fetch weight matrices. + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + # create a new single projection layer and copy over the weights. + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + + self.fused_projections = fuse + + +class VchitectAttnProcessor: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + @autocast("cuda", enabled=False) + def apply_rotary_emb( + self, + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + ): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + def spatial_attn( + self, + attn, + hidden_states, + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + batch_size, + head_dim, + ): + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # attention + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + xq, xk = query.to(value.dtype), key.to(value.dtype) + + query_spatial, key_spatial, value_spatial = xq, xk, value + query_spatial = query_spatial.transpose(1, 2) + key_spatial = key_spatial.transpose(1, 2) + value_spatial = value_spatial.transpose(1, 2) + + hidden_states = hidden_states = F.scaled_dot_product_attention( + query_spatial, key_spatial, value_spatial, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + return hidden_states + + def temporal_attention( + self, + attn, + hidden_states, + residual, + batch_size, + batchsize, + Frame, + head_dim, + freqs_cis, + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ): + query_t = attn.to_q_temp(hidden_states) + key_t = attn.to_k_temp(hidden_states) + value_t = attn.to_v_temp(hidden_states) + + query_t = torch.cat([query_t, encoder_hidden_states_query_proj], dim=1) + key_t = torch.cat([key_t, encoder_hidden_states_key_proj], dim=1) + value_t = torch.cat([value_t, encoder_hidden_states_value_proj], dim=1) + + query_t = query_t.view(batch_size, -1, attn.heads, head_dim) + key_t = key_t.view(batch_size, -1, attn.heads, head_dim) + value_t = value_t.view(batch_size, -1, attn.heads, head_dim) + + query_t, key_t = query_t.to(value_t.dtype), key_t.to(value_t.dtype) + + if attn.parallel_manager.sp_size > 1: + func = lambda x: self.dynamic_switch(attn, x, batchsize, to_spatial_shard=True) + query_t, key_t, value_t = map(func, [query_t, key_t, value_t]) + + func = lambda x: rearrange(x, "(B T) S H C -> (B S) T H C", T=Frame, B=batchsize) + xq_gather_temporal, xk_gather_temporal, xv_gather_temporal = map(func, [query_t, key_t, value_t]) + + freqs_cis_temporal = freqs_cis[: xq_gather_temporal.shape[1], :] + xq_gather_temporal, xk_gather_temporal = self.apply_rotary_emb( + xq_gather_temporal, xk_gather_temporal, freqs_cis=freqs_cis_temporal + ) + + xq_gather_temporal = xq_gather_temporal.transpose(1, 2) + xk_gather_temporal = xk_gather_temporal.transpose(1, 2) + xv_gather_temporal = xv_gather_temporal.transpose(1, 2) + + batch_size_temp = xv_gather_temporal.shape[0] + hidden_states_temp = F.scaled_dot_product_attention( + xq_gather_temporal, xk_gather_temporal, xv_gather_temporal, dropout_p=0.0, is_causal=False + ) + hidden_states_temp = hidden_states_temp.transpose(1, 2).reshape(batch_size_temp, -1, attn.heads * head_dim) + hidden_states_temp = hidden_states_temp.to(value_t.dtype) + hidden_states_temp = rearrange(hidden_states_temp, "(B S) T C -> (B T) S C", T=Frame, B=batchsize) + if attn.parallel_manager.sp_size > 1: + hidden_states_temp = self.dynamic_switch(attn, hidden_states_temp, batchsize, to_spatial_shard=False) + + hidden_states_temporal, encoder_hidden_states_temporal = ( + hidden_states_temp[:, : residual.shape[1]], + hidden_states_temp[:, residual.shape[1] :], + ) + hidden_states_temporal = attn.to_out_temporal(hidden_states_temporal) + return hidden_states_temporal, encoder_hidden_states_temporal + + def cross_attention( + self, + attn, + hidden_states, + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + batch_size, + head_dim, + cur_frame, + batchsize, + ): + query_cross = attn.to_q_cross(hidden_states) + query_cross = torch.cat([query_cross, encoder_hidden_states_query_proj], dim=1) + + key_y = encoder_hidden_states_key_proj[0].unsqueeze(0) + value_y = encoder_hidden_states_value_proj[0].unsqueeze(0) + + query_y = query_cross.view(batch_size, -1, attn.heads, head_dim) + key_y = key_y.view(batchsize, -1, attn.heads, head_dim) + value_y = value_y.view(batchsize, -1, attn.heads, head_dim) + + query_y = rearrange(query_y, "(B T) S H C -> B (S T) H C", T=cur_frame, B=batchsize) + + query_y = query_y.transpose(1, 2) + key_y = key_y.transpose(1, 2) + value_y = value_y.transpose(1, 2) + + cross_output = F.scaled_dot_product_attention(query_y, key_y, value_y, dropout_p=0.0, is_causal=False) + cross_output = cross_output.transpose(1, 2).reshape(batchsize, -1, attn.heads * head_dim) + cross_output = cross_output.to(query_cross.dtype) + + cross_output = rearrange(cross_output, "B (S T) C -> (B T) S C", T=cur_frame, B=batchsize) + cross_output = attn.to_out_context(cross_output) + return cross_output + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + full_seqlen: Optional[int] = None, + Frame: Optional[int] = None, + timestep: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + batch_size = encoder_hidden_states.shape[0] + inner_dim = encoder_hidden_states_key_proj.shape[-1] + head_dim = inner_dim // attn.heads + batchsize = full_seqlen // Frame + # same as Frame if shard, otherwise sharded frame + cur_frame = batch_size // batchsize + + # temporal attention + if enable_pab(): + broadcast_temporal, attn.temporal_count = if_broadcast_temporal(int(timestep[0]), attn.temporal_count) + if enable_pab() and broadcast_temporal: + hidden_states_temporal, encoder_hidden_states_temporal = attn.last_temporal + else: + hidden_states_temporal, encoder_hidden_states_temporal = self.temporal_attention( + attn, + hidden_states, + residual, + batch_size, + batchsize, + Frame, + head_dim, + freqs_cis, + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) + if enable_pab(): + attn.last_temporal = (hidden_states_temporal, encoder_hidden_states_temporal) + + # cross attn + if enable_pab(): + broadcast_cross, attn.cross_count = if_broadcast_cross(int(timestep[0]), attn.cross_count) + if enable_pab() and broadcast_cross: + cross_output = attn.last_cross + else: + cross_output = self.cross_attention( + attn, + hidden_states, + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + batch_size, + head_dim, + cur_frame, + batchsize, + ) + if enable_pab(): + attn.last_cross = cross_output + + # spatial attn + if enable_pab(): + broadcast_spatial, attn.spatial_count = if_broadcast_spatial(int(timestep[0]), attn.spatial_count) + if enable_pab() and broadcast_spatial: + hidden_states = attn.last_spatial + else: + hidden_states = self.spatial_attn( + attn, + hidden_states, + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + batch_size, + head_dim, + ) + if enable_pab(): + attn.last_spatial = hidden_states + + # processs attention outputs. + hidden_states = hidden_states * 1.1 + cross_output + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if cur_frame == 1: + hidden_states_temporal = hidden_states_temporal * 0 + hidden_states = hidden_states + hidden_states_temporal + + # encoder + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + encoder_hidden_states_temporal = attn.to_add_out_temporal(encoder_hidden_states_temporal) + if cur_frame == 1: + encoder_hidden_states_temporal = encoder_hidden_states_temporal * 0 + encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_temporal + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + + def dynamic_switch(self, attn, x, batchsize, to_spatial_shard: bool): + if to_spatial_shard: + scatter_dim, gather_dim = 2, 1 + set_pad("spatial", x.shape[1], attn.parallel_manager.sp_group) + scatter_pad = get_pad("spatial") + gather_pad = get_pad("temporal") + else: + scatter_dim, gather_dim = 1, 2 + scatter_pad = get_pad("temporal") + gather_pad = get_pad("spatial") + + x = rearrange(x, "(B T) S ... -> B T S ...", B=batchsize) + x = all_to_all_with_pad( + x, + attn.parallel_manager.sp_group, + scatter_dim=scatter_dim, + gather_dim=gather_dim, + scatter_pad=scatter_pad, + gather_pad=gather_pad, + ) + x = rearrange(x, "B T ... -> (B T) ...") + return x diff --git a/videosys/models/modules/normalization.py b/videosys/models/modules/normalization.py index 7985e56..4df9ae3 100644 --- a/videosys/models/modules/normalization.py +++ b/videosys/models/modules/normalization.py @@ -2,6 +2,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn +import torch.nn.functional as F class LlamaRMSNorm(nn.Module): @@ -100,3 +101,32 @@ class AdaLayerNorm(nn.Module): x = self.norm(x) * (1 + scale) + shift return x + + +class VchitectSpatialNorm(nn.Module): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. + """ + + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f diff --git a/videosys/models/transformers/cogvideox_transformer_3d.py b/videosys/models/transformers/cogvideox_transformer_3d.py index 4b86d70..e568e06 100644 --- a/videosys/models/transformers/cogvideox_transformer_3d.py +++ b/videosys/models/transformers/cogvideox_transformer_3d.py @@ -22,15 +22,9 @@ from diffusers.utils import is_torch_version from diffusers.utils.torch_utils import maybe_allow_in_graph from torch import nn -from videosys.core.comm import all_to_all_comm, gather_sequence, get_spatial_pad, set_spatial_pad, split_sequence +from videosys.core.comm import all_to_all_comm, gather_sequence, get_pad, set_pad, split_sequence from videosys.core.pab_mgr import enable_pab, if_broadcast_spatial -from videosys.core.parallel_mgr import ( - enable_sequence_parallel, - get_cfg_parallel_group, - get_cfg_parallel_size, - get_sequence_parallel_group, - get_sequence_parallel_size, -) +from videosys.core.parallel_mgr import ParallelManager from videosys.models.modules.embeddings import apply_rotary_emb from videosys.utils.utils import batch_func @@ -48,6 +42,49 @@ class CogVideoXAttnProcessor2_0: if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + def _remove_extra_encoder(self, hidden_states, text_seq_length, attn): + # current layout is [text, 1/n seq, text, 1/n seq, ...] + # we want to remove the all the text info [text, seq] + sp_size = attn.parallel_manager.sp_size + split_seq = hidden_states.split(hidden_states.size(2) // sp_size, dim=2) + encoder_hidden_states = split_seq[0][:, :, :text_seq_length] + new_seq = [encoder_hidden_states] + for i in range(sp_size): + new_seq.append(split_seq[i][:, :, text_seq_length:]) + hidden_states = torch.cat(new_seq, dim=2) + + # remove padding added when all2all + # if pad is removed earlier than this + # the split size will be wrong + pad = get_pad("pad") + if pad > 0: + hidden_states = hidden_states.narrow(2, 0, hidden_states.size(2) - pad) + return hidden_states + + def _add_extra_encoder(self, hidden_states, text_seq_length, attn): + # add padding for split and later all2all + # if pad is removed later than this + # the split size will be wrong + pad = get_pad("pad") + if pad > 0: + pad_shape = list(hidden_states.shape) + pad_shape[1] = pad + pad_tensor = torch.zeros(pad_shape, device=hidden_states.device, dtype=hidden_states.dtype) + hidden_states = torch.cat([hidden_states, pad_tensor], dim=1) + + # current layout is [text, seq] + # we want to add the extra encoder info [text, 1/n seq, text, 1/n seq, ...] + sp_size = attn.parallel_manager.sp_size + encoder = hidden_states[:, :text_seq_length] + seq = hidden_states[:, text_seq_length:] + seq = seq.split(seq.size(1) // sp_size, dim=1) + new_seq = [] + for i in range(sp_size): + new_seq.append(encoder) + new_seq.append(seq[i]) + hidden_states = torch.cat(new_seq, dim=1) + return hidden_states + def __call__( self, attn: Attention, @@ -72,13 +109,15 @@ class CogVideoXAttnProcessor2_0: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - if enable_sequence_parallel(): + if attn.parallel_manager.sp_size > 1: assert ( - attn.heads % get_sequence_parallel_size() == 0 - ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {get_sequence_parallel_size()}" - attn_heads = attn.heads // get_sequence_parallel_size() + attn.heads % attn.parallel_manager.sp_size == 0 + ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}" + attn_heads = attn.heads // attn.parallel_manager.sp_size + # normally we operate pad for every all2all. but for more convient implementation + # we move pad operation to encoder add and remove in cogvideo query, key, value = map( - lambda x: all_to_all_comm(x, get_sequence_parallel_group(), scatter_dim=2, gather_dim=1), + lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1), [query, key, value], ) else: @@ -96,6 +135,13 @@ class CogVideoXAttnProcessor2_0: if attn.norm_k is not None: key = attn.norm_k(key) + if attn.parallel_manager.sp_size > 1: + # remove extra encoder for attention + query, key, value = map( + lambda x: self._remove_extra_encoder(x, text_seq_length, attn), + [query, key, value], + ) + # Apply RoPE if needed if image_rotary_emb is not None: emb_len = image_rotary_emb[0].shape[0] @@ -113,77 +159,10 @@ class CogVideoXAttnProcessor2_0: hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim) - if enable_sequence_parallel(): - hidden_states = all_to_all_comm(hidden_states, get_sequence_parallel_group(), scatter_dim=1, gather_dim=2) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 - ) - return hidden_states, encoder_hidden_states - - -class FusedCogVideoXAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on - query and key vectors, but does not include spatial normalization. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) - if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + if attn.parallel_manager.sp_size > 1: + # add extra encoder for all_to_all + hidden_states = self._add_extra_encoder(hidden_states, text_seq_length, attn) + hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2) # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -266,6 +245,9 @@ class CogVideoXBlock(nn.Module): processor=CogVideoXAttnProcessor2_0(), ) + # parallel + self.attn1.parallel_manager = None + # 2. Feed Forward self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) @@ -300,7 +282,7 @@ class CogVideoXBlock(nn.Module): # attention if enable_pab(): - broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx) + broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count) if enable_pab() and broadcast_attn: attn_hidden_states, attn_encoder_hidden_states = self.last_attn else: @@ -474,6 +456,23 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): self.gradient_checkpointing = False + # parallel + self.parallel_manager = None + + def enable_parallel(self, dp_size, sp_size, enable_cp): + # update cfg parallel + if enable_cp and sp_size % 2 == 0: + sp_size = sp_size // 2 + cp_size = 2 + else: + cp_size = 1 + + self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size) + + for _, module in self.named_modules(): + if hasattr(module, "parallel_manager"): + module.parallel_manager = self.parallel_manager + def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value @@ -485,9 +484,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): timestep_cond: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, - all_timesteps=None, ): - if get_cfg_parallel_size() > 1: + if self.parallel_manager.cp_size > 1: ( hidden_states, encoder_hidden_states, @@ -495,7 +493,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): timestep_cond, image_rotary_emb, ) = batch_func( - partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0), + partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), hidden_states, encoder_hidden_states, timestep, @@ -530,9 +528,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] - if enable_sequence_parallel(): - set_spatial_pad(hidden_states.shape[1]) - hidden_states = split_sequence(hidden_states, get_sequence_parallel_group(), dim=1, pad=get_spatial_pad()) + if self.parallel_manager.sp_size > 1: + set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group) + hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) # 4. Transformer blocks for i, block in enumerate(self.transformer_blocks): @@ -562,8 +560,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): timestep=timesteps if enable_pab() else None, ) - if enable_sequence_parallel(): - hidden_states = gather_sequence(hidden_states, get_sequence_parallel_group(), dim=1, pad=get_spatial_pad()) + if self.parallel_manager.sp_size > 1: + hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) if not self.config.use_rotary_positional_embeddings: # CogVideoX-2B @@ -583,8 +581,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - if get_cfg_parallel_size() > 1: - output = gather_sequence(output, get_cfg_parallel_group(), dim=0) + if self.parallel_manager.cp_size > 1: + output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) if not return_dict: return (output,) diff --git a/videosys/models/transformers/latte_transformer_3d.py b/videosys/models/transformers/latte_transformer_3d.py index 3a00759..ef28720 100644 --- a/videosys/models/transformers/latte_transformer_3d.py +++ b/videosys/models/transformers/latte_transformer_3d.py @@ -33,15 +33,7 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph from einops import rearrange, repeat from torch import nn -from videosys.core.comm import ( - all_to_all_with_pad, - gather_sequence, - get_spatial_pad, - get_temporal_pad, - set_spatial_pad, - set_temporal_pad, - split_sequence, -) +from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence from videosys.core.pab_mgr import ( enable_pab, get_mlp_output, @@ -51,12 +43,7 @@ from videosys.core.pab_mgr import ( if_broadcast_temporal, save_mlp_output, ) -from videosys.core.parallel_mgr import ( - enable_sequence_parallel, - get_cfg_parallel_group, - get_cfg_parallel_size, - get_sequence_parallel_group, -) +from videosys.core.parallel_mgr import ParallelManager from videosys.utils.utils import batch_func @@ -388,9 +375,7 @@ class BasicTransformerBlock(nn.Module): gligen_kwargs = cross_attention_kwargs.pop("gligen", None) if enable_pab(): - broadcast_spatial, self.spatial_count = if_broadcast_spatial( - int(org_timestep[0]), self.spatial_count, self.block_idx - ) + broadcast_spatial, self.spatial_count = if_broadcast_spatial(int(org_timestep[0]), self.spatial_count) if enable_pab() and broadcast_spatial: attn_output = self.spatial_last @@ -681,6 +666,9 @@ class BasicTransformerBlock_(nn.Module): self.block_idx = block_idx self.count = 0 + # parallel + self.parallel_manager: ParallelManager = None + def set_last_out(self, last_out: torch.Tensor): self.last_out = last_out @@ -743,7 +731,7 @@ class BasicTransformerBlock_(nn.Module): if self.pos_embed is not None: norm_hidden_states = self.pos_embed(norm_hidden_states) - if enable_sequence_parallel(): + if self.parallel_manager.sp_size > 1: norm_hidden_states = self.dynamic_switch(norm_hidden_states, to_spatial_shard=True) attn_output = self.attn1( @@ -753,7 +741,7 @@ class BasicTransformerBlock_(nn.Module): **cross_attention_kwargs, ) - if enable_sequence_parallel(): + if self.parallel_manager.sp_size > 1: attn_output = self.dynamic_switch(attn_output, to_spatial_shard=False) if self.use_ada_layer_norm_zero: @@ -838,15 +826,15 @@ class BasicTransformerBlock_(nn.Module): def dynamic_switch(self, x, to_spatial_shard: bool): if to_spatial_shard: scatter_dim, gather_dim = 0, 1 - scatter_pad = get_spatial_pad() - gather_pad = get_temporal_pad() + scatter_pad = get_pad("spatial") + gather_pad = get_pad("temporal") else: scatter_dim, gather_dim = 1, 0 - scatter_pad = get_temporal_pad() - gather_pad = get_spatial_pad() + scatter_pad = get_pad("temporal") + gather_pad = get_pad("spatial") x = all_to_all_with_pad( x, - get_sequence_parallel_group(), + self.parallel_manager.sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim, scatter_pad=scatter_pad, @@ -1133,6 +1121,23 @@ class LatteT2V(ModelMixin, ConfigMixin): temp_pos_embed = self.get_1d_sincos_temp_embed(inner_dim, video_length) # 1152 hidden size self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + # parallel + self.parallel_manager = None + + def enable_parallel(self, dp_size, sp_size, enable_cp): + # update cfg parallel + if enable_cp and sp_size % 2 == 0: + sp_size = sp_size // 2 + cp_size = 2 + else: + cp_size = 1 + + self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size) + + for _, module in self.named_modules(): + if hasattr(module, "parallel_manager"): + module.parallel_manager = self.parallel_manager + def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value @@ -1191,7 +1196,7 @@ class LatteT2V(ModelMixin, ConfigMixin): """ # 0. Split batch for data parallelism - if get_cfg_parallel_size() > 1: + if self.parallel_manager.cp_size > 1: ( hidden_states, timestep, @@ -1201,7 +1206,7 @@ class LatteT2V(ModelMixin, ConfigMixin): attention_mask, encoder_attention_mask, ) = batch_func( - partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0), + partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), hidden_states, timestep, encoder_hidden_states, @@ -1292,14 +1297,14 @@ class LatteT2V(ModelMixin, ConfigMixin): timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous() timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous() - if enable_sequence_parallel(): - set_temporal_pad(frame + use_image_num) - set_spatial_pad(num_patches) + if self.parallel_manager.sp_size > 1: + set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group) + set_pad("spatial", num_patches, self.parallel_manager.sp_group) hidden_states = self.split_from_second_dim(hidden_states, input_batch_size) encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size) timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size) temp_pos_embed = split_sequence( - self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad() + self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") ) else: temp_pos_embed = self.temp_pos_embed @@ -1420,7 +1425,7 @@ class LatteT2V(ModelMixin, ConfigMixin): hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() - if enable_sequence_parallel(): + if self.parallel_manager.sp_size > 1: hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size) if self.is_input_patches: @@ -1452,8 +1457,8 @@ class LatteT2V(ModelMixin, ConfigMixin): output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous() # 3. Gather batch for data parallelism - if get_cfg_parallel_size() > 1: - output = gather_sequence(output, get_cfg_parallel_group(), dim=0) + if self.parallel_manager.cp_size > 1: + output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) if not return_dict: return (output,) @@ -1466,12 +1471,12 @@ class LatteT2V(ModelMixin, ConfigMixin): def split_from_second_dim(self, x, batch_size): x = x.view(batch_size, -1, *x.shape[1:]) - x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()) + x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")) x = x.reshape(-1, *x.shape[2:]) return x def gather_from_second_dim(self, x, batch_size): x = x.view(batch_size, -1, *x.shape[1:]) - x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad()) + x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal")) x = x.reshape(-1, *x.shape[2:]) return x diff --git a/videosys/models/transformers/open_sora_plan_v110_transformer_3d.py b/videosys/models/transformers/open_sora_plan_v110_transformer_3d.py new file mode 100644 index 0000000..9839677 --- /dev/null +++ b/videosys/models/transformers/open_sora_plan_v110_transformer_3d.py @@ -0,0 +1,2826 @@ +# Adapted from Open-Sora-Plan + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan +# -------------------------------------------------------- + + +import json +import os +from dataclasses import dataclass +from functools import partial +from importlib import import_module +from typing import Any, Callable, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + AttnProcessor, + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + SlicedAttnAddedKVProcessor, + SlicedAttnProcessor, + SpatialNorm, + XFormersAttnAddedKVProcessor, + XFormersAttnProcessor, +) +from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_xformers_available +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange, repeat +from torch import nn + +from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence +from videosys.core.pab_mgr import ( + enable_pab, + get_mlp_output, + if_broadcast_cross, + if_broadcast_mlp, + if_broadcast_spatial, + if_broadcast_temporal, + save_mlp_output, +) +from videosys.core.parallel_mgr import ParallelManager +from videosys.utils.logging import logger +from videosys.utils.utils import batch_func + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +SPATIAL_LIST = [] +TEMPROAL_LIST = [] +CROSS_LIST = [] + + +def get_2d_sincos_pos_embed( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed(embed_dim, length, interpolation_scale=1.0, base_size=16): + pos = torch.arange(0, length).unsqueeze(1) / interpolation_scale + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, pos) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class RoPE2D(torch.nn.Module): + def __init__(self, freq=10000.0, F0=1.0, scaling_factor=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.scaling_factor = scaling_factor + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3) % 2 == 0, "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:, :, 0], cos, sin) + x = self.apply_rope1d(x, positions[:, :, 1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens + + +class LinearScalingRoPE2D(RoPE2D): + """Code from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L148""" + + def forward(self, tokens, positions): + # difference to the original RoPE: a scaling factor is aplied to the position ids + dtype = positions.dtype + positions = positions.float() / self.scaling_factor + positions = positions.to(dtype) + tokens = super().forward(tokens, positions) + return tokens + + +class RoPE1D(torch.nn.Module): + def __init__(self, freq=10000.0, F0=1.0, scaling_factor=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.scaling_factor = scaling_factor + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens (t position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + D = tokens.size(3) + assert positions.ndim == 2 # Batch, Seq + cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype) + tokens = self.apply_rope1d(tokens, positions, cos, sin) + return tokens + + +class LinearScalingRoPE1D(RoPE1D): + """Code from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L148""" + + def forward(self, tokens, positions): + # difference to the original RoPE: a scaling factor is aplied to the position ids + dtype = positions.dtype + positions = positions.float() / self.scaling_factor + positions = positions.to(dtype) + tokens = super().forward(tokens, positions) + return tokens + + +class PositionGetter2D(object): + """return positions of patches""" + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if not (h, w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone() + return pos + + +class PositionGetter1D(object): + """return positions of patches""" + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, l, device): + if not (l) in self.cache_positions: + x = torch.arange(l, device=device) + self.cache_positions[l] = x # (l, ) + pos = self.cache_positions[l].view(1, l).expand(b, -1).clone() + return pos + + +class CombinedTimestepSizeEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.use_additional_conditions = True + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): + if size.ndim == 1: + size = size[:, None] + + if size.shape[0] != batch_size: + size = size.repeat(batch_size // size.shape[0], 1) + if size.shape[0] != batch_size: + raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") + + current_batch_size, dims = size.shape[0], size.shape[1] + size = size.reshape(-1) + size_freq = self.additional_condition_proj(size).to(size.dtype) + + size_emb = embedder(size_freq) + size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) + return size_emb + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) + aspect_ratio = self.apply_condition( + aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder + ) + conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + + +class CaptionProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features, hidden_size, num_tokens=120): + super().__init__() + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + self.act_1 = nn.GELU(approximate="tanh") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) + self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5)) + + def forward(self, caption, force_drop_ids=None): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=1, + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size = patch_size + # See: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 + self.height, self.width = height // patch_size, width // patch_size + + self.base_size = height // patch_size + self.interpolation_scale = interpolation_scale + pos_embed = get_2d_sincos_pos_embed( + embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent): + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + # Interpolate positional embeddings if needed. + # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) + if self.height != height or self.width != width: + # raise ValueError + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed) + pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + return (latent + pos_embed).to(latent.dtype) + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + attention_mode: str = "xformers", + use_rope: bool = False, + rope_scaling: Optional[Dict] = None, + compress_kv_factor: Optional[Tuple] = None, + ): + super().__init__() + self.inner_dim = dim_head * heads + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.use_rope = use_rope + self.rope_scaling = rope_scaling + self.compress_kv_factor = compress_kv_factor + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + if USE_PEFT_BACKEND: + linear_cls = nn.Linear + else: + linear_cls = LoRACompatibleLinear + + assert not ( + self.use_rope and (self.compress_kv_factor is not None) + ), "Can not both enable compressing kv and using rope" + if self.compress_kv_factor is not None: + self._init_compress() + + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0( + self.inner_dim, + attention_mode, + use_rope, + rope_scaling=rope_scaling, + compress_kv_factor=compress_kv_factor, + ) + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_lora = hasattr(self, "processor") + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and (is_lora or is_custom_diffusion): + raise NotImplementedError( + f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_lora: + # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers + # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? + processor = LoRAXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_lora: + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + processor = attn_processor_class( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + _remove_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to remove LoRA layers from the model. + """ + if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None: + deprecate( + "set_processor to offload LoRA", + "0.26.0", + "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", + ) + # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete + # We need to remove all LoRA layers + # Don't forget to remove ALL `_remove_lora` from the codebase + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False): + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None + for name, module in self.named_modules() + if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) + lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + def _init_compress(self): + if len(self.compress_kv_factor) == 2: + self.sr = nn.Conv2d( + self.inner_dim, + self.inner_dim, + groups=self.inner_dim, + kernel_size=self.compress_kv_factor, + stride=self.compress_kv_factor, + ) + self.sr.weight.data.fill_(1 / self.compress_kv_factor[0] ** 2) + elif len(self.compress_kv_factor) == 1: + self.kernel_size = self.compress_kv_factor[0] + self.sr = nn.Conv1d( + self.inner_dim, + self.inner_dim, + groups=self.inner_dim, + kernel_size=self.compress_kv_factor[0], + stride=self.compress_kv_factor[0], + ) + self.sr.weight.data.fill_(1 / self.compress_kv_factor[0]) + self.sr.bias.data.zero_() + self.norm = nn.LayerNorm(self.inner_dim) + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, dim=1152, attention_mode="xformers", use_rope=False, rope_scaling=None, compress_kv_factor=None): + self.dim = dim + self.attention_mode = attention_mode + self.use_rope = use_rope + self.rope_scaling = rope_scaling + self.compress_kv_factor = compress_kv_factor + if self.use_rope: + self._init_rope() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def _init_rope(self): + if self.rope_scaling is None: + self.rope2d = RoPE2D() + self.rope1d = RoPE1D() + else: + scaling_type = self.rope_scaling["type"] + scaling_factor_2d = self.rope_scaling["factor_2d"] + scaling_factor_1d = self.rope_scaling["factor_1d"] + if scaling_type == "linear": + self.rope2d = LinearScalingRoPE2D(scaling_factor=scaling_factor_2d) + self.rope1d = LinearScalingRoPE1D(scaling_factor=scaling_factor_1d) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + position_q: Optional[torch.LongTensor] = None, + position_k: Optional[torch.LongTensor] = None, + last_shape: Tuple[int] = None, + ) -> torch.FloatTensor: + residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + if self.compress_kv_factor is not None: + batch_size = hidden_states.shape[0] + if len(last_shape) == 2: + encoder_hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, self.dim, *last_shape) + encoder_hidden_states = ( + attn.sr(encoder_hidden_states).reshape(batch_size, self.dim, -1).permute(0, 2, 1) + ) + elif len(last_shape) == 1: + encoder_hidden_states = hidden_states.permute(0, 2, 1) + if last_shape[0] % 2 == 1: + first_frame_pad = encoder_hidden_states[:, :, :1].repeat((1, 1, attn.kernel_size - 1)) + encoder_hidden_states = torch.concatenate((first_frame_pad, encoder_hidden_states), dim=2) + encoder_hidden_states = attn.sr(encoder_hidden_states).permute(0, 2, 1) + else: + raise NotImplementedError(f"NotImplementedError with last_shape {last_shape}") + + encoder_hidden_states = attn.norm(encoder_hidden_states) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + args = () if USE_PEFT_BACKEND else (scale,) + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if self.use_rope: + # require the shape of (batch_size x nheads x ntokens x dim) + if position_q.ndim == 3: + query = self.rope2d(query, position_q) + elif position_q.ndim == 2: + query = self.rope1d(query, position_q) + else: + raise NotImplementedError + if position_k.ndim == 3: + key = self.rope2d(key, position_k) + elif position_k.ndim == 2: + key = self.rope1d(key, position_k) + else: + raise NotImplementedError + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock_(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + attention_mode: str = "xformers", + use_rope: bool = False, + rope_scaling: Optional[Dict] = None, + compress_kv_factor: Optional[Tuple] = None, + block_idx: Optional[int] = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + attention_mode=attention_mode, + use_rope=use_rope, + rope_scaling=rope_scaling, + compress_kv_factor=compress_kv_factor, + ) + + # # 2. Cross-Attn + # if cross_attention_dim is not None or double_self_attention: + # # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # # the second cross attention block. + # self.norm2 = ( + # AdaLayerNorm(dim, num_embeds_ada_norm) + # if self.use_ada_layer_norm + # else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + # ) + # self.attn2 = Attention( + # query_dim=dim, + # cross_attention_dim=cross_attention_dim if not double_self_attention else None, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # upcast_attention=upcast_attention, + # ) # is self-attn if encoder_hidden_states is none + # else: + # self.norm2 = None + # self.attn2 = None + + # 3. Feed-forward + # if not self.use_ada_layer_norm_single: + # self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # pab + self.last_out = None + self.count = 0 + self.block_idx = block_idx + self.temp_mlp_count = 0 + + # parallel + self.parallel_manager: ParallelManager = None + + def set_last_out(self, last_out: torch.Tensor): + self.last_out = last_out + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + position_q: Optional[torch.LongTensor] = None, + position_k: Optional[torch.LongTensor] = None, + frame: int = None, + org_timestep: Optional[torch.LongTensor] = None, + all_timesteps=None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + broadcast_temporal, self.count = if_broadcast_temporal(int(org_timestep[0]), self.count) + if broadcast_temporal: + attn_output = self.last_out + assert self.use_ada_layer_norm_single + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + else: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_single: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + if self.parallel_manager.sp_size > 1: + norm_hidden_states = self.dynamic_switch(norm_hidden_states, to_spatial_shard=True) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + position_q=position_q, + position_k=position_k, + last_shape=frame, + **cross_attention_kwargs, + ) + + if self.parallel_manager.sp_size > 1: + attn_output = self.dynamic_switch(attn_output, to_spatial_shard=False) + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + if enable_pab(): + self.set_last_out(attn_output) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # # 3. Cross-Attention + # if self.attn2 is not None: + # if self.use_ada_layer_norm: + # norm_hidden_states = self.norm2(hidden_states, timestep) + # elif self.use_ada_layer_norm_zero or self.use_layer_norm: + # norm_hidden_states = self.norm2(hidden_states) + # elif self.use_ada_layer_norm_single: + # # For PixArt norm2 isn't applied here: + # # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + # norm_hidden_states = hidden_states + # else: + # raise ValueError("Incorrect norm") + + # if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + # norm_hidden_states = self.pos_embed(norm_hidden_states) + + # attn_output = self.attn2( + # norm_hidden_states, + # encoder_hidden_states=encoder_hidden_states, + # attention_mask=encoder_attention_mask, + # **cross_attention_kwargs, + # ) + # hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # if not self.use_ada_layer_norm_single: + # norm_hidden_states = self.norm3(hidden_states) + + if enable_pab(): + broadcast_mlp, self.temp_mlp_count, broadcast_next, broadcast_range = if_broadcast_mlp( + int(org_timestep[0]), + self.temp_mlp_count, + self.block_idx, + all_timesteps.tolist(), + is_temporal=True, + ) + + if enable_pab() and broadcast_mlp: + ff_output = get_mlp_output( + broadcast_range, + timestep=int(org_timestep[0]), + block_idx=self.block_idx, + is_temporal=True, + ) + else: + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + # norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = self.norm3(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice, scale=lora_scale) + for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + if enable_pab() and broadcast_next: + save_mlp_output( + timestep=int(org_timestep[0]), + block_idx=self.block_idx, + ff_output=ff_output, + is_temporal=True, + ) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + def dynamic_switch(self, x, to_spatial_shard: bool): + if to_spatial_shard: + scatter_dim, gather_dim = 0, 1 + scatter_pad = get_pad("spatial") + gather_pad = get_pad("temporal") + else: + scatter_dim, gather_dim = 1, 0 + scatter_pad = get_pad("temporal") + gather_pad = get_pad("spatial") + x = all_to_all_with_pad( + x, + self.parallel_manager.sp_group, + scatter_dim=scatter_dim, + gather_dim=gather_dim, + scatter_pad=scatter_pad, + gather_pad=gather_pad, + ) + return x + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + attention_mode: str = "xformers", + use_rope: bool = False, + rope_scaling: Optional[Dict] = None, + compress_kv_factor: Optional[Tuple] = None, + block_idx: Optional[int] = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + attention_mode=attention_mode, + use_rope=use_rope, + rope_scaling=rope_scaling, + compress_kv_factor=compress_kv_factor, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + attention_mode=attention_mode, # only xformers support attention_mask + use_rope=False, # do not position in cross attention + compress_kv_factor=None, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if not self.use_ada_layer_norm_single: + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # pab + self.cross_last = None + self.cross_count = 0 + self.spatial_last = None + self.spatial_count = 0 + self.block_idx = block_idx + self.spatila_mlp_count = 0 + + def set_cross_last(self, last_out: torch.Tensor): + self.cross_last = last_out + + def set_spatial_last(self, last_out: torch.Tensor): + self.spatial_last = last_out + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + position_q: Optional[torch.LongTensor] = None, + position_k: Optional[torch.LongTensor] = None, + hw: Tuple[int, int] = None, + org_timestep: Optional[torch.LongTensor] = None, + all_timesteps=None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + broadcast_spatial, self.spatial_count = if_broadcast_spatial(int(org_timestep[0]), self.spatial_count) + if broadcast_spatial: + attn_output = self.spatial_last + assert self.use_ada_layer_norm_single + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + else: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_single: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + position_q=position_q, + position_k=position_k, + last_shape=hw, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + if enable_pab(): + self.set_spatial_last(attn_output) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + broadcast_cross, self.cross_count = if_broadcast_cross(int(org_timestep[0]), self.cross_count) + if broadcast_cross: + hidden_states = hidden_states + self.cross_last + else: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.use_ada_layer_norm_zero or self.use_layer_norm: + norm_hidden_states = self.norm2(hidden_states) + elif self.use_ada_layer_norm_single: + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_q=None, # cross attn do not need relative position + position_k=None, + last_shape=None, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + if enable_pab(): + self.set_cross_last(attn_output) + + if enable_pab(): + broadcast_mlp, self.spatila_mlp_count, broadcast_next, broadcast_range = if_broadcast_mlp( + int(org_timestep[0]), + self.spatila_mlp_count, + self.block_idx, + all_timesteps.tolist(), + is_temporal=False, + ) + + if enable_pab() and broadcast_mlp: + ff_output = get_mlp_output( + broadcast_range, + timestep=int(org_timestep[0]), + block_idx=self.block_idx, + is_temporal=False, + ) + else: + # 4. Feed-forward + if not self.use_ada_layer_norm_single: + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + if enable_pab() and broadcast_next: + save_mlp_output( + timestep=int(org_timestep[0]), + block_idx=self.block_idx, + ff_output=ff_output, + is_temporal=False, + ) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = CombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + batch_size: int = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb( + timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None + ) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class LatteT2V(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + patch_size_t: int = 1, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + video_length: int = 16, + attention_mode: str = "flash", + use_rope: bool = False, + model_max_length: int = 300, + rope_scaling_type: str = "linear", + compress_kv_factor: int = 1, + interpolation_scale_1d: float = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.video_length = video_length + self.use_rope = use_rope + self.model_max_length = model_max_length + self.compress_kv_factor = compress_kv_factor + self.num_layers = num_layers + self.config.hidden_size = model_max_length + + assert not (self.compress_kv_factor != 1 and use_rope), "Can not both enable compressing kv and using rope" + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + # self.is_input_patches = in_channels is not None and patch_size is not None + self.is_input_patches = True + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + # 2. Define input layers + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size[0] + self.width = sample_size[1] + + self.patch_size = patch_size + interpolation_scale_2d = self.config.sample_size[0] // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale_2d = max(interpolation_scale_2d, 1) + self.pos_embed = PatchEmbed( + height=sample_size[0], + width=sample_size[1], + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale_2d, + ) + + # define temporal positional embedding + if interpolation_scale_1d is None: + if self.config.video_length % 2 == 1: + interpolation_scale_1d = ( + self.config.video_length - 1 + ) // 16 # => 16 (= 16 Latte) has interpolation scale 1 + else: + interpolation_scale_1d = self.config.video_length // 16 # => 16 (= 16 Latte) has interpolation scale 1 + # interpolation_scale_1d = self.config.video_length // 5 # + interpolation_scale_1d = max(interpolation_scale_1d, 1) + temp_pos_embed = get_1d_sincos_pos_embed( + inner_dim, video_length, interpolation_scale=interpolation_scale_1d + ) # 1152 hidden size + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + + rope_scaling = None + if self.use_rope: + self.position_getter_2d = PositionGetter2D() + self.position_getter_1d = PositionGetter1D() + rope_scaling = dict( + type=rope_scaling_type, factor_2d=interpolation_scale_2d, factor_1d=interpolation_scale_1d + ) + + # 3. Define transformers blocks, spatial attention + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + attention_mode=attention_mode, + use_rope=use_rope, + rope_scaling=rope_scaling, + compress_kv_factor=(compress_kv_factor, compress_kv_factor) + if d >= num_layers // 2 and compress_kv_factor != 1 + else None, # follow pixart-sigma, apply in second-half layers + block_idx=d, + ) + for d in range(num_layers) + ] + ) + + # Define temporal transformers blocks + self.temporal_transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock_( # one attention + inner_dim, + num_attention_heads, # num_attention_heads + attention_head_dim, # attention_head_dim 72 + dropout=dropout, + cross_attention_dim=None, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=False, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + attention_mode=attention_mode, + use_rope=use_rope, + rope_scaling=rope_scaling, + compress_kv_factor=(compress_kv_factor,) + if d >= num_layers // 2 and compress_kv_factor != 1 + else None, # follow pixart-sigma, apply in second-half layers + block_idx=d, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + # parallel + self.parallel_manager: ParallelManager = None + + def enable_parallel(self, dp_size, sp_size, enable_cp): + # update cfg parallel + if enable_cp and sp_size % 2 == 0: + sp_size = sp_size // 2 + cp_size = 2 + else: + cp_size = 1 + + self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size) + + for _, module in self.named_modules(): + if hasattr(module, "parallel_manager"): + module.parallel_manager = self.parallel_manager + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def make_position(self, b, t, use_image_num, h, w, device): + pos_hw = self.position_getter_2d(b * (t + use_image_num), h, w, device) # fake_b = b*(t+use_image_num) + pos_t = self.position_getter_1d(b * h * w, t, device) # fake_b = b*h*w + return pos_hw, pos_t + + def make_attn_mask(self, attention_mask, frame, dtype): + attention_mask = rearrange(attention_mask, "b t h w -> (b t) 1 (h w)") + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(dtype)) * -10000.0 + attention_mask = attention_mask.to(self.dtype) + return attention_mask + + def vae_to_diff_mask(self, attention_mask, use_image_num): + dtype = attention_mask.dtype + # b, t+use_image_num, h, w, assume t as channel + # this version do not use 3d patch embedding + attention_mask = F.max_pool2d( + attention_mask, kernel_size=(self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size) + ) + attention_mask = attention_mask.bool().to(dtype) + return attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + all_timesteps=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_image_num: int = 0, + enable_temporal_attentions: bool = True, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 0. Split batch + if self.parallel_manager.cp_size > 1: + ( + hidden_states, + timestep, + encoder_hidden_states, + class_labels, + attention_mask, + encoder_attention_mask, + ) = batch_func( + partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), + hidden_states, + timestep, + encoder_hidden_states, + class_labels, + attention_mask, + encoder_attention_mask, + ) + input_batch_size, c, frame, h, w = hidden_states.shape + frame = frame - use_image_num # 20-4=16 + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() + org_timestep = timestep + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is None: + attention_mask = torch.ones( + (input_batch_size, frame + use_image_num, h, w), device=hidden_states.device, dtype=hidden_states.dtype + ) + attention_mask = self.vae_to_diff_mask(attention_mask, use_image_num) + dtype = attention_mask.dtype + attention_mask_compress = F.max_pool2d( + attention_mask.float(), kernel_size=self.compress_kv_factor, stride=self.compress_kv_factor + ) + attention_mask_compress = attention_mask_compress.to(dtype) + + attention_mask = self.make_attn_mask(attention_mask, frame, hidden_states.dtype) + attention_mask_compress = self.make_attn_mask(attention_mask_compress, frame, hidden_states.dtype) + + # 1 + 4, 1 -> video condition, 4 -> image condition + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous() + encoder_attention_mask = encoder_attention_mask.to(self.dtype) + elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask_video = encoder_attention_mask[:, :1, ...] + encoder_attention_mask_video = repeat( + encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame + ).contiguous() + encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...] + encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1) + encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1) + encoder_attention_mask = encoder_attention_mask.to(self.dtype) + + # Retrieve lora scale. + cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + if self.is_input_patches: # here + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hw = (height, width) + num_patches = height * width + + hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # alrady add positional embeddings + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + # batch_size = hidden_states.shape[0] + batch_size = input_batch_size + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states.to(self.dtype)) # 3 120 1152 + + if use_image_num != 0 and self.training: + encoder_hidden_states_video = encoder_hidden_states[:, :1, ...] + encoder_hidden_states_video = repeat( + encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame + ).contiguous() + encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...] + encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) + encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous() + else: + encoder_hidden_states_spatial = repeat( + encoder_hidden_states, "b 1 t d -> (b f) t d", f=frame + ).contiguous() + + # prepare timesteps for spatial and temporal block + timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous() + timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous() + + pos_hw, pos_t = None, None + if self.use_rope: + pos_hw, pos_t = self.make_position( + input_batch_size, frame, use_image_num, height, width, hidden_states.device + ) + + if self.parallel_manager.sp_size > 1: + set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group) + set_pad("spatial", num_patches, self.parallel_manager.sp_group) + hidden_states = self.split_from_second_dim(hidden_states, input_batch_size) + encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size) + timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size) + attention_mask = self.split_from_second_dim(attention_mask, input_batch_size) + attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size) + temp_pos_embed = split_sequence( + self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") + ) + else: + temp_pos_embed = self.temp_pos_embed + + for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + spatial_block, + hidden_states, + attention_mask_compress if i >= self.num_layers // 2 else attention_mask, + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + cross_attention_kwargs, + class_labels, + pos_hw, + pos_hw, + hw, + use_reentrant=False, + ) + + if enable_temporal_attentions: + hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous() + + if use_image_num != 0: # image-video joitn training + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + # if i == 0 and not self.use_rope: + if i == 0: + hidden_states_video = hidden_states_video + temp_pos_embed + + hidden_states_video = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + pos_t, + pos_t, + (frame,), + use_reentrant=False, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange( + hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size + ).contiguous() + + else: + # if i == 0 and not self.use_rope: + if i == 0: + hidden_states = hidden_states + temp_pos_embed + + hidden_states = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + pos_t, + pos_t, + (frame,), + use_reentrant=False, + ) + + hidden_states = rearrange( + hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size + ).contiguous() + else: + hidden_states = spatial_block( + hidden_states, + attention_mask_compress if i >= self.num_layers // 2 else attention_mask, + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + cross_attention_kwargs, + class_labels, + pos_hw, + pos_hw, + hw, + org_timestep, + all_timesteps=all_timesteps, + ) + + if enable_temporal_attentions: + # b c f h w, f = 16 + 4 + hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous() + + if use_image_num != 0 and self.training: + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + # if i == 0 and not self.use_rope: + # hidden_states_video = hidden_states_video + temp_pos_embed + + hidden_states_video = temp_block( + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + pos_t, + pos_t, + (frame,), + org_timestep, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange( + hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size + ).contiguous() + + else: + # if i == 0 and not self.use_rope: + if i == 0: + hidden_states = hidden_states + temp_pos_embed + hidden_states = temp_block( + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + pos_t, + pos_t, + (frame,), + org_timestep, + all_timesteps=all_timesteps, + ) + + hidden_states = rearrange( + hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size + ).contiguous() + + if self.parallel_manager.sp_size > 1: + hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size) + + if self.is_input_patches: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous() + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous() + + # 3. Gather batch for data parallelism + if self.parallel_manager.cp_size > 1: + output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) + + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + config_file = os.path.join(pretrained_model_path, "config.json") + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + model = cls.from_config(config, **kwargs) + return model + + def split_from_second_dim(self, x, batch_size): + x = x.view(batch_size, -1, *x.shape[1:]) + x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")) + x = x.reshape(-1, *x.shape[2:]) + return x + + def gather_from_second_dim(self, x, batch_size): + x = x.view(batch_size, -1, *x.shape[1:]) + x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal")) + x = x.reshape(-1, *x.shape[2:]) + return x + + +# depth = num_layers * 2 +def LatteT2V_XL_122(**kwargs): + return LatteT2V( + num_layers=28, + attention_head_dim=72, + num_attention_heads=16, + patch_size_t=1, + patch_size=2, + norm_type="ada_norm_single", + caption_channels=4096, + cross_attention_dim=1152, + **kwargs, + ) + + +def LatteT2V_D64_XL_122(**kwargs): + return LatteT2V( + num_layers=28, + attention_head_dim=64, + num_attention_heads=18, + patch_size_t=1, + patch_size=2, + norm_type="ada_norm_single", + caption_channels=4096, + cross_attention_dim=1152, + **kwargs, + ) + + +Latte_models = { + "LatteT2V-XL/122": LatteT2V_XL_122, + "LatteT2V-D64-XL/122": LatteT2V_D64_XL_122, +} diff --git a/videosys/models/transformers/open_sora_plan_v120_transformer_3d.py b/videosys/models/transformers/open_sora_plan_v120_transformer_3d.py new file mode 100644 index 0000000..761d0d9 --- /dev/null +++ b/videosys/models/transformers/open_sora_plan_v120_transformer_3d.py @@ -0,0 +1,2183 @@ +# Adapted from Open-Sora-Plan + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan +# -------------------------------------------------------- + +import collections +import re +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention import FeedForward, GatedSelfAttentionDense +from diffusers.models.attention_processor import Attention as Attention_ +from diffusers.models.embeddings import PixArtAlphaTextProjection, SinusoidalPositionalEmbedding +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero +from diffusers.utils import deprecate, is_torch_version +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange, repeat +from torch import nn +from torch.nn import functional as F + +from videosys.core.comm import all_to_all_comm, gather_sequence, split_sequence +from videosys.core.pab_mgr import enable_pab, if_broadcast_cross, if_broadcast_spatial +from videosys.core.parallel_mgr import ParallelManager +from videosys.core.pipeline import VideoSysPipelineOutput + +torch_npu = None +npu_config = None +set_run_dtype = None + + +class PositionGetter3D(object): + """return positions of patches""" + + def __init__( + self, + ): + self.cache_positions = {} + + def __call__(self, b, t, h, w, device): + if not (b, t, h, w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + z = torch.arange(t, device=device) + pos = torch.cartesian_prod(z, y, x) + pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, b, -1).clone() + poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous()) + max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max())) + + self.cache_positions[b, t, h, w] = (poses, max_poses) + pos = self.cache_positions[b, t, h, w] + + return pos + + +class RoPE3D(torch.nn.Module): + def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)): + super().__init__() + self.base = freq + self.F0 = F0 + self.interpolation_scale_t = interpolation_scale_thw[0] + self.interpolation_scale_h = interpolation_scale_thw[1] + self.interpolation_scale_w = interpolation_scale_thw[2] + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + # for (batch_size x ntokens x nheads x dim) + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 3 (t, y and x position of each token) + output: + * tokens after appplying RoPE3D (batch_size x nheads x ntokens x x dim) + """ + assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three" + D = tokens.size(3) // 3 + poses, max_poses = positions + assert len(poses) == 3 and poses[0].ndim == 2 # Batch, Seq, 3 + cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t) + cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h) + cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w) + # split features into three along the feature dimension, and apply rope1d on each half + t, y, x = tokens.chunk(3, dim=-1) + t = self.apply_rope1d(t, poses[0], cos_t, sin_t) + y = self.apply_rope1d(y, poses[1], cos_y, sin_y) + x = self.apply_rope1d(x, poses[2], cos_x, sin_x) + tokens = torch.cat((t, y, x), dim=-1) + return tokens + + +def get_3d_sincos_pos_embed( + embed_dim, + grid_size, + cls_token=False, + extra_tokens=0, + interpolation_scale=1.0, + base_size=16, +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + # if isinstance(grid_size, int): + # grid_size = (grid_size, grid_size) + grid_t = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size[0]) / interpolation_scale[0] + grid_h = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size[1]) / interpolation_scale[1] + grid_w = np.arange(grid_size[2], dtype=np.float32) / (grid_size[2] / base_size[2]) / interpolation_scale[2] + grid = np.meshgrid(grid_w, grid_h, grid_t) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([3, 1, grid_size[2], grid_size[1], grid_size[0]]) + pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) + # import ipdb;ipdb.set_trace() + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 3 != 0: + raise ValueError("embed_dim must be divisible by 3") + + # import ipdb;ipdb.set_trace() + # use 1/3 of dimensions to encode grid_t/h/w + emb_t = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (T*H*W, D/3) + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (T*H*W, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (T*H*W, D/3) + + emb = np.concatenate([emb_t, emb_h, emb_w], axis=1) # (T*H*W, D) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + cls_token=False, + extra_tokens=0, + interpolation_scale=1.0, + base_size=16, +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + # if isinstance(grid_size, int): + # grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size[0]) / interpolation_scale[0] + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size[1]) / interpolation_scale[1] + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use 1/3 of dimensions to encode grid_t/h/w + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed( + embed_dim, + grid_size, + cls_token=False, + extra_tokens=0, + interpolation_scale=1.0, + base_size=16, +): + """ + grid_size: int of the grid return: pos_embed: [grid_size, embed_dim] or + [1+grid_size, embed_dim] (w/ or w/o cls_token) + """ + # if isinstance(grid_size, int): + # grid_size = (grid_size, grid_size) + + grid = np.arange(grid_size, dtype=np.float32) / (grid_size / base_size) / interpolation_scale + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) # (H*W, D/2) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class PatchEmbed2D(nn.Module): + """2D Image to Patch Embedding but with 3D position embedding""" + + def __init__( + self, + num_frames=1, + height=224, + width=224, + patch_size_t=1, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=(1, 1), + interpolation_scale_t=1, + use_abs_pos=True, + ): + super().__init__() + # assert num_frames == 1 + self.use_abs_pos = use_abs_pos + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size_t = patch_size_t + self.patch_size = patch_size + # See: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 + + self.height, self.width = height // patch_size, width // patch_size + self.base_size = (height // patch_size, width // patch_size) + self.interpolation_scale = (interpolation_scale[0], interpolation_scale[1]) + pos_embed = get_2d_sincos_pos_embed( + embed_dim, (self.height, self.width), base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + self.num_frames = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t + self.base_size_t = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t + self.interpolation_scale_t = interpolation_scale_t + temp_pos_embed = get_1d_sincos_pos_embed( + embed_dim, self.num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t + ) + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + # self.temp_embed_gate = nn.Parameter(torch.tensor([0.0])) + + def forward(self, latent, num_frames): + b, _, _, _, _ = latent.shape + video_latent, image_latent = None, None + # b c 1 h w + # assert latent.shape[-3] == 1 and num_frames == 1 + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + latent = rearrange(latent, "b c t h w -> (b t) c h w") + latent = self.proj(latent) + + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C + if self.layer_norm: + latent = self.norm(latent) + + if self.use_abs_pos: + # Interpolate positional embeddings if needed. + # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) + if self.height != height or self.width != width: + # raise NotImplementedError + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed) + pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + if self.num_frames != num_frames: + raise NotImplementedError + else: + temp_pos_embed = self.temp_pos_embed + + latent = (latent + pos_embed).to(latent.dtype) + + latent = rearrange(latent, "(b t) n c -> b t n c", b=b) + video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:] + + if self.use_abs_pos: + # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh() + temp_pos_embed = temp_pos_embed.unsqueeze(2) + video_latent = ( + (video_latent + temp_pos_embed).to(video_latent.dtype) + if video_latent is not None and video_latent.numel() > 0 + else None + ) + image_latent = ( + (image_latent + temp_pos_embed[:, :1]).to(image_latent.dtype) + if image_latent is not None and image_latent.numel() > 0 + else None + ) + + video_latent = ( + rearrange(video_latent, "b t n c -> b (t n) c") + if video_latent is not None and video_latent.numel() > 0 + else None + ) + image_latent = ( + rearrange(image_latent, "b t n c -> (b t) n c") + if image_latent is not None and image_latent.numel() > 0 + else None + ) + + if num_frames == 1 and image_latent is None: + image_latent = video_latent + video_latent = None + # print('video_latent is None, image_latent is None', video_latent is None, image_latent is None) + return video_latent, image_latent + + +class OverlapPatchEmbed3D(nn.Module): + """2D Image to Patch Embedding but with 3D position embedding""" + + def __init__( + self, + num_frames=1, + height=224, + width=224, + patch_size_t=1, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=(1, 1), + interpolation_scale_t=1, + use_abs_pos=True, + ): + super().__init__() + # assert patch_size_t == 1 and patch_size == 1 + self.use_abs_pos = use_abs_pos + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=(patch_size_t, patch_size, patch_size), + stride=(patch_size_t, patch_size, patch_size), + bias=bias, + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size_t = patch_size_t + self.patch_size = patch_size + # See: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 + + self.height, self.width = height // patch_size, width // patch_size + self.base_size = (height // patch_size, width // patch_size) + self.interpolation_scale = (interpolation_scale[0], interpolation_scale[1]) + pos_embed = get_2d_sincos_pos_embed( + embed_dim, (self.height, self.width), base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + self.num_frames = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t + self.base_size_t = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t + self.interpolation_scale_t = interpolation_scale_t + temp_pos_embed = get_1d_sincos_pos_embed( + embed_dim, self.num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t + ) + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + # self.temp_embed_gate = nn.Parameter(torch.tensor([0.0])) + + def forward(self, latent, num_frames): + b, _, _, _, _ = latent.shape + video_latent, image_latent = None, None + # b c 1 h w + # assert latent.shape[-3] == 1 and num_frames == 1 + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + # latent = rearrange(latent, 'b c t h w -> (b t) c h w') + latent = self.proj(latent) + + if self.flatten: + # latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C + latent = rearrange(latent, "b c t h w -> (b t) (h w) c ") + if self.layer_norm: + latent = self.norm(latent) + + if self.use_abs_pos: + # Interpolate positional embeddings if needed. + # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) + if self.height != height or self.width != width: + # raise NotImplementedError + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed) + pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + if self.num_frames != num_frames: + # import ipdb;ipdb.set_trace() + # raise NotImplementedError + temp_pos_embed = get_1d_sincos_pos_embed( + embed_dim=self.temp_pos_embed.shape[-1], + grid_size=num_frames, + base_size=self.base_size_t, + interpolation_scale=self.interpolation_scale_t, + ) + temp_pos_embed = torch.from_numpy(temp_pos_embed) + temp_pos_embed = temp_pos_embed.float().unsqueeze(0).to(latent.device) + else: + temp_pos_embed = self.temp_pos_embed + + latent = (latent + pos_embed).to(latent.dtype) + + latent = rearrange(latent, "(b t) n c -> b t n c", b=b) + video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:] + + if self.use_abs_pos: + # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh() + temp_pos_embed = temp_pos_embed.unsqueeze(2) + video_latent = ( + (video_latent + temp_pos_embed).to(video_latent.dtype) + if video_latent is not None and video_latent.numel() > 0 + else None + ) + image_latent = ( + (image_latent + temp_pos_embed[:, :1]).to(image_latent.dtype) + if image_latent is not None and image_latent.numel() > 0 + else None + ) + + video_latent = ( + rearrange(video_latent, "b t n c -> b (t n) c") + if video_latent is not None and video_latent.numel() > 0 + else None + ) + image_latent = ( + rearrange(image_latent, "b t n c -> (b t) n c") + if image_latent is not None and image_latent.numel() > 0 + else None + ) + + if num_frames == 1 and image_latent is None: + image_latent = video_latent + video_latent = None + return video_latent, image_latent + + +class OverlapPatchEmbed2D(nn.Module): + """2D Image to Patch Embedding but with 3D position embedding""" + + def __init__( + self, + num_frames=1, + height=224, + width=224, + patch_size_t=1, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=(1, 1), + interpolation_scale_t=1, + use_abs_pos=True, + ): + super().__init__() + assert patch_size_t == 1 + self.use_abs_pos = use_abs_pos + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size_t = patch_size_t + self.patch_size = patch_size + # See: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 + + self.height, self.width = height // patch_size, width // patch_size + self.base_size = (height // patch_size, width // patch_size) + self.interpolation_scale = (interpolation_scale[0], interpolation_scale[1]) + pos_embed = get_2d_sincos_pos_embed( + embed_dim, (self.height, self.width), base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + self.num_frames = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t + self.base_size_t = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t + self.interpolation_scale_t = interpolation_scale_t + temp_pos_embed = get_1d_sincos_pos_embed( + embed_dim, self.num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t + ) + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + # self.temp_embed_gate = nn.Parameter(torch.tensor([0.0])) + + def forward(self, latent, num_frames): + b, _, _, _, _ = latent.shape + video_latent, image_latent = None, None + # b c 1 h w + # assert latent.shape[-3] == 1 and num_frames == 1 + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + latent = rearrange(latent, "b c t h w -> (b t) c h w") + latent = self.proj(latent) + + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C + if self.layer_norm: + latent = self.norm(latent) + + if self.use_abs_pos: + # Interpolate positional embeddings if needed. + # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) + if self.height != height or self.width != width: + # raise NotImplementedError + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed) + pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + if self.num_frames != num_frames: + # import ipdb;ipdb.set_trace() + # raise NotImplementedError + temp_pos_embed = get_1d_sincos_pos_embed( + embed_dim=self.temp_pos_embed.shape[-1], + grid_size=num_frames, + base_size=self.base_size_t, + interpolation_scale=self.interpolation_scale_t, + ) + temp_pos_embed = torch.from_numpy(temp_pos_embed) + temp_pos_embed = temp_pos_embed.float().unsqueeze(0).to(latent.device) + else: + temp_pos_embed = self.temp_pos_embed + + latent = (latent + pos_embed).to(latent.dtype) + + latent = rearrange(latent, "(b t) n c -> b t n c", b=b) + video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:] + + if self.use_abs_pos: + # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh() + temp_pos_embed = temp_pos_embed.unsqueeze(2) + video_latent = ( + (video_latent + temp_pos_embed).to(video_latent.dtype) + if video_latent is not None and video_latent.numel() > 0 + else None + ) + image_latent = ( + (image_latent + temp_pos_embed[:, :1]).to(image_latent.dtype) + if image_latent is not None and image_latent.numel() > 0 + else None + ) + + video_latent = ( + rearrange(video_latent, "b t n c -> b (t n) c") + if video_latent is not None and video_latent.numel() > 0 + else None + ) + image_latent = ( + rearrange(image_latent, "b t n c -> (b t) n c") + if image_latent is not None and image_latent.numel() > 0 + else None + ) + + if num_frames == 1 and image_latent is None: + image_latent = video_latent + video_latent = None + return video_latent, image_latent + + +class Attention(Attention_): + def __init__(self, downsampler, attention_mode, use_rope, interpolation_scale_thw, **kwags): + processor = AttnProcessor2_0( + attention_mode=attention_mode, use_rope=use_rope, interpolation_scale_thw=interpolation_scale_thw + ) + super().__init__(processor=processor, **kwags) + self.downsampler = None + if downsampler: # downsampler k155_s122 + downsampler_ker_size = list(re.search(r"k(\d{2,3})", downsampler).group(1)) # 122 + down_factor = list(re.search(r"s(\d{2,3})", downsampler).group(1)) + downsampler_ker_size = [int(i) for i in downsampler_ker_size] + downsampler_padding = [(i - 1) // 2 for i in downsampler_ker_size] + down_factor = [int(i) for i in down_factor] + + if len(downsampler_ker_size) == 2: + self.downsampler = DownSampler2d( + kwags["query_dim"], + kwags["query_dim"], + kernel_size=downsampler_ker_size, + stride=1, + padding=downsampler_padding, + groups=kwags["query_dim"], + down_factor=down_factor, + down_shortcut=True, + ) + elif len(downsampler_ker_size) == 3: + self.downsampler = DownSampler3d( + kwags["query_dim"], + kwags["query_dim"], + kernel_size=downsampler_ker_size, + stride=1, + padding=downsampler_padding, + groups=kwags["query_dim"], + down_factor=down_factor, + down_shortcut=True, + ) + + # parallel + self.parallel_manager: ParallelManager = None + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + +class DownSampler3d(nn.Module): + def __init__(self, *args, **kwargs): + """Required kwargs: down_factor, downsampler""" + super().__init__() + self.down_factor = kwargs.pop("down_factor") + self.down_shortcut = kwargs.pop("down_shortcut") + self.layer = nn.Conv3d(*args, **kwargs) + + def forward(self, x, attention_mask, t, h, w): + x.shape[0] + x = rearrange(x, "b (t h w) d -> b d t h w", t=t, h=h, w=w) + if npu_config is None: + x = self.layer(x) + (x if self.down_shortcut else 0) + else: + x_dtype = x.dtype + x = npu_config.run_conv3d(self.layer, x, x_dtype) + (x if self.down_shortcut else 0) + + self.t = t // self.down_factor[0] + self.h = h // self.down_factor[1] + self.w = w // self.down_factor[2] + x = rearrange( + x, + "b d (t dt) (h dh) (w dw) -> (b dt dh dw) (t h w) d", + t=t // self.down_factor[0], + h=h // self.down_factor[1], + w=w // self.down_factor[2], + dt=self.down_factor[0], + dh=self.down_factor[1], + dw=self.down_factor[2], + ) + + attention_mask = rearrange(attention_mask, "b 1 (t h w) -> b 1 t h w", t=t, h=h, w=w) + attention_mask = rearrange( + attention_mask, + "b 1 (t dt) (h dh) (w dw) -> (b dt dh dw) 1 (t h w)", + t=t // self.down_factor[0], + h=h // self.down_factor[1], + w=w // self.down_factor[2], + dt=self.down_factor[0], + dh=self.down_factor[1], + dw=self.down_factor[2], + ) + return x, attention_mask + + def reverse(self, x, t, h, w): + x = rearrange( + x, + "(b dt dh dw) (t h w) d -> b (t dt h dh w dw) d", + t=t, + h=h, + w=w, + dt=self.down_factor[0], + dh=self.down_factor[1], + dw=self.down_factor[2], + ) + return x + + +class DownSampler2d(nn.Module): + def __init__(self, *args, **kwargs): + """Required kwargs: down_factor, downsampler""" + super().__init__() + self.down_factor = kwargs.pop("down_factor") + self.down_shortcut = kwargs.pop("down_shortcut") + self.layer = nn.Conv2d(*args, **kwargs) + + def forward(self, x, attention_mask, t, h, w): + x.shape[0] + x = rearrange(x, "b (t h w) d -> (b t) d h w", t=t, h=h, w=w) + x = self.layer(x) + (x if self.down_shortcut else 0) + + self.t = 1 + self.h = h // self.down_factor[0] + self.w = w // self.down_factor[1] + + x = rearrange( + x, + "b d (h dh) (w dw) -> (b dh dw) (h w) d", + h=h // self.down_factor[0], + w=w // self.down_factor[1], + dh=self.down_factor[0], + dw=self.down_factor[1], + ) + + attention_mask = rearrange(attention_mask, "b 1 (t h w) -> (b t) 1 h w", h=h, w=w) + attention_mask = rearrange( + attention_mask, + "b 1 (h dh) (w dw) -> (b dh dw) 1 (h w)", + h=h // self.down_factor[0], + w=w // self.down_factor[1], + dh=self.down_factor[0], + dw=self.down_factor[1], + ) + return x, attention_mask + + def reverse(self, x, t, h, w): + x = rearrange( + x, "(b t dh dw) (h w) d -> b (t h dh w dw) d", t=t, h=h, w=w, dh=self.down_factor[0], dw=self.down_factor[1] + ) + return x + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, attention_mode="xformers", use_rope=False, interpolation_scale_thw=(1, 1, 1)): + self.use_rope = use_rope + self.interpolation_scale_thw = interpolation_scale_thw + if self.use_rope: + self._init_rope(interpolation_scale_thw) + self.attention_mode = attention_mode + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def _init_rope(self, interpolation_scale_thw): + self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw) + self.position_getter = PositionGetter3D() + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + frame: int = 8, + height: int = 16, + width: int = 16, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + if attn.downsampler is not None: + hidden_states, attention_mask = attn.downsampler(hidden_states, attention_mask, t=frame, h=height, w=width) + frame, height, width = attn.downsampler.t, attn.downsampler.h, attn.downsampler.w + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.parallel_manager.sp_size > 1 and query.shape[2] == key.shape[2]: + func = lambda x: all_to_all_comm( + x, process_group=attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2 + ) + query, key, value = map(func, [query, key, value]) + + if self.use_rope: + # require the shape of (batch_size x nheads x ntokens x dim) + pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device) + query = self.rope(query, pos_thw) + key = self.rope(key, pos_thw) + + # 0, -10000 ->(bool) False, True ->(any) True ->(not) False + # 0, 0 ->(bool) False, False ->(any) False ->(not) True + if attention_mask is None or not torch.any(attention_mask.bool()): # 0 mean visible + attention_mask = None + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + if attn.parallel_manager.sp_size > 1 and query.shape[2] == key.shape[2]: + hidden_states = all_to_all_comm( + hidden_states, process_group=attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1 + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + if attn.downsampler is not None: + hidden_states = attn.downsampler.reverse(hidden_states, t=frame, h=height, w=width) + return hidden_states + + +class FeedForward_Conv3d(nn.Module): + def __init__(self, downsampler, dim, hidden_features, bias=True): + super(FeedForward_Conv3d, self).__init__() + + self.bias = bias + + self.project_in = nn.Linear(dim, hidden_features, bias=bias) + + self.dwconv = nn.ModuleList( + [ + nn.Conv3d( + hidden_features, + hidden_features, + kernel_size=(5, 5, 5), + stride=1, + padding=(2, 2, 2), + dilation=1, + groups=hidden_features, + bias=bias, + ), + nn.Conv3d( + hidden_features, + hidden_features, + kernel_size=(3, 3, 3), + stride=1, + padding=(1, 1, 1), + dilation=1, + groups=hidden_features, + bias=bias, + ), + nn.Conv3d( + hidden_features, + hidden_features, + kernel_size=(1, 1, 1), + stride=1, + padding=(0, 0, 0), + dilation=1, + groups=hidden_features, + bias=bias, + ), + ] + ) + + self.project_out = nn.Linear(hidden_features, dim, bias=bias) + + def forward(self, x, t, h, w): + # import ipdb;ipdb.set_trace() + if npu_config is None: + x = self.project_in(x) + x = rearrange(x, "b (t h w) d -> b d t h w", t=t, h=h, w=w) + x = F.gelu(x) + out = x + for module in self.dwconv: + out = out + module(x) + out = rearrange(out, "b d t h w -> b (t h w) d", t=t, h=h, w=w) + x = self.project_out(out) + else: + x_dtype = x.dtype + x = npu_config.run_conv3d(self.project_in, x, npu_config.replaced_type) + x = rearrange(x, "b (t h w) d -> b d t h w", t=t, h=h, w=w) + x = F.gelu(x) + out = x + for module in self.dwconv: + out = out + npu_config.run_conv3d(module, x, npu_config.replaced_type) + out = rearrange(out, "b d t h w -> b (t h w) d", t=t, h=h, w=w) + x = npu_config.run_conv3d(self.project_out, out, x_dtype) + return x + + +class FeedForward_Conv2d(nn.Module): + def __init__(self, downsampler, dim, hidden_features, bias=True): + super(FeedForward_Conv2d, self).__init__() + + self.bias = bias + + self.project_in = nn.Linear(dim, hidden_features, bias=bias) + + self.dwconv = nn.ModuleList( + [ + nn.Conv2d( + hidden_features, + hidden_features, + kernel_size=(5, 5), + stride=1, + padding=(2, 2), + dilation=1, + groups=hidden_features, + bias=bias, + ), + nn.Conv2d( + hidden_features, + hidden_features, + kernel_size=(3, 3), + stride=1, + padding=(1, 1), + dilation=1, + groups=hidden_features, + bias=bias, + ), + nn.Conv2d( + hidden_features, + hidden_features, + kernel_size=(1, 1), + stride=1, + padding=(0, 0), + dilation=1, + groups=hidden_features, + bias=bias, + ), + ] + ) + + self.project_out = nn.Linear(hidden_features, dim, bias=bias) + + def forward(self, x, t, h, w): + # import ipdb;ipdb.set_trace() + x = self.project_in(x) + x = rearrange(x, "b (t h w) d -> (b t) d h w", t=t, h=h, w=w) + x = F.gelu(x) + out = x + for module in self.dwconv: + out = out + module(x) + out = rearrange(out, "(b t) d h w -> b (t h w) d", t=t, h=h, w=w) + x = self.project_out(out) + return x + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + attention_mode: str = "xformers", + downsampler: str = None, + use_rope: bool = False, + interpolation_scale_thw: Tuple[int] = (1, 1, 1), + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.downsampler = downsampler + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + attention_mode=attention_mode, + downsampler=downsampler, + use_rope=use_rope, + interpolation_scale_thw=interpolation_scale_thw, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + attention_mode=attention_mode, + downsampler=False, + use_rope=False, + interpolation_scale_thw=interpolation_scale_thw, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + if downsampler: + downsampler_ker_size = list(re.search(r"k(\d{2,3})", downsampler).group(1)) # 122 + # if len(downsampler_ker_size) == 3: + # self.ff = FeedForward_Conv3d( + # downsampler, + # dim, + # 2 * dim, + # bias=ff_bias, + # ) + # elif len(downsampler_ker_size) == 2: + self.ff = FeedForward_Conv2d( + downsampler, + dim, + 2 * dim, + bias=ff_bias, + ) + else: + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # pab + self.spatial_last = None + self.spatial_count = 0 + self.cross_last = None + self.cross_count = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + frame: int = None, + height: int = None, + width: int = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + org_timestep: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + # import ipdb;ipdb.set_trace() + if self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + else: + raise ValueError("Incorrect norm used") + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + broadcast, self.spatial_count = if_broadcast_spatial(int(org_timestep[0]), self.spatial_count) + if broadcast: + attn_output = self.spatial_last + else: + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + frame=frame, + height=height, + width=width, + **cross_attention_kwargs, + ) + + if enable_pab(): + self.spatial_last = attn_output + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + broadcast, self.cross_count = if_broadcast_cross(int(org_timestep[0]), self.cross_count) + if broadcast: + attn_output = self.cross_last + + else: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + if enable_pab(): + self.cross_last = attn_output + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + # if self._chunk_size is not None: + # # "feed_forward_chunk_size" can be used to save memory + # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + # else: + + if self.downsampler: + ff_output = self.ff(norm_hidden_states, t=frame, h=height, w=width) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +class OpenSoraT2V(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + sample_size_t: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + patch_size_t: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + interpolation_scale_h: float = None, + interpolation_scale_w: float = None, + interpolation_scale_t: float = None, + use_additional_conditions: Optional[bool] = None, + attention_mode: str = "xformers", + downsampler: str = None, + use_rope: bool = False, + use_stable_fp32: bool = False, + ): + super().__init__() + + # Validate inputs. + if patch_size is not None: + if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]: + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # Set some common variables used across the board. + self.use_rope = use_rope + self.use_linear_projection = use_linear_projection + self.interpolation_scale_t = interpolation_scale_t + self.interpolation_scale_h = interpolation_scale_h + self.interpolation_scale_w = interpolation_scale_w + self.downsampler = downsampler + self.caption_channels = caption_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.gradient_checkpointing = False + self.config.hidden_size = self.inner_dim + use_additional_conditions = False + # if use_additional_conditions is None: + # if norm_type == "ada_norm_single" and sample_size == 128: + # use_additional_conditions = True + # else: + # use_additional_conditions = False + self.use_additional_conditions = use_additional_conditions + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + assert in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + # 2. Initialize the right blocks. + # Initialize the output blocks and other projection blocks when necessary. + self._init_patched_inputs(norm_type=norm_type) + + # parallel + self.parallel_manager: ParallelManager = None + + def _init_patched_inputs(self, norm_type): + assert self.config.sample_size_t is not None, "OpenSoraT2V over patched input must provide sample_size_t" + assert self.config.sample_size is not None, "OpenSoraT2V over patched input must provide sample_size" + # assert not (self.config.sample_size_t == 1 and self.config.patch_size_t == 2), "Image do not need patchfy in t-dim" + + self.num_frames = self.config.sample_size_t + self.config.sample_size = to_2tuple(self.config.sample_size) + self.height = self.config.sample_size[0] + self.width = self.config.sample_size[1] + self.patch_size_t = self.config.patch_size_t + self.patch_size = self.config.patch_size + interpolation_scale_t = ( + ((self.config.sample_size_t - 1) // 16 + 1) + if self.config.sample_size_t % 2 == 1 + else self.config.sample_size_t / 16 + ) + interpolation_scale_t = ( + self.config.interpolation_scale_t + if self.config.interpolation_scale_t is not None + else interpolation_scale_t + ) + interpolation_scale = ( + self.config.interpolation_scale_h + if self.config.interpolation_scale_h is not None + else self.config.sample_size[0] / 30, + self.config.interpolation_scale_w + if self.config.interpolation_scale_w is not None + else self.config.sample_size[1] / 40, + ) + if self.config.downsampler is not None and len(self.config.downsampler) == 9: + self.pos_embed = OverlapPatchEmbed3D( + num_frames=self.config.sample_size_t, + height=self.config.sample_size[0], + width=self.config.sample_size[1], + patch_size_t=self.config.patch_size_t, + patch_size=self.config.patch_size, + in_channels=self.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + interpolation_scale_t=interpolation_scale_t, + use_abs_pos=not self.config.use_rope, + ) + elif self.config.downsampler is not None and len(self.config.downsampler) == 7: + self.pos_embed = OverlapPatchEmbed2D( + num_frames=self.config.sample_size_t, + height=self.config.sample_size[0], + width=self.config.sample_size[1], + patch_size_t=self.config.patch_size_t, + patch_size=self.config.patch_size, + in_channels=self.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + interpolation_scale_t=interpolation_scale_t, + use_abs_pos=not self.config.use_rope, + ) + + else: + self.pos_embed = PatchEmbed2D( + num_frames=self.config.sample_size_t, + height=self.config.sample_size[0], + width=self.config.sample_size[1], + patch_size_t=self.config.patch_size_t, + patch_size=self.config.patch_size, + in_channels=self.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + interpolation_scale_t=interpolation_scale_t, + use_abs_pos=not self.config.use_rope, + ) + interpolation_scale_thw = (interpolation_scale_t, *interpolation_scale) + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + attention_mode=self.config.attention_mode, + downsampler=self.config.downsampler, + use_rope=self.config.use_rope, + interpolation_scale_thw=interpolation_scale_thw, + ) + for _ in range(self.config.num_layers) + ] + ) + + if self.config.norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear( + self.inner_dim, + self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels, + ) + elif self.config.norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear( + self.inner_dim, + self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels, + ) + + # PixArt-Alpha blocks. + self.adaln_single = None + if self.config.norm_type == "ada_norm_single": + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=self.use_additional_conditions + ) + + self.caption_projection = None + if self.caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, hidden_size=self.inner_dim + ) + + def enable_parallel(self, dp_size, sp_size, enable_cp): + # update cfg parallel + if enable_cp and sp_size % 2 == 0: + sp_size = sp_size // 2 + cp_size = 2 + else: + cp_size = 1 + + self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size) + + for _, module in self.named_modules(): + if hasattr(module, "parallel_manager"): + module.parallel_manager = self.parallel_manager + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_image_num: Optional[int] = 0, + return_dict: bool = True, + **kwargs, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + batch_size, c, frame, h, w = hidden_states.shape + # print('hidden_states.shape', hidden_states.shape) + frame = frame - use_image_num # 21-4=17 + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + attention_mask_vid, attention_mask_img = None, None + if attention_mask is not None and attention_mask.ndim == 4: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + # b, frame+use_image_num, h, w -> a video with images + # b, 1, h, w -> only images + attention_mask = attention_mask.to(self.dtype) + attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w + attention_mask_img = attention_mask[:, frame:] # b, use_image_num, h, w + + if attention_mask_vid.numel() > 0: + attention_mask_vid_first_frame = attention_mask_vid[:, :1].repeat(1, self.patch_size_t - 1, 1, 1) + attention_mask_vid = torch.cat([attention_mask_vid_first_frame, attention_mask_vid], dim=1) + attention_mask_vid = attention_mask_vid.unsqueeze(1) # b 1 t h w + attention_mask_vid = F.max_pool3d( + attention_mask_vid, + kernel_size=(self.patch_size_t, self.patch_size, self.patch_size), + stride=(self.patch_size_t, self.patch_size, self.patch_size), + ) + attention_mask_vid = rearrange(attention_mask_vid, "b 1 t h w -> (b 1) 1 (t h w)") + if attention_mask_img.numel() > 0: + attention_mask_img = F.max_pool2d( + attention_mask_img, + kernel_size=(self.patch_size, self.patch_size), + stride=(self.patch_size, self.patch_size), + ) + attention_mask_img = rearrange(attention_mask_img, "b i h w -> (b i) 1 (h w)") + + attention_mask_vid = ( + (1 - attention_mask_vid.bool().to(self.dtype)) * -10000.0 if attention_mask_vid.numel() > 0 else None + ) + attention_mask_img = ( + (1 - attention_mask_img.bool().to(self.dtype)) * -10000.0 if attention_mask_img.numel() > 0 else None + ) + + if frame == 1 and use_image_num == 0: + attention_mask_img = attention_mask_vid + attention_mask_vid = None + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: + # b, 1+use_image_num, l -> a video with images + # b, 1, l -> only images + encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 + in_t = encoder_attention_mask.shape[1] + encoder_attention_mask_vid = encoder_attention_mask[:, : in_t - use_image_num] # b, 1, l + encoder_attention_mask_vid = ( + rearrange(encoder_attention_mask_vid, "b 1 l -> (b 1) 1 l") + if encoder_attention_mask_vid.numel() > 0 + else None + ) + + encoder_attention_mask_img = encoder_attention_mask[:, in_t - use_image_num :] # b, use_image_num, l + encoder_attention_mask_img = ( + rearrange(encoder_attention_mask_img, "b i l -> (b i) 1 l") + if encoder_attention_mask_img.numel() > 0 + else None + ) + + if frame == 1 and use_image_num == 0: + encoder_attention_mask_img = encoder_attention_mask_vid + encoder_attention_mask_vid = None + + if npu_config is not None and attention_mask_vid is not None: + attention_mask_vid = npu_config.get_attention_mask(attention_mask_vid, attention_mask_vid.shape[-1]) + encoder_attention_mask_vid = npu_config.get_attention_mask( + encoder_attention_mask_vid, attention_mask_vid.shape[-2] + ) + if npu_config is not None and attention_mask_img is not None: + attention_mask_img = npu_config.get_attention_mask(attention_mask_img, attention_mask_img.shape[-1]) + encoder_attention_mask_img = npu_config.get_attention_mask( + encoder_attention_mask_img, attention_mask_img.shape[-2] + ) + + # 1. Input + frame = ((frame - 1) // self.patch_size_t + 1) if frame % 2 == 1 else frame // self.patch_size_t # patchfy + # print('frame', frame) + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + ( + hidden_states_vid, + hidden_states_img, + encoder_hidden_states_vid, + encoder_hidden_states_img, + timestep_vid, + timestep_img, + embedded_timestep_vid, + embedded_timestep_img, + ) = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size, frame, use_image_num + ) + # 2. Blocks + if self.parallel_manager.sp_size > 1: + if hidden_states_vid is not None: + hidden_states_vid = split_sequence( + hidden_states_vid, dim=1, process_group=self.parallel_manager.sp_group, grad_scale="down" + ) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if hidden_states_vid is not None: + hidden_states_vid = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states_vid, + attention_mask_vid, + encoder_hidden_states_vid, + encoder_attention_mask_vid, + timestep_vid, + cross_attention_kwargs, + class_labels, + frame, + height, + width, + **ckpt_kwargs, + ) + if hidden_states_img is not None: + hidden_states_img = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states_img, + attention_mask_img, + encoder_hidden_states_img, + encoder_attention_mask_img, + timestep_img, + cross_attention_kwargs, + class_labels, + 1, + height, + width, + **ckpt_kwargs, + ) + else: + if hidden_states_vid is not None: + hidden_states_vid = block( + hidden_states_vid, + attention_mask=attention_mask_vid, + encoder_hidden_states=encoder_hidden_states_vid, + encoder_attention_mask=encoder_attention_mask_vid, + timestep=timestep_vid, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + frame=frame, + height=height, + width=width, + org_timestep=timestep, + ) + if hidden_states_img is not None: + hidden_states_img = block( + hidden_states_img, + attention_mask=attention_mask_img, + encoder_hidden_states=encoder_hidden_states_img, + encoder_attention_mask=encoder_attention_mask_img, + timestep=timestep_img, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + frame=1, + height=height, + width=width, + org_timestep=timestep, + ) + + if self.parallel_manager.sp_size > 1: + if hidden_states_vid is not None: + hidden_states_vid = gather_sequence( + hidden_states_vid, dim=1, process_group=self.parallel_manager.sp_group, grad_scale="up" + ) + + # 3. Output + output_vid, output_img = None, None + if hidden_states_vid is not None: + output_vid = self._get_output_for_patched_inputs( + hidden_states=hidden_states_vid, + timestep=timestep_vid, + class_labels=class_labels, + embedded_timestep=embedded_timestep_vid, + num_frames=frame, + height=height, + width=width, + ) # b c t h w + if hidden_states_img is not None: + output_img = self._get_output_for_patched_inputs( + hidden_states=hidden_states_img, + timestep=timestep_img, + class_labels=class_labels, + embedded_timestep=embedded_timestep_img, + num_frames=1, + height=height, + width=width, + ) # b c 1 h w + if use_image_num != 0: + output_img = rearrange(output_img, "(b i) c 1 h w -> b c i h w", i=use_image_num) + + if output_vid is not None and output_img is not None: + output = torch.cat([output_vid, output_img], dim=2) + elif output_vid is not None: + output = output_vid + elif output_img is not None: + output = output_img + + if not return_dict: + return (output,) + + return VideoSysPipelineOutput(video=output) + + def _operate_on_patched_inputs( + self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size, frame, use_image_num + ): + # batch_size = hidden_states.shape[0] + hidden_states_vid, hidden_states_img = self.pos_embed(hidden_states.to(self.dtype), frame) + timestep_vid, timestep_img = None, None + embedded_timestep_vid, embedded_timestep_img = None, None + encoder_hidden_states_vid, encoder_hidden_states_img = None, None + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype + ) # b 6d, b d + if hidden_states_vid is None: + timestep_img = timestep + embedded_timestep_img = embedded_timestep + else: + timestep_vid = timestep + embedded_timestep_vid = embedded_timestep + if hidden_states_img is not None: + timestep_img = repeat(timestep, "b d -> (b i) d", i=use_image_num).contiguous() + embedded_timestep_img = repeat(embedded_timestep, "b d -> (b i) d", i=use_image_num).contiguous() + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection( + encoder_hidden_states + ) # b, 1+use_image_num, l, d or b, 1, l, d + if hidden_states_vid is None: + encoder_hidden_states_img = rearrange(encoder_hidden_states, "b 1 l d -> (b 1) l d") + else: + encoder_hidden_states_vid = rearrange(encoder_hidden_states[:, :1], "b 1 l d -> (b 1) l d") + if hidden_states_img is not None: + encoder_hidden_states_img = rearrange(encoder_hidden_states[:, 1:], "b i l d -> (b i) l d") + + return ( + hidden_states_vid, + hidden_states_img, + encoder_hidden_states_vid, + encoder_hidden_states_img, + timestep_vid, + timestep_img, + embedded_timestep_vid, + embedded_timestep_img, + ) + + def _get_output_for_patched_inputs( + self, hidden_states, timestep, class_labels, embedded_timestep, num_frames, height=None, width=None + ): + # import ipdb;ipdb.set_trace() + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=self.dtype) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=( + -1, + num_frames, + height, + width, + self.patch_size_t, + self.patch_size, + self.patch_size, + self.out_channels, + ) + ) + hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states) + output = hidden_states.reshape( + shape=( + -1, + self.out_channels, + num_frames * self.patch_size_t, + height * self.patch_size, + width * self.patch_size, + ) + ) + # import ipdb;ipdb.set_trace() + # if output.shape[2] % 2 == 0: + # output = output[:, :, 1:] + return output + + +def OpenSoraT2V_S_122(**kwargs): + return OpenSoraT2V( + num_layers=28, + attention_head_dim=96, + num_attention_heads=16, + patch_size_t=1, + patch_size=2, + norm_type="ada_norm_single", + caption_channels=4096, + cross_attention_dim=1536, + **kwargs, + ) + + +def OpenSoraT2V_B_122(**kwargs): + return OpenSoraT2V( + num_layers=32, + attention_head_dim=96, + num_attention_heads=16, + patch_size_t=1, + patch_size=2, + norm_type="ada_norm_single", + caption_channels=4096, + cross_attention_dim=1920, + **kwargs, + ) + + +def OpenSoraT2V_L_122(**kwargs): + return OpenSoraT2V( + num_layers=40, + attention_head_dim=128, + num_attention_heads=16, + patch_size_t=1, + patch_size=2, + norm_type="ada_norm_single", + caption_channels=4096, + cross_attention_dim=2048, + **kwargs, + ) + + +def OpenSoraT2V_ROPE_L_122(**kwargs): + return OpenSoraT2V( + num_layers=32, + attention_head_dim=96, + num_attention_heads=24, + patch_size_t=1, + patch_size=2, + norm_type="ada_norm_single", + caption_channels=4096, + cross_attention_dim=2304, + **kwargs, + ) + + +OpenSora_models = { + "OpenSoraT2V-S/122": OpenSoraT2V_S_122, # 1.1B + "OpenSoraT2V-B/122": OpenSoraT2V_B_122, + "OpenSoraT2V-L/122": OpenSoraT2V_L_122, + "OpenSoraT2V-ROPE-L/122": OpenSoraT2V_ROPE_L_122, +} + +OpenSora_models_class = { + "OpenSoraT2V-S/122": OpenSoraT2V, + "OpenSoraT2V-B/122": OpenSoraT2V, + "OpenSoraT2V-L/122": OpenSoraT2V, + "OpenSoraT2V-ROPE-L/122": OpenSoraT2V, +} diff --git a/videosys/models/transformers/open_sora_transformer_3d.py b/videosys/models/transformers/open_sora_transformer_3d.py index 4a9c213..27aec68 100644 --- a/videosys/models/transformers/open_sora_transformer_3d.py +++ b/videosys/models/transformers/open_sora_transformer_3d.py @@ -20,15 +20,7 @@ from timm.models.vision_transformer import Mlp from torch.utils.checkpoint import checkpoint, checkpoint_sequential from transformers import PretrainedConfig, PreTrainedModel -from videosys.core.comm import ( - all_to_all_with_pad, - gather_sequence, - get_spatial_pad, - get_temporal_pad, - set_spatial_pad, - set_temporal_pad, - split_sequence, -) +from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence from videosys.core.pab_mgr import ( enable_pab, get_mlp_output, @@ -38,12 +30,7 @@ from videosys.core.pab_mgr import ( if_broadcast_temporal, save_mlp_output, ) -from videosys.core.parallel_mgr import ( - enable_sequence_parallel, - get_cfg_parallel_size, - get_data_parallel_group, - get_sequence_parallel_group, -) +from videosys.core.parallel_mgr import ParallelManager from videosys.models.modules.activations import approx_gelu from videosys.models.modules.attentions import OpenSoraAttention, OpenSoraMultiHeadCrossAttention from videosys.models.modules.embeddings import ( @@ -143,6 +130,9 @@ class STDiT3Block(nn.Module): self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + # parallel + self.parallel_manager: ParallelManager = None + # pab self.block_idx = block_idx self.attn_count = 0 @@ -188,9 +178,7 @@ class STDiT3Block(nn.Module): if self.temporal: broadcast_attn, self.attn_count = if_broadcast_temporal(int(timestep[0]), self.attn_count) else: - broadcast_attn, self.attn_count = if_broadcast_spatial( - int(timestep[0]), self.attn_count, self.block_idx - ) + broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count) if enable_pab() and broadcast_attn: x_m_s = self.last_attn @@ -203,12 +191,12 @@ class STDiT3Block(nn.Module): # attention if self.temporal: - if enable_sequence_parallel(): + if self.parallel_manager.sp_size > 1: x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=True) x_m = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S) x_m = self.attn(x_m) x_m = rearrange(x_m, "(B S) T C -> B (T S) C", T=T, S=S) - if enable_sequence_parallel(): + if self.parallel_manager.sp_size > 1: x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=False) else: x_m = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S) @@ -287,17 +275,17 @@ class STDiT3Block(nn.Module): def dynamic_switch(self, x, s, t, to_spatial_shard: bool): if to_spatial_shard: scatter_dim, gather_dim = 2, 1 - scatter_pad = get_spatial_pad() - gather_pad = get_temporal_pad() + scatter_pad = get_pad("spatial") + gather_pad = get_pad("temporal") else: scatter_dim, gather_dim = 1, 2 - scatter_pad = get_temporal_pad() - gather_pad = get_spatial_pad() + scatter_pad = get_pad("temporal") + gather_pad = get_pad("spatial") x = rearrange(x, "b (t s) d -> b t s d", t=t, s=s) x = all_to_all_with_pad( x, - get_sequence_parallel_group(), + self.parallel_manager.sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim, scatter_pad=scatter_pad, @@ -449,6 +437,24 @@ class STDiT3(PreTrainedModel): for param in self.y_embedder.parameters(): param.requires_grad = False + # parallel + self.parallel_manager: ParallelManager = None + + def enable_parallel(self, dp_size, sp_size, enable_cp): + # update cfg parallel + if enable_cp and sp_size % 2 == 0: + sp_size = sp_size // 2 + cp_size = 2 + else: + cp_size = 1 + + self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size) + + for name, module in self.named_modules(): + if "spatial_blocks" in name or "temporal_blocks" in name: + if hasattr(module, "parallel_manager"): + module.parallel_manager = self.parallel_manager + def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): @@ -501,9 +507,14 @@ class STDiT3(PreTrainedModel): self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs ): # === Split batch === - if get_cfg_parallel_size() > 1: + if self.parallel_manager.cp_size > 1: x, timestep, y, x_mask, mask = batch_func( - partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask + partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), + x, + timestep, + y, + x_mask, + mask, ) dtype = self.x_embedder.proj.weight.dtype @@ -547,14 +558,14 @@ class STDiT3(PreTrainedModel): x = x + pos_emb # shard over the sequence dim if sp is enabled - if enable_sequence_parallel(): - set_temporal_pad(T) - set_spatial_pad(S) - x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()) + if self.parallel_manager.sp_size > 1: + set_pad("temporal", T, self.parallel_manager.sp_group) + set_pad("spatial", S, self.parallel_manager.sp_group) + x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")) T = x.shape[1] x_mask_org = x_mask x_mask = split_sequence( - x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad() + x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") ) x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) @@ -589,9 +600,9 @@ class STDiT3(PreTrainedModel): all_timesteps=all_timesteps, ) - if enable_sequence_parallel(): + if self.parallel_manager.sp_size > 1: x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) - x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad()) + x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal")) T, S = x.shape[1], x.shape[2] x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) x_mask = x_mask_org @@ -604,8 +615,8 @@ class STDiT3(PreTrainedModel): x = x.to(torch.float32) # === Gather Output === - if get_cfg_parallel_size() > 1: - x = gather_sequence(x, get_data_parallel_group(), dim=0) + if self.parallel_manager.cp_size > 1: + x = gather_sequence(x, self.parallel_manager.cp_group, dim=0) return x diff --git a/videosys/models/transformers/vchitect_transformer_3d.py b/videosys/models/transformers/vchitect_transformer_3d.py new file mode 100644 index 0000000..166d654 --- /dev/null +++ b/videosys/models/transformers/vchitect_transformer_3d.py @@ -0,0 +1,644 @@ +# Adapted from Vchitect + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# Vchitect: https://github.com/Vchitect/Vchitect-2.0 +# diffusers: https://github.com/huggingface/diffusers +# -------------------------------------------------------- + + +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero +from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput +from diffusers.utils import USE_PEFT_BACKEND, deprecate, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange +from torch import nn + +from videosys.core.comm import gather_from_second_dim, set_pad, split_from_second_dim +from videosys.core.parallel_mgr import ParallelManager +from videosys.models.modules.attentions import VchitectAttention, VchitectAttnProcessor + + +def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + return ff_output + + +@maybe_allow_in_graph +class JointTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False): + super().__init__() + + self.context_pre_only = context_pre_only + context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" + + self.norm1 = AdaLayerNormZero(dim) + + if context_norm_type == "ada_norm_continous": + self.norm1_context = AdaLayerNormContinuous( + dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" + ) + elif context_norm_type == "ada_norm_zero": + self.norm1_context = AdaLayerNormZero(dim) + else: + raise ValueError( + f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" + ) + processor = VchitectAttnProcessor() + self.attn = VchitectAttention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim // num_attention_heads, + heads=num_attention_heads, + out_dim=attention_head_dim, + context_pre_only=context_pre_only, + bias=True, + processor=processor, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + if not context_pre_only: + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + else: + self.norm2_context = None + self.ff_context = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + freqs_cis: torch.Tensor, + full_seqlen: int, + Frame: int, + timestep: int, + ): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + if self.context_pre_only: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + freqs_cis=freqs_cis, + full_seqlen=full_seqlen, + Frame=Frame, + timestep=timestep, + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + if self.context_pre_only: + encoder_hidden_states = None + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + context_ff_output = _chunked_feed_forward( + self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return encoder_hidden_states, hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class VchitectXLTransformerModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + The Transformer model introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + out_channels (`int`, defaults to 16): Number of output channels. + + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 18, + attention_head_dim: int = 64, + num_attention_heads: int = 18, + joint_attention_dim: int = 4096, + caption_projection_dim: int = 1152, + pooled_projection_dim: int = 2048, + out_channels: int = 16, + pos_embed_max_size: int = 96, + rope_scaling_factor: float = 1.0, + ): + super().__init__() + default_out_channels = in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=pos_embed_max_size, # hard-code for now. + ) + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim + ) + self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) + # `attention_head_dim` is doubled to account for the mixing. + # It needs to crafted when we get the actual checkpoints. + self.transformer_blocks = nn.ModuleList( + [ + JointTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.inner_dim, + context_pre_only=i == num_layers - 1, + ) + for i in range(self.config.num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + # Video param + # self.scatter_dim_zero = Identity() + self.freqs_cis = VchitectXLTransformerModel.precompute_freqs_cis( + self.inner_dim // self.config.num_attention_heads, + 1000000, + theta=1e6, + rope_scaling_factor=rope_scaling_factor, # todo max pos embeds + ) + + # self.vid_token = nn.Parameter(torch.empty(self.inner_dim)) + + # parallel + self.parallel_manager = None + + def enable_parallel(self, dp_size, sp_size, enable_cp): + # update cfg parallel + if enable_cp and sp_size % 2 == 0: + sp_size = sp_size // 2 + cp_size = 2 + else: + cp_size = 1 + + self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size) + + for _, module in self.named_modules(): + if hasattr(module, "parallel_manager"): + module.parallel_manager = self.parallel_manager + + @staticmethod + def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, rope_scaling_factor: float = 1.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float) + t = t / rope_scaling_factor + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, VchitectAttnProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, module: torch.nn.Module, processors: Dict[str, VchitectAttnProcessor] + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[VchitectAttnProcessor, Dict[str, VchitectAttnProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, VchitectAttention): + module.fuse_projections(fuse=True) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def patchify_and_embed(self, x): + B, F, C, H, W = x.size() + x = rearrange(x, "b f c h w -> (b f) c h w") + x = self.pos_embed(x) # [B L D] + return x, F, [(H, W)] * B + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`VchitectXLTransformerModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + height, width = hidden_states.shape[-2:] + + batch_size = hidden_states.shape[0] + hidden_states, F_num, _ = self.patchify_and_embed( + hidden_states + ) # takes care of adding positional embeddings too. + full_seq = batch_size * F_num + + self.freqs_cis = self.freqs_cis.to(hidden_states.device) + freqs_cis = self.freqs_cis + # seq_length = hidden_states.size(1) + # freqs_cis = self.freqs_cis[:hidden_states.size(1)*F_num] + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if self.parallel_manager.sp_size > 1: + set_pad("temporal", F_num, self.parallel_manager.sp_group) + hidden_states = split_from_second_dim(hidden_states, batch_size, self.parallel_manager.sp_group) + cur_temb = temb.repeat(hidden_states.shape[0] // batch_size, 1) + else: + cur_temb = temb.repeat(F_num, 1) + + for block_idx, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=cur_temb, + freqs_cis=freqs_cis, + full_seqlen=full_seq, + Frame=F_num, + timestep=timestep, + ) + + if self.parallel_manager.sp_size > 1: + hidden_states = gather_from_second_dim(hidden_states, batch_size, self.parallel_manager.sp_group) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + # hidden_states = hidden_states[:, :-1] #Drop the video token + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def get_fsdp_wrap_module_list(self) -> List[nn.Module]: + return list(self.transformer_blocks) + + @classmethod + def from_pretrained_temporal(cls, pretrained_model_path, torch_dtype, logger, subfolder=None, tp_size=1): + import json + import os + + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + config_file = os.path.join(pretrained_model_path, "config.json") + + with open(config_file, "r") as f: + config = json.load(f) + + config["tp_size"] = tp_size + from safetensors.torch import load_file + + model = cls.from_config(config) + # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + + model_files = [ + os.path.join(pretrained_model_path, "diffusion_pytorch_model.bin"), + os.path.join(pretrained_model_path, "diffusion_pytorch_model.safetensors"), + ] + + model_file = None + + for fp in model_files: + if os.path.exists(fp): + model_file = fp + + if not model_file: + raise RuntimeError(f"{model_file} does not exist") + + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + + state_dict = load_file(model_file, device="cpu") + m, u = model.load_state_dict(state_dict, strict=False) + model = model.to(torch_dtype) + + params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] + total_params = [p.numel() for n, p in model.named_parameters()] + + if logger is not None: + logger.info(f"model_file: {model_file}") + logger.info(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + logger.info(f"### Temporal Module Parameters: {sum(params) / 1e6} M") + logger.info(f"### Total Parameters: {sum(total_params) / 1e6} M") + + return model diff --git a/videosys/pipelines/cogvideox/pipeline_cogvideox.py b/videosys/pipelines/cogvideox/pipeline_cogvideox.py index 7eaee3a..e8bd151 100644 --- a/videosys/pipelines/cogvideox/pipeline_cogvideox.py +++ b/videosys/pipelines/cogvideox/pipeline_cogvideox.py @@ -13,6 +13,7 @@ import math from typing import Callable, Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor @@ -26,19 +27,17 @@ from videosys.models.transformers.cogvideox_transformer_3d import CogVideoXTrans from videosys.schedulers.scheduling_ddim_cogvideox import CogVideoXDDIMScheduler from videosys.schedulers.scheduling_dpm_cogvideox import CogVideoXDPMScheduler from videosys.utils.logging import logger -from videosys.utils.utils import save_video +from videosys.utils.utils import save_video, set_seed class CogVideoXPABConfig(PABConfig): def __init__( self, - steps: int = 50, spatial_broadcast: bool = True, spatial_threshold: list = [100, 850], spatial_range: int = 2, ): super().__init__( - steps=steps, spatial_broadcast=spatial_broadcast, spatial_threshold=spatial_threshold, spatial_range=spatial_range, @@ -114,7 +113,7 @@ class CogVideoXConfig: class CogVideoXPipeline(VideoSysPipeline): - _optional_components = ["text_encoder", "tokenizer"] + _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"] model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = [ "latents", @@ -158,9 +157,6 @@ class CogVideoXPipeline(VideoSysPipeline): subfolder="scheduler", ) - # set eval and device - self.set_eval_and_device(self._device, vae, transformer) - self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) @@ -169,7 +165,7 @@ class CogVideoXPipeline(VideoSysPipeline): if config.cpu_offload: self.enable_model_cpu_offload() else: - self.set_eval_and_device(self._device, text_encoder) + self.set_eval_and_device(self._device, text_encoder, vae, transformer) # vae tiling if config.vae_tiling: @@ -187,6 +183,31 @@ class CogVideoXPipeline(VideoSysPipeline): ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + # parallel + self._set_parallel() + + def _set_seed(self, seed): + if dist.get_world_size() == 1: + set_seed(seed) + else: + set_seed(seed, self.transformer.parallel_manager.dp_rank) + + def _set_parallel( + self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False + ): + # init sequence parallel + if sp_size is None: + sp_size = dist.get_world_size() + dp_size = 1 + else: + assert ( + dist.get_world_size() % sp_size == 0 + ), f"world_size {dist.get_world_size()} must be divisible by sp_size" + dp_size = dist.get_world_size() // sp_size + + # transformer parallel + self.transformer.enable_parallel(dp_size, sp_size, enable_cp) + def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -474,6 +495,7 @@ class CogVideoXPipeline(VideoSysPipeline): num_frames: int = 49, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, + seed: int = -1, guidance_scale: float = 6, use_dynamic_cfg: bool = False, num_videos_per_prompt: int = 1, @@ -489,7 +511,6 @@ class CogVideoXPipeline(VideoSysPipeline): ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 226, - verbose=True, ) -> Union[VideoSysPipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -572,6 +593,7 @@ class CogVideoXPipeline(VideoSysPipeline): "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." ) update_steps(num_inference_steps) + self._set_seed(seed) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -653,11 +675,10 @@ class CogVideoXPipeline(VideoSysPipeline): # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x) - # with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar: # for DPM-solver++ - old_pred_original_sample = None - for i, t in progress_wrap(list(enumerate(timesteps))): + old_pred_original_sample = None + for i, t in enumerate(timesteps): if self.interrupt: continue @@ -674,7 +695,6 @@ class CogVideoXPipeline(VideoSysPipeline): timestep=timestep, image_rotary_emb=image_rotary_emb, return_dict=False, - all_timesteps=timesteps, )[0] noise_pred = noise_pred.float() @@ -713,8 +733,8 @@ class CogVideoXPipeline(VideoSysPipeline): prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - # progress_bar.update() + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() if not output_type == "latent": video = self.decode_latents(latents) diff --git a/videosys/pipelines/latte/pipeline_latte.py b/videosys/pipelines/latte/pipeline_latte.py index 7c1d590..437baaa 100644 --- a/videosys/pipelines/latte/pipeline_latte.py +++ b/videosys/pipelines/latte/pipeline_latte.py @@ -29,13 +29,12 @@ from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput from videosys.models.transformers.latte_transformer_3d import LatteT2V from videosys.utils.logging import logger -from videosys.utils.utils import save_video +from videosys.utils.utils import save_video, set_seed class LattePABConfig(PABConfig): def __init__( self, - steps: int = 50, spatial_broadcast: bool = True, spatial_threshold: list = [100, 800], spatial_range: int = 2, @@ -62,7 +61,6 @@ class LattePABConfig(PABConfig): }, ): super().__init__( - steps=steps, spatial_broadcast=spatial_broadcast, spatial_threshold=spatial_threshold, spatial_range=spatial_range, @@ -188,7 +186,7 @@ class LattePipeline(VideoSysPipeline): r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"] model_cpu_offload_seq = "text_encoder->transformer->vae" def __init__( @@ -238,9 +236,6 @@ class LattePipeline(VideoSysPipeline): if config.enable_pab: set_pab_manager(config.pab_config) - # set eval and device - self.set_eval_and_device(device, vae, transformer) - self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) @@ -249,11 +244,36 @@ class LattePipeline(VideoSysPipeline): if config.cpu_offload: self.enable_model_cpu_offload() else: - self.set_eval_and_device(device, text_encoder) + self.set_eval_and_device(device, text_encoder, vae, transformer) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + # parallel + self._set_parallel() + + def _set_seed(self, seed): + if dist.get_world_size() == 1: + set_seed(seed) + else: + set_seed(seed, self.transformer.parallel_manager.dp_rank) + + def _set_parallel( + self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False + ): + # init sequence parallel + if sp_size is None: + sp_size = dist.get_world_size() + dp_size = 1 + else: + assert ( + dist.get_world_size() % sp_size == 0 + ), f"world_size {dist.get_world_size()} must be divisible by sp_size" + dp_size = dist.get_world_size() // sp_size + + # transformer parallel + self.transformer.enable_parallel(dp_size, sp_size, enable_cp) + # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py def mask_text_embeddings(self, emb, mask): if emb.shape[0] == 1: @@ -658,6 +678,7 @@ class LattePipeline(VideoSysPipeline): negative_prompt: str = "", num_inference_steps: int = 50, guidance_scale: float = 7.5, + seed: int = -1, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -745,6 +766,7 @@ class LattePipeline(VideoSysPipeline): width = 512 update_steps(num_inference_steps) self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds) + self._set_seed(seed) # 2. Default height and width to transformer if prompt is not None and isinstance(prompt, str): diff --git a/videosys/pipelines/open_sora/pipeline_open_sora.py b/videosys/pipelines/open_sora/pipeline_open_sora.py index 5fd3473..6199f26 100644 --- a/videosys/pipelines/open_sora/pipeline_open_sora.py +++ b/videosys/pipelines/open_sora/pipeline_open_sora.py @@ -2,19 +2,22 @@ import html import json import os import re +import urllib.parse as ul from typing import Optional, Tuple, Union import ftfy import torch +import torch.distributed as dist +from bs4 import BeautifulSoup from diffusers.models import AutoencoderKL from transformers import AutoTokenizer, T5EncoderModel -from videosys.core.pab_mgr import PABConfig, set_pab_manager +from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput from videosys.models.autoencoders.autoencoder_kl_open_sora import OpenSoraVAE_V1_2 from videosys.models.transformers.open_sora_transformer_3d import STDiT3 from videosys.schedulers.scheduling_rflow_open_sora import RFLOW -from videosys.utils.utils import save_video +from videosys.utils.utils import save_video, set_seed from .data_process import get_image_size, get_num_frames, prepare_multi_resolution_info, read_from_path @@ -29,7 +32,6 @@ BAD_PUNCT_REGEX = re.compile( class OpenSoraPABConfig(PABConfig): def __init__( self, - steps: int = 50, spatial_broadcast: bool = True, spatial_threshold: list = [450, 930], spatial_range: int = 2, @@ -52,7 +54,6 @@ class OpenSoraPABConfig(PABConfig): }, ): super().__init__( - steps=steps, spatial_broadcast=spatial_broadcast, spatial_threshold=spatial_threshold, spatial_range=spatial_range, @@ -187,11 +188,7 @@ class OpenSoraPipeline(VideoSysPipeline): bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa - - _optional_components = [ - "text_encoder", - "tokenizer", - ] + _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"] model_cpu_offload_seq = "text_encoder->transformer->vae" def __init__( @@ -234,9 +231,6 @@ class OpenSoraPipeline(VideoSysPipeline): if config.enable_pab: set_pab_manager(config.pab_config) - # set eval and device - self.set_eval_and_device(device, vae, transformer) - self.register_modules( text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, tokenizer=tokenizer ) @@ -245,7 +239,32 @@ class OpenSoraPipeline(VideoSysPipeline): if config.cpu_offload: self.enable_model_cpu_offload() else: - self.set_eval_and_device(self._device, text_encoder) + self.set_eval_and_device(self._device, vae, transformer, text_encoder) + + # parallel + self._set_parallel() + + def _set_seed(self, seed): + if dist.get_world_size() == 1: + set_seed(seed) + else: + set_seed(seed, self.transformer.parallel_manager.dp_rank) + + def _set_parallel( + self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False + ): + # init sequence parallel + if sp_size is None: + sp_size = dist.get_world_size() + dp_size = 1 + else: + assert ( + dist.get_world_size() % sp_size == 0 + ), f"world_size {dist.get_world_size()} must be divisible by sp_size" + dp_size = dist.get_world_size() // sp_size + + # transformer parallel + self.transformer.enable_parallel(dp_size, sp_size, enable_cp) def get_text_embeddings(self, texts): text_tokens_and_mask = self.tokenizer( @@ -283,10 +302,6 @@ class OpenSoraPipeline(VideoSysPipeline): return text.strip() def _clean_caption(self, caption): - import urllib.parse as ul - - from bs4 import BeautifulSoup - caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() @@ -418,6 +433,7 @@ class OpenSoraPipeline(VideoSysPipeline): loop: int = 1, llm_refine: bool = False, negative_prompt: str = "", + seed: int = -1, ms: Optional[str] = "", refs: Optional[str] = "", aes: float = 6.5, @@ -460,10 +476,6 @@ class OpenSoraPipeline(VideoSysPipeline): usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - height (`int`, *optional*, defaults to self.unet.config.sample_size): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size): - The width in pixels of the generated image. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -508,6 +520,8 @@ class OpenSoraPipeline(VideoSysPipeline): fps = 24 image_size = get_image_size(resolution, aspect_ratio) num_frames = get_num_frames(num_frames) + self._set_seed(seed) + update_steps(self._config.num_sampling_steps) # == prepare batch prompts == batch_prompts = [prompt] @@ -621,7 +635,7 @@ class OpenSoraPipeline(VideoSysPipeline): progress=verbose, mask=masks, ) - samples = self.vae.decode(samples.to(self._dtype), num_frames=num_frames) + samples = self.vae(samples.to(self._dtype), decode_only=True, num_frames=num_frames) video_clips.append(samples) for i in range(1, loop): diff --git a/videosys/pipelines/open_sora_plan/__init__.py b/videosys/pipelines/open_sora_plan/__init__.py index 7a1ddb8..ddf790d 100644 --- a/videosys/pipelines/open_sora_plan/__init__.py +++ b/videosys/pipelines/open_sora_plan/__init__.py @@ -1,3 +1,8 @@ -from .pipeline_open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline +from .pipeline_open_sora_plan import ( + OpenSoraPlanConfig, + OpenSoraPlanPipeline, + OpenSoraPlanV110PABConfig, + OpenSoraPlanV120PABConfig, +) -__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanPABConfig"] +__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig"] diff --git a/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py b/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py index b973361..ce34321 100644 --- a/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py +++ b/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py @@ -20,39 +20,27 @@ import torch.distributed as dist import tqdm from bs4 import BeautifulSoup from diffusers.models import AutoencoderKL, Transformer2DModel -from diffusers.schedulers import PNDMScheduler +from diffusers.schedulers import EulerAncestralDiscreteScheduler, PNDMScheduler from diffusers.utils.torch_utils import randn_tensor -from transformers import T5EncoderModel, T5Tokenizer +from transformers import AutoTokenizer, MT5EncoderModel, T5EncoderModel, T5Tokenizer from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput +from videosys.models.autoencoders.autoencoder_kl_open_sora_plan_v110 import ( + CausalVAEModelWrapper as CausalVAEModelWrapperV110, +) +from videosys.models.autoencoders.autoencoder_kl_open_sora_plan_v120 import ( + CausalVAEModelWrapper as CausalVAEModelWrapperV120, +) +from videosys.models.transformers.open_sora_plan_v110_transformer_3d import LatteT2V +from videosys.models.transformers.open_sora_plan_v120_transformer_3d import OpenSoraT2V from videosys.utils.logging import logger -from videosys.utils.utils import save_video - -from ...models.autoencoders.autoencoder_kl_open_sora_plan import ae_stride_config, getae_wrapper -from ...models.transformers.open_sora_plan_transformer_3d import LatteT2V - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import PixArtAlphaPipeline - - >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too. - >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) - >>> # Enable memory optimizations. - >>> pipe.enable_model_cpu_offload() - - >>> prompt = "A small cactus with a happy face in the Sahara desert." - >>> image = pipe(prompt).images[0] - ``` -""" +from videosys.utils.utils import save_video, set_seed -class OpenSoraPlanPABConfig(PABConfig): +class OpenSoraPlanV110PABConfig(PABConfig): def __init__( self, - steps: int = 150, spatial_broadcast: bool = True, spatial_threshold: list = [100, 850], spatial_range: int = 2, @@ -97,7 +85,6 @@ class OpenSoraPlanPABConfig(PABConfig): }, ): super().__init__( - steps=steps, spatial_broadcast=spatial_broadcast, spatial_threshold=spatial_threshold, spatial_range=spatial_range, @@ -113,6 +100,26 @@ class OpenSoraPlanPABConfig(PABConfig): ) +class OpenSoraPlanV120PABConfig(PABConfig): + def __init__( + self, + spatial_broadcast: bool = True, + spatial_threshold: list = [100, 850], + spatial_range: int = 2, + cross_broadcast: bool = True, + cross_threshold: list = [100, 850], + cross_range: int = 6, + ): + super().__init__( + spatial_broadcast=spatial_broadcast, + spatial_threshold=spatial_threshold, + spatial_range=spatial_range, + cross_broadcast=cross_broadcast, + cross_threshold=cross_threshold, + cross_range=cross_range, + ) + + class OpenSoraPlanConfig: """ This config is to instantiate a `OpenSoraPlanPipeline` class for video generation. @@ -163,10 +170,10 @@ class OpenSoraPlanConfig: def __init__( self, - transformer: str = "LanguageBind/Open-Sora-Plan-v1.1.0", - ae: str = "CausalVAEModel_4x8x8", - text_encoder: str = "DeepFloyd/t5-v1_1-xxl", - num_frames: int = 65, + version: str = "v120", + transformer_type: str = "29x480p", + transformer: str = None, + text_encoder: str = None, # ======= distributed ======== num_gpus: int = 1, # ======= memory ======= @@ -175,15 +182,32 @@ class OpenSoraPlanConfig: tile_overlap_factor: float = 0.25, # ======= pab ======== enable_pab: bool = False, - pab_config: PABConfig = OpenSoraPlanPABConfig(), + pab_config: PABConfig = None, ): self.pipeline_cls = OpenSoraPlanPipeline - self.ae = ae - self.text_encoder = text_encoder - self.transformer = transformer - assert num_frames in [65, 221], "num_frames must be one of [65, 221]" - self.num_frames = num_frames - self.version = f"{num_frames}x512x512" + + # get version + assert version in ["v110", "v120"], f"Unknown Open-Sora-Plan version: {version}" + self.version = version + self.transformer_type = transformer_type + + # check transformer_type + if version == "v110": + assert transformer_type in ["65x512x512", "221x512x512"] + elif version == "v120": + assert transformer_type in ["93x480p", "93x720p", "29x480p", "29x720p"] + self.num_frames = int(transformer_type.split("x")[0]) + + # set default values according to version + if version == "v110": + transformer_default = "LanguageBind/Open-Sora-Plan-v1.1.0" + text_encoder_default = "DeepFloyd/t5-v1_1-xxl" + elif version == "v120": + transformer_default = "LanguageBind/Open-Sora-Plan-v1.2.0" + text_encoder_default = "google/mt5-xxl" + self.text_encoder = text_encoder or text_encoder_default + self.transformer = transformer or transformer_default + # ======= distributed ======== self.num_gpus = num_gpus # ======= memory ======== @@ -192,7 +216,13 @@ class OpenSoraPlanConfig: self.tile_overlap_factor = tile_overlap_factor # ======= pab ======== self.enable_pab = enable_pab - self.pab_config = pab_config + if self.enable_pab and pab_config is None: + if version == "v110": + self.pab_config = OpenSoraPlanV110PABConfig() + elif version == "v120": + self.pab_config = OpenSoraPlanV120PABConfig() + else: + self.pab_config = pab_config class OpenSoraPlanPipeline(VideoSysPipeline): @@ -221,7 +251,7 @@ class OpenSoraPlanPipeline(VideoSysPipeline): r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"] model_cpu_offload_seq = "text_encoder->transformer->vae" def __init__( @@ -240,26 +270,57 @@ class OpenSoraPlanPipeline(VideoSysPipeline): # init if tokenizer is None: - tokenizer = T5Tokenizer.from_pretrained(config.text_encoder) + if config.version == "v110": + tokenizer = T5Tokenizer.from_pretrained(config.text_encoder) + elif config.version == "v120": + tokenizer = AutoTokenizer.from_pretrained(config.text_encoder) + if text_encoder is None: - text_encoder = T5EncoderModel.from_pretrained(config.text_encoder, torch_dtype=torch.float16) + if config.version == "v110": + text_encoder = T5EncoderModel.from_pretrained(config.text_encoder, torch_dtype=dtype) + elif config.version == "v120": + text_encoder = MT5EncoderModel.from_pretrained( + config.text_encoder, low_cpu_mem_usage=True, torch_dtype=dtype + ) + if vae is None: - vae = getae_wrapper(config.ae)(config.transformer, subfolder="vae").to(dtype=dtype) + if config.version == "v110": + vae = CausalVAEModelWrapperV110(config.transformer, subfolder="vae").to(dtype=dtype) + elif config.version == "v120": + vae = CausalVAEModelWrapperV120(config.transformer, subfolder="vae").to(dtype=dtype) + if transformer is None: - transformer = LatteT2V.from_pretrained(config.transformer, subfolder=config.version, torch_dtype=dtype) + if config.version == "v110": + transformer = LatteT2V.from_pretrained( + config.transformer, subfolder=config.transformer_type, torch_dtype=dtype + ) + elif config.version == "v120": + transformer = OpenSoraT2V.from_pretrained( + config.transformer, subfolder=config.transformer_type, torch_dtype=dtype + ) + if scheduler is None: - scheduler = PNDMScheduler() + if config.version == "v110": + scheduler = PNDMScheduler() + elif config.version == "v120": + scheduler = EulerAncestralDiscreteScheduler() # setting if config.enable_tiling: vae.vae.enable_tiling() vae.vae.tile_overlap_factor = config.tile_overlap_factor - vae.vae_scale_factor = ae_stride_config[config.ae] + vae.vae.tile_sample_min_size = 512 + vae.vae.tile_latent_min_size = 64 + vae.vae.tile_sample_min_size_t = 29 + vae.vae.tile_latent_min_size_t = 8 + # if low_mem: + # vae.vae.tile_sample_min_size = 256 + # vae.vae.tile_latent_min_size = 32 + # vae.vae.tile_sample_min_size_t = 29 + # vae.vae.tile_latent_min_size_t = 8 + vae.vae_scale_factor = [4, 8, 8] transformer.force_images = False - # set eval and device - self.set_eval_and_device(device, vae, transformer) - # pab if config.enable_pab: set_pab_manager(config.pab_config) @@ -272,10 +333,35 @@ class OpenSoraPlanPipeline(VideoSysPipeline): if config.cpu_offload: self.enable_model_cpu_offload() else: - self.set_eval_and_device(device, text_encoder) + self.set_eval_and_device(device, text_encoder, vae, transformer) # self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + # parallel + self._set_parallel() + + def _set_seed(self, seed): + if dist.get_world_size() == 1: + set_seed(seed) + else: + set_seed(seed, self.transformer.parallel_manager.dp_rank) + + def _set_parallel( + self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False + ): + # init sequence parallel + if sp_size is None: + sp_size = dist.get_world_size() + dp_size = 1 + else: + assert ( + dist.get_world_size() % sp_size == 0 + ), f"world_size {dist.get_world_size()} must be divisible by sp_size" + dp_size = dist.get_world_size() // sp_size + + # transformer parallel + self.transformer.enable_parallel(dp_size, sp_size, enable_cp) + # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py def mask_text_embeddings(self, emb, mask): if emb.shape[0] == 1: @@ -449,6 +535,143 @@ class OpenSoraPlanPipeline(VideoSysPipeline): return prompt_embeds, negative_prompt_embeds + def encode_prompt_v120( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 120, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + """ + if device is None: + device = getattr(self, "_execution_device", None) or getattr(self, "device", None) or torch.device("cuda") + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -476,6 +699,8 @@ class OpenSoraPlanPipeline(VideoSysPipeline): callback_steps, prompt_embeds=None, negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -519,6 +744,12 @@ class OpenSoraPlanPipeline(VideoSysPipeline): f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): @@ -685,10 +916,13 @@ class OpenSoraPlanPipeline(VideoSysPipeline): guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, + seed: int = -1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -697,6 +931,7 @@ class OpenSoraPlanPipeline(VideoSysPipeline): mask_feature: bool = True, enable_temporal_attentions: bool = True, verbose: bool = True, + max_sequence_length: int = 512, ) -> Union[VideoSysPipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -768,11 +1003,22 @@ class OpenSoraPlanPipeline(VideoSysPipeline): returned where the first element is a list with the generated images """ # 1. Check inputs. Raise error if not correct - height = 512 - width = 512 + height = self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1] + width = self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2] num_frames = self._config.num_frames update_steps(num_inference_steps) - self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds) + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + self._set_seed(seed) # 2. Default height and width to transformer if prompt is not None and isinstance(prompt, str): @@ -782,7 +1028,7 @@ class OpenSoraPlanPipeline(VideoSysPipeline): else: batch_size = prompt_embeds.shape[0] - device = self._execution_device + device = getattr(self, "_execution_device", None) or getattr(self, "device", None) or torch.device("cuda") # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -790,23 +1036,50 @@ class OpenSoraPlanPipeline(VideoSysPipeline): do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - do_classifier_free_guidance, - negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - clean_caption=clean_caption, - mask_feature=mask_feature, - ) + if self._config.version == "v110": + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + mask_feature=mask_feature, + ) + prompt_attention_mask = None + negative_prompt_attention_mask = None + else: + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt_v120( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + if prompt_attention_mask is not None and negative_prompt_attention_mask is not None: + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + if self._config.version == "v110": + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels @@ -827,15 +1100,10 @@ class OpenSoraPlanPipeline(VideoSysPipeline): # 6.1 Prepare micro-conditions. added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - # if self.transformer.config.sample_size == 128: - # resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) - # aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) - # resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) - # aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) - # added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x) for i, t in progress_wrap(list(enumerate(timesteps))): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -856,9 +1124,15 @@ class OpenSoraPlanPipeline(VideoSysPipeline): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML current_timestep = current_timestep.expand(latent_model_input.shape[0]) - if prompt_embeds.ndim == 3: + if prompt_embeds is not None and prompt_embeds.ndim == 3: prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d + if prompt_attention_mask is not None and prompt_attention_mask.ndim == 2: + prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l + # prepare attention_mask. + # b c t h w -> b t h w + attention_mask = torch.ones_like(latent_model_input)[:, 0] + # predict noise model_output noise_pred = self.transformer( latent_model_input, @@ -867,6 +1141,8 @@ class OpenSoraPlanPipeline(VideoSysPipeline): timestep=current_timestep, added_cond_kwargs=added_cond_kwargs, enable_temporal_attentions=enable_temporal_attentions, + attention_mask=attention_mask, + encoder_attention_mask=prompt_attention_mask, return_dict=False, )[0] @@ -906,10 +1182,57 @@ class OpenSoraPlanPipeline(VideoSysPipeline): return VideoSysPipelineOutput(video=video) def decode_latents(self, latents): - video = self.vae.decode(latents) # b t c h w - # b t c h w -> b t h w c - video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous() + video = self.vae(latents.to(self.vae.vae.dtype)) + video = ( + ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous() + ) # b t h w c + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 return video def save_video(self, video, output_path): save_video(video, output_path, fps=24) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps diff --git a/videosys/pipelines/vchitect/__init__.py b/videosys/pipelines/vchitect/__init__.py new file mode 100644 index 0000000..0ee6ac1 --- /dev/null +++ b/videosys/pipelines/vchitect/__init__.py @@ -0,0 +1,3 @@ +from .pipeline_vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline + +__all__ = ["VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig"] diff --git a/videosys/pipelines/vchitect/pipeline_vchitect.py b/videosys/pipelines/vchitect/pipeline_vchitect.py new file mode 100644 index 0000000..e46927d --- /dev/null +++ b/videosys/pipelines/vchitect/pipeline_vchitect.py @@ -0,0 +1,1057 @@ +# Adapted from Vchitect + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# Vchitect: https://github.com/Vchitect/Vchitect-2.0 +# diffusers: https://github.com/huggingface/diffusers +# -------------------------------------------------------- + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +import torch.distributed as dist +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils.torch_utils import randn_tensor +from torch.amp import autocast +from tqdm import tqdm +from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps +from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput +from videosys.models.transformers.vchitect_transformer_3d import VchitectXLTransformerModel +from videosys.utils.logging import logger +from videosys.utils.utils import save_video, set_seed + + +class VchitectPABConfig(PABConfig): + def __init__( + self, + spatial_broadcast: bool = True, + spatial_threshold: list = [100, 800], + spatial_range: int = 2, + temporal_broadcast: bool = True, + temporal_threshold: list = [100, 800], + temporal_range: int = 4, + cross_broadcast: bool = True, + cross_threshold: list = [100, 800], + cross_range: int = 6, + ): + super().__init__( + spatial_broadcast=spatial_broadcast, + spatial_threshold=spatial_threshold, + spatial_range=spatial_range, + temporal_broadcast=temporal_broadcast, + temporal_threshold=temporal_threshold, + temporal_range=temporal_range, + cross_broadcast=cross_broadcast, + cross_threshold=cross_threshold, + cross_range=cross_range, + ) + + +class VchitectConfig: + """ + This config is to instantiate a `VchitectXLPipeline` class for video generation. + + To be specific, this config will be passed to engine by `VideoSysEngine(config)`. + In the engine, it will be used to instantiate the corresponding pipeline class. + And the engine will call the `generate` function of the pipeline to generate the video. + If you want to explore the detail of generation, please refer to the pipeline class below. + + Args: + model_path (str): + The model path to use. Defaults to "Vchitect/Vchitect-2.0-2B". + num_gpus (int): + The number of GPUs to use. Defaults to 1. + cpu_offload (bool): + Whether to enable cpu offload. Defaults to False. + enable_pab (bool): + Whether to enable Pyramid Attention Broadcast. Defaults to False. + pab_config (VchitectPABConfig): + The configuration for Pyramid Attention Broadcast. Defaults to `VchitectPABConfig()`. + + Examples: + ```python + from videosys import OpenSoraPlanConfig, VideoSysEngine + + # change num_gpus for multi-gpu inference + config = VchitectConfig("Vchitect/Vchitect-2.0-2B", num_gpus=1) + engine = VideoSysEngine(config) + + prompt = "Sunset over the sea." + # seed=-1 means random seed. >0 means fixed seed. + # WxH: 480x288 624x352 432x240 768x432 + video = engine.generate( + prompt=prompt, + negative_prompt="", + num_inference_steps=100, + guidance_scale=7.5, + width=480, + height=288, + frames=40, + seed=0, + ).video[0] + engine.save_video(video, f"./outputs/{prompt}.mp4") + ``` + """ + + def __init__( + self, + model_path: str = "Vchitect/Vchitect-2.0-2B", + # ======= distributed ======== + num_gpus: int = 1, + # ======= memory ======= + cpu_offload: bool = False, + # ======= pab ======== + enable_pab: bool = False, + pab_config: VchitectPABConfig = VchitectPABConfig(), + ): + self.model_path = model_path + self.pipeline_cls = VchitectXLPipeline + # ======= distributed ======== + self.num_gpus = num_gpus + # ======= memory ======== + self.cpu_offload = cpu_offload + # ======= pab ======== + self.enable_pab = enable_pab + self.pab_config = pab_config + + +class VchitectXLPipeline(VideoSysPipeline): + r""" + Args: + transformer ([`VchitectXLTransformerModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [ + "text_encoder", + "text_encoder_2", + "text_encoder_3", + "tokenizer", + "tokenizer_2", + "tokenizer_3", + "vae", + "transformer", + "scheduler", + ] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + config: VchitectConfig, + text_encoder: Optional[CLIPTextModelWithProjection] = None, + text_encoder_2: Optional[CLIPTextModelWithProjection] = None, + text_encoder_3: Optional[T5EncoderModel] = None, + tokenizer: Optional[CLIPTokenizer] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, + tokenizer_3: Optional[T5TokenizerFast] = None, + vae: Optional[AutoencoderKL] = None, + transformer: Optional[VchitectXLTransformerModel] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self._config = config + self._device = device + self._dtype = dtype + + if text_encoder is None: + self.text_encoder = CLIPTextModelWithProjection.from_pretrained( + config.model_path, subfolder="text_encoder", torch_dtype=dtype + ) + if text_encoder_2 is None: + self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + config.model_path, subfolder="text_encoder_2", torch_dtype=dtype + ) + if text_encoder_3 is None: + self.text_encoder_3 = T5EncoderModel.from_pretrained( + config.model_path, subfolder="text_encoder_3", torch_dtype=dtype + ) + + if tokenizer is None: + self.tokenizer = CLIPTokenizer.from_pretrained(config.model_path, subfolder="tokenizer") + if tokenizer_2 is None: + self.tokenizer_2 = CLIPTokenizer.from_pretrained(config.model_path, subfolder="tokenizer_2") + if tokenizer_3 is None: + self.tokenizer_3 = T5TokenizerFast.from_pretrained(config.model_path, subfolder="tokenizer_3") + + if vae is None: + self.vae = AutoencoderKL.from_pretrained(config.model_path, subfolder="vae", torch_dtype=dtype) + + if transformer is None: + self.transformer = VchitectXLTransformerModel.from_pretrained( + config.model_path, subfolder="transformer", torch_dtype=dtype + ) + + if scheduler is None: + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(config.model_path, subfolder="scheduler") + + self.register_modules( + tokenizer=self.tokenizer, + tokenizer_2=self.tokenizer_2, + tokenizer_3=self.tokenizer_3, + text_encoder=self.text_encoder, + text_encoder_3=self.text_encoder_3, + text_encoder_2=self.text_encoder_2, + vae=self.vae, + transformer=self.transformer, + scheduler=self.scheduler, + ) + + # pab + if config.enable_pab: + set_pab_manager(config.pab_config) + + # cpu offload + if config.cpu_offload: + self.enable_model_cpu_offload() + else: + self.set_eval_and_device( + device, self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.transformer, self.vae + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.max_sequence_length_t5 = 256 + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + # parallel + self._set_parallel() + + def _set_parallel( + self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False + ): + # init sequence parallel + if sp_size is None: + sp_size = dist.get_world_size() + dp_size = 1 + else: + assert ( + dist.get_world_size() % sp_size == 0 + ), f"world_size {dist.get_world_size()} must be divisible by sp_size" + dp_size = dist.get_world_size() // sp_size + + # transformer parallel + self.transformer.enable_parallel(dp_size, sp_size, enable_cp) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self.execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + (batch_size, self.max_sequence_length_t5, self.transformer.config.joint_attention_dim), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=self.max_sequence_length_t5, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.max_sequence_length_t5 - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.max_sequence_length_t5} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + device = device or self.execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + def _set_seed(self, seed): + if dist.get_world_size() == 1: + set_seed(seed) + else: + set_seed(seed, self.transformer.parallel_manager.dp_rank) + + @autocast("cuda", enabled=False) + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self.execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = ( + prompt_embeds.to(self._device), + negative_prompt_embeds.to(self._device), + pooled_prompt_embeds.to(self._device), + negative_pooled_prompt_embeds.to(self._device), + ) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + frames, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + # 1, 60, 16, 32, 32 + shape = ( + batch_size, + frames, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @autocast("cuda", dtype=torch.bfloat16) + def generate( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: int = 288, + width: int = 480, + frames: int = 40, + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + seed: int = -1, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + frames = frames or 24 + self._set_seed(seed) + update_steps(num_inference_steps) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=self.text_encoder.device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, self._device, timesteps + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + frames, + prompt_embeds.dtype, + self._device, + generator, + latents, + ) + + # 6. Denoising loop + # with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input[0, :].unsqueeze(0), + timestep=timestep, + encoder_hidden_states=prompt_embeds[0, :].unsqueeze(0), + pooled_projections=pooled_prompt_embeds[0, :].unsqueeze(0), + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred_text = self.transformer( + hidden_states=latent_model_input[1, :].unsqueeze(0), + timestep=timestep, + encoder_hidden_states=prompt_embeds[1, :].unsqueeze(0), + pooled_projections=pooled_prompt_embeds[1, :].unsqueeze(0), + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + # perform guidance + if self.do_classifier_free_guidance: + # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + # progress_bar.update() + + # if output_type == "latent": + # image = latents + + # else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + videos = [] + for v_idx in range(latents.shape[1]): + image = self.vae.decode(latents[:, v_idx], return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + videos.append(image[0]) + videos = [videos] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (videos,) + + return VideoSysPipelineOutput(video=videos) + + def save_video(self, video, output_path): + save_video(video, output_path, fps=8) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps diff --git a/videosys/utils/__init__.py b/videosys/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/videosys/utils/test.py b/videosys/utils/test.py new file mode 100644 index 0000000..aec6f7b --- /dev/null +++ b/videosys/utils/test.py @@ -0,0 +1,12 @@ +import functools + +import torch + + +def empty_cache(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + torch.cuda.empty_cache() + return func(*args, **kwargs) + + return wrapper diff --git a/videosys/utils/utils.py b/videosys/utils/utils.py index 868e37e..622a36d 100644 --- a/videosys/utils/utils.py +++ b/videosys/utils/utils.py @@ -16,7 +16,17 @@ def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: p.requires_grad = flag -def set_seed(seed): +def set_seed(seed, dp_rank=None): + if seed == -1: + seed = random.randint(0, 1000000) + + if dp_rank is not None: + seed = torch.tensor(seed, dtype=torch.int64).cuda() + if dist.get_world_size() > 1: + dist.broadcast(seed, 0) + seed = seed + dp_rank + + seed = int(seed) random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed)