update the version of videosys

This commit is contained in:
LiewFeng 2024-12-16 17:27:04 +08:00
parent 66038c512f
commit 21ea48e71a
34 changed files with 11296 additions and 506 deletions

30
eval/teacache/README.md Normal file
View File

@ -0,0 +1,30 @@
# Evaluation of TeaCache
We first generate videos according to VBench's prompts.
And then calculate Vbench, PSNR, LPIPS and SSIM based on the video generated.
1. Generate video
```
cd eval/teacache
python experiments/latte.py
python experiments/opensora.py
python experiments/open_sora_plan.py
```
2. Calculate Vbench score
```
# vbench is calculated independently
# get scores for all metrics
python vbench/run_vbench.py --video_path aaa --save_path bbb
# calculate final score
python vbench/cal_vbench.py --score_dir bbb
```
3. Calculate other metrics
```
# these metrics are calculated compared with original model
# gt video is the video of original model
# generated video is our methods's results
python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb
```

View File

@ -0,0 +1,205 @@
import argparse
import os
import imageio
import torch
import torchvision.transforms.functional as F
import tqdm
from calculate_lpips import calculate_lpips
from calculate_psnr import calculate_psnr
from calculate_ssim import calculate_ssim
def load_video(video_path):
"""
Load a video from the given path and convert it to a PyTorch tensor.
"""
# Read the video using imageio
reader = imageio.get_reader(video_path, "ffmpeg")
# Extract frames and convert to a list of tensors
frames = []
for frame in reader:
# Convert the frame to a tensor and permute the dimensions to match (C, H, W)
frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
frames.append(frame_tensor)
# Stack the list of tensors into a single tensor with shape (T, C, H, W)
video_tensor = torch.stack(frames)
return video_tensor
def resize_video(video, target_height, target_width):
resized_frames = []
for frame in video:
resized_frame = F.resize(frame, [target_height, target_width])
resized_frames.append(resized_frame)
return torch.stack(resized_frames)
def resize_gt_video(gt_video, gen_video):
gen_video_shape = gen_video.shape
T_gen, _, H_gen, W_gen = gen_video_shape
T_eval, _, H_eval, W_eval = gt_video.shape
if T_eval < T_gen:
raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
if H_eval < H_gen or W_eval < W_gen:
# Resize the video maintaining the aspect ratio
resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
gt_video = resize_video(gt_video, resize_height, resize_width)
# Recalculate the dimensions
T_eval, _, H_eval, W_eval = gt_video.shape
# Center crop
start_h = (H_eval - H_gen) // 2
start_w = (W_eval - W_gen) // 2
cropped_video = gt_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
return cropped_video
def get_video_ids(gt_video_dirs, gen_video_dirs):
video_ids = []
for f in os.listdir(gt_video_dirs[0]):
if f.endswith(f".mp4"):
video_ids.append(f.replace(f".mp4", ""))
video_ids.sort()
for video_dir in gt_video_dirs + gen_video_dirs:
tmp_video_ids = []
for f in os.listdir(video_dir):
if f.endswith(f".mp4"):
tmp_video_ids.append(f.replace(f".mp4", ""))
tmp_video_ids.sort()
if tmp_video_ids != video_ids:
raise ValueError(f"Video IDs in {video_dir} are different.")
return video_ids
def get_videos(video_ids, gt_video_dirs, gen_video_dirs):
gt_videos = {}
generated_videos = {}
for gt_video_dir in gt_video_dirs:
tmp_gt_videos_tensor = []
for video_id in video_ids:
gt_video = load_video(os.path.join(gt_video_dir, f"{video_id}.mp4"))
tmp_gt_videos_tensor.append(gt_video)
gt_videos[gt_video_dir] = tmp_gt_videos_tensor
for generated_video_dir in gen_video_dirs:
tmp_generated_videos_tensor = []
for video_id in video_ids:
generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.mp4"))
tmp_generated_videos_tensor.append(generated_video)
generated_videos[generated_video_dir] = tmp_generated_videos_tensor
return gt_videos, generated_videos
def print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs):
out_str = ""
for gt_video_dir in gt_video_dirs:
for generated_video_dir in gen_video_dirs:
if gt_video_dir == generated_video_dir:
continue
lpips = sum(lpips_results[gt_video_dir][generated_video_dir]) / len(
lpips_results[gt_video_dir][generated_video_dir]
)
psnr = sum(psnr_results[gt_video_dir][generated_video_dir]) / len(
psnr_results[gt_video_dir][generated_video_dir]
)
ssim = sum(ssim_results[gt_video_dir][generated_video_dir]) / len(
ssim_results[gt_video_dir][generated_video_dir]
)
out_str += f"\ngt: {gt_video_dir} -> gen: {generated_video_dir}, lpips: {lpips:.4f}, psnr: {psnr:.4f}, ssim: {ssim:.4f}"
return out_str
def main(args):
device = "cuda"
gt_video_dirs = args.gt_video_dirs
gen_video_dirs = args.gen_video_dirs
video_ids = get_video_ids(gt_video_dirs, gen_video_dirs)
print(f"Find {len(video_ids)} videos")
prompt_interval = 1
batch_size = 8
calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
lpips_results = {}
psnr_results = {}
ssim_results = {}
for gt_video_dir in gt_video_dirs:
lpips_results[gt_video_dir] = {}
psnr_results[gt_video_dir] = {}
ssim_results[gt_video_dir] = {}
for generated_video_dir in gen_video_dirs:
lpips_results[gt_video_dir][generated_video_dir] = []
psnr_results[gt_video_dir][generated_video_dir] = []
ssim_results[gt_video_dir][generated_video_dir] = []
total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
for idx in tqdm.tqdm(range(total_len)):
video_ids_batch = video_ids[idx * batch_size : (idx + 1) * batch_size]
gt_videos, generated_videos = get_videos(video_ids_batch, gt_video_dirs, gen_video_dirs)
for gt_video_dir, gt_videos_tensor in gt_videos.items():
for generated_video_dir, generated_videos_tensor in generated_videos.items():
if gt_video_dir == generated_video_dir:
continue
if not isinstance(gt_videos_tensor, torch.Tensor):
for i in range(len(gt_videos_tensor)):
gt_videos_tensor[i] = resize_gt_video(gt_videos_tensor[i], generated_videos_tensor[0])
gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
if calculate_lpips_flag:
result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
result = result["value"].values()
result = float(sum(result) / len(result))
lpips_results[gt_video_dir][generated_video_dir].append(result)
if calculate_psnr_flag:
result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
result = result["value"].values()
result = float(sum(result) / len(result))
psnr_results[gt_video_dir][generated_video_dir].append(result)
if calculate_ssim_flag:
result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
result = result["value"].values()
result = float(sum(result) / len(result))
ssim_results[gt_video_dir][generated_video_dir].append(result)
if (idx + 1) % prompt_interval == 0:
out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
print(f"Processed {idx + 1} / {total_len} videos. {out_str}")
out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
# save
with open(f"./batch_eval.txt", "w+") as f:
f.write(out_str)
print(f"Processed all videos. {out_str}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--gt_video_dirs", type=str, nargs="+")
parser.add_argument("--gen_video_dirs", type=str, nargs="+")
args = parser.parse_args()
main(args)

View File

@ -5,12 +5,7 @@ from einops import rearrange, repeat
from torch import nn
import numpy as np
from typing import Any, Dict, Optional, Tuple
from videosys.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_size,
get_data_parallel_group,
get_sequence_parallel_group,
)
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
def teacache_forward(
self,
@ -67,7 +62,7 @@ def teacache_forward(
"""
# 0. Split batch for data parallelism
if get_cfg_parallel_size() > 1:
if self.parallel_manager.cp_size > 1:
(
hidden_states,
timestep,
@ -77,7 +72,7 @@ def teacache_forward(
attention_mask,
encoder_attention_mask,
) = batch_func(
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
hidden_states,
timestep,
encoder_hidden_states,
@ -193,14 +188,14 @@ def teacache_forward(
if not should_calc:
hidden_states += self.previous_residual
else:
if enable_sequence_parallel():
set_temporal_pad(frame + use_image_num)
set_spatial_pad(num_patches)
if self.parallel_manager.sp_size > 1:
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
@ -323,14 +318,14 @@ def teacache_forward(
).contiguous()
self.previous_residual = hidden_states - hidden_states_origin
else:
if enable_sequence_parallel():
set_temporal_pad(frame + use_image_num)
set_spatial_pad(num_patches)
if self.parallel_manager.sp_size > 1:
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
@ -451,7 +446,8 @@ def teacache_forward(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
if enable_sequence_parallel():
if self.parallel_manager.sp_size > 1:
if self.enable_teacache:
if should_calc:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
@ -488,8 +484,8 @@ def teacache_forward(
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
# 3. Gather batch for data parallelism
if get_cfg_parallel_size() > 1:
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
if self.parallel_manager.cp_size > 1:
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)
@ -497,10 +493,6 @@ def teacache_forward(
return Transformer3DModelOutput(sample=output)
def eval_base(prompt_list):
config = LatteConfig()
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/latte_base", loop=5)
def eval_teacache_slow(prompt_list):
config = LatteConfig()
@ -523,10 +515,18 @@ def eval_teacache_fast(prompt_list):
engine.driver_worker.transformer.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/latte_teacache_fast", loop=5)
def eval_base(prompt_list):
config = LatteConfig()
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/latte_base", loop=5)
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
# eval_base(prompt_list)
eval_base(prompt_list)
eval_teacache_slow(prompt_list)
# eval_teacache_fast(prompt_list)
eval_teacache_fast(prompt_list)

View File

@ -3,23 +3,23 @@ from videosys import OpenSoraConfig, VideoSysEngine
import torch
from einops import rearrange
from videosys.models.transformers.open_sora_transformer_3d import t2i_modulate, auto_grad_checkpoint
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_temporal_pad, set_spatial_pad, set_temporal_pad, split_sequence
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
import numpy as np
from videosys.utils.utils import batch_func
from videosys.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_size,
get_data_parallel_group,
get_sequence_parallel_group,
)
from functools import partial
def teacache_forward(
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
):
# === Split batch ===
if get_cfg_parallel_size() > 1:
if self.parallel_manager.cp_size > 1:
x, timestep, y, x_mask, mask = batch_func(
partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
x,
timestep,
y,
x_mask,
mask,
)
dtype = self.x_embedder.proj.weight.dtype
@ -60,7 +60,7 @@ def teacache_forward(
# === get x embed ===
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = x + pos_emb
x = x + pos_emb
if self.enable_teacache:
inp = x.clone()
@ -84,7 +84,6 @@ def teacache_forward(
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
# === blocks ===
if self.enable_teacache:
if not should_calc:
@ -92,16 +91,15 @@ def teacache_forward(
x += self.previous_residual
else:
# shard over the sequence dim if sp is enabled
if enable_sequence_parallel():
set_temporal_pad(T)
set_spatial_pad(S)
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
if self.parallel_manager.sp_size > 1:
set_pad("temporal", T, self.parallel_manager.sp_group)
set_pad("spatial", S, self.parallel_manager.sp_group)
x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
T = x.shape[1]
x_mask_org = x_mask
x_mask = split_sequence(
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
origin_x = x.clone().detach()
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
@ -135,16 +133,17 @@ def teacache_forward(
self.previous_residual = x - origin_x
else:
# shard over the sequence dim if sp is enabled
if enable_sequence_parallel():
set_temporal_pad(T)
set_spatial_pad(S)
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
if self.parallel_manager.sp_size > 1:
set_pad("temporal", T, self.parallel_manager.sp_group)
set_pad("spatial", S, self.parallel_manager.sp_group)
x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
T = x.shape[1]
x_mask_org = x_mask
x_mask = split_sequence(
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
x = auto_grad_checkpoint(
spatial_block,
@ -174,25 +173,23 @@ def teacache_forward(
all_timesteps=all_timesteps,
)
if enable_sequence_parallel():
if self.parallel_manager.sp_size > 1:
if self.enable_teacache:
if should_calc:
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
self.previous_residual = rearrange(self.previous_residual, "B (T S) C -> B T S C", T=T, S=S)
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
T, S = x.shape[1], x.shape[2]
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
self.previous_residual = rearrange(self.previous_residual, "B T S C -> B (T S) C", T=T, S=S)
x_mask = x_mask_org
else:
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
T, S = x.shape[1], x.shape[2]
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
x_mask = x_mask_org
# === final layer ===
x = self.final_layer(x, t, x_mask, t0, T, S)
x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
@ -201,12 +198,11 @@ def teacache_forward(
x = x.to(torch.float32)
# === Gather Output ===
if get_cfg_parallel_size() > 1:
x = gather_sequence(x, get_data_parallel_group(), dim=0)
if self.parallel_manager.cp_size > 1:
x = gather_sequence(x, self.parallel_manager.cp_group, dim=0)
return x
def eval_base(prompt_list):
config = OpenSoraConfig()
engine = VideoSysEngine(config)
@ -235,9 +231,9 @@ def eval_teacache_fast(prompt_list):
generate_func(engine, prompt_list, "./samples/opensora_teacache_fast", loop=5)
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
# eval_base(prompt_list)
eval_base(prompt_list)
eval_teacache_slow(prompt_list)
# eval_teacache_fast(prompt_list)
eval_teacache_fast(prompt_list)

View File

@ -5,12 +5,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
from typing import Any, Dict, Optional, Tuple
from videosys.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_group,
get_cfg_parallel_size,
get_sequence_parallel_group,
)
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
def teacache_forward(
self,
@ -66,7 +61,7 @@ def teacache_forward(
`tuple` where the first element is the sample tensor.
"""
# 0. Split batch
if get_cfg_parallel_size() > 1:
if self.parallel_manager.cp_size > 1:
(
hidden_states,
timestep,
@ -75,7 +70,7 @@ def teacache_forward(
attention_mask,
encoder_attention_mask,
) = batch_func(
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
hidden_states,
timestep,
encoder_hidden_states,
@ -204,20 +199,20 @@ def teacache_forward(
if not should_calc:
hidden_states += self.previous_residual
else:
if enable_sequence_parallel():
set_temporal_pad(frame + use_image_num)
set_spatial_pad(num_patches)
if self.parallel_manager.sp_size > 1:
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
ori_hidden_states = hidden_states.clone()
ori_hidden_states = hidden_states.clone().detach()
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
@ -359,20 +354,20 @@ def teacache_forward(
).contiguous()
self.previous_residual = hidden_states - ori_hidden_states
else:
if enable_sequence_parallel():
set_temporal_pad(frame + use_image_num)
set_spatial_pad(num_patches)
if self.parallel_manager.sp_size > 1:
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
self.previous_residual = self.split_from_second_dim(self.previous_residual, input_batch_size) if self.previous_residual is not None else None
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
@ -512,8 +507,8 @@ def teacache_forward(
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
if enable_sequence_parallel():
if self.parallel_manager.sp_size > 1:
if self.enable_teacache:
if should_calc:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
@ -550,8 +545,8 @@ def teacache_forward(
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
# 3. Gather batch for data parallelism
if get_cfg_parallel_size() > 1:
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
if self.parallel_manager.cp_size > 1:
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)
@ -559,14 +554,8 @@ def teacache_forward(
return Transformer3DModelOutput(sample=output)
def eval_base(prompt_list):
config = OpenSoraPlanConfig()
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)
def eval_teacache_slow(prompt_list):
config = OpenSoraPlanConfig()
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
engine = VideoSysEngine(config)
engine.driver_worker.transformer.enable_teacache = True
engine.driver_worker.transformer.rel_l1_thresh = 0.1
@ -577,7 +566,7 @@ def eval_teacache_slow(prompt_list):
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
config = OpenSoraPlanConfig()
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
engine = VideoSysEngine(config)
engine.driver_worker.transformer.enable_teacache = True
engine.driver_worker.transformer.rel_l1_thresh = 0.2
@ -587,8 +576,16 @@ def eval_teacache_fast(prompt_list):
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_fast", loop=5)
def eval_base(prompt_list):
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", )
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
# eval_base(prompt_list)
eval_base(prompt_list)
eval_teacache_slow(prompt_list)
# eval_teacache_fast(prompt_list)
eval_teacache_fast(prompt_list)

View File

@ -1,5 +1,7 @@
accelerate>0.17.0
bs4
click
colossalai
colossalai==0.4.0
diffusers==0.30.0
einops
fabric
@ -16,7 +18,9 @@ pydantic
ray
rich
safetensors
sentencepiece
timm
torch>=1.13
tqdm
transformers
peft==0.13.2
transformers==4.39.3

View File

@ -1,6 +1,9 @@
from typing import List
from setuptools import find_packages, setup
from setuptools.command.develop import develop
from setuptools.command.egg_info import egg_info
from setuptools.command.install import install
def fetch_requirements(path) -> List[str]:
@ -14,7 +17,9 @@ def fetch_requirements(path) -> List[str]:
The lines in the requirements file.
"""
with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
requirements = [r.strip() for r in fd.readlines()]
# requirements.remove("colossalai")
return requirements
def fetch_readme() -> str:
@ -28,6 +33,28 @@ def fetch_readme() -> str:
return f.read()
def custom_install():
return ["pip", "install", "colossalai", "--no-deps"]
class CustomInstallCommand(install):
def run(self):
install.run(self)
self.spawn(custom_install())
class CustomDevelopCommand(develop):
def run(self):
develop.run(self)
self.spawn(custom_install())
class CustomEggInfoCommand(egg_info):
def run(self):
egg_info.run(self)
self.spawn(custom_install())
setup(
name="videosys",
version="2.0.0",
@ -39,12 +66,17 @@ setup(
"*.egg-info",
)
),
description="VideoSys",
description="TeaCache",
long_description=fetch_readme(),
long_description_content_type="text/markdown",
license="Apache Software License 2.0",
install_requires=fetch_requirements("requirements.txt"),
python_requires=">=3.6",
python_requires=">=3.7",
# cmdclass={
# "install": CustomInstallCommand,
# "develop": CustomDevelopCommand,
# "egg_info": CustomEggInfoCommand,
# },
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",

View File

@ -3,13 +3,20 @@ from .core.parallel_mgr import initialize
from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline
from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
from .pipelines.open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
from .pipelines.open_sora_plan import (
OpenSoraPlanConfig,
OpenSoraPlanPipeline,
OpenSoraPlanV110PABConfig,
OpenSoraPlanV120PABConfig,
)
from .pipelines.vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline
__all__ = [
"initialize",
"VideoSysEngine",
"LattePipeline", "LatteConfig", "LattePABConfig",
"OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanPABConfig",
"OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig",
"OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig",
"CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"
"CogVideoXPipeline", "CogVideoXConfig", "CogVideoXPABConfig",
"VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig"
] # fmt: skip

View File

@ -7,8 +7,6 @@ from einops import rearrange
from torch import Tensor
from torch.distributed import ProcessGroup
from videosys.core.parallel_mgr import get_sequence_parallel_size
# ======================================================
# Model
# ======================================================
@ -369,30 +367,18 @@ def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
# Pad
# ==============================
SPTIAL_PAD = 0
TEMPORAL_PAD = 0
PAD_DICT = {}
def set_spatial_pad(dim_size: int):
sp_size = get_sequence_parallel_size()
def set_pad(name: str, dim_size: int, parallel_group: dist.ProcessGroup):
sp_size = dist.get_world_size(parallel_group)
pad = (sp_size - (dim_size % sp_size)) % sp_size
global SPTIAL_PAD
SPTIAL_PAD = pad
global PAD_DICT
PAD_DICT[name] = pad
def get_spatial_pad() -> int:
return SPTIAL_PAD
def set_temporal_pad(dim_size: int):
sp_size = get_sequence_parallel_size()
pad = (sp_size - (dim_size % sp_size)) % sp_size
global TEMPORAL_PAD
TEMPORAL_PAD = pad
def get_temporal_pad() -> int:
return TEMPORAL_PAD
def get_pad(name) -> int:
return PAD_DICT[name]
def all_to_all_with_pad(
@ -418,3 +404,17 @@ def all_to_all_with_pad(
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
return input_
def split_from_second_dim(x, batch_size, parallel_group):
x = x.view(batch_size, -1, *x.shape[1:])
x = split_sequence(x, parallel_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
x = x.reshape(-1, *x.shape[2:])
return x
def gather_from_second_dim(x, batch_size, parallel_group):
x = x.view(batch_size, -1, *x.shape[1:])
x = gather_sequence(x, parallel_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
x = x.reshape(-1, *x.shape[2:])
return x

View File

@ -66,7 +66,7 @@ class VideoSysEngine:
# TODO: add more options here for pipeline, or wrap all options into config
def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None):
videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method, seed=42)
videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method)
pipeline = pipeline_cls(self.config)
return pipeline

View File

@ -6,7 +6,6 @@ PAB_MANAGER = None
class PABConfig:
def __init__(
self,
steps: int,
cross_broadcast: bool = False,
cross_threshold: list = None,
cross_range: int = None,
@ -20,7 +19,7 @@ class PABConfig:
mlp_spatial_broadcast_config: dict = None,
mlp_temporal_broadcast_config: dict = None,
):
self.steps = steps
self.steps = None
self.cross_broadcast = cross_broadcast
self.cross_threshold = cross_threshold
@ -45,7 +44,7 @@ class PABManager:
def __init__(self, config: PABConfig):
self.config: PABConfig = config
init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
init_prompt = f"Init Pyramid Attention Broadcast."
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
@ -78,7 +77,7 @@ class PABManager:
count = (count + 1) % self.config.steps
return flag, count
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
def if_broadcast_spatial(self, timestep: int, count: int):
if (
self.config.spatial_broadcast
and (timestep is not None)
@ -213,10 +212,10 @@ def if_broadcast_temporal(timestep: int, count: int):
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
def if_broadcast_spatial(timestep: int, count: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
return PAB_MANAGER.if_broadcast_spatial(timestep, count)
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):

View File

@ -1,14 +1,9 @@
from typing import Optional
import torch
import torch.distributed as dist
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from torch.distributed import ProcessGroup
from videosys.utils.logging import init_dist_logger, logger
from videosys.utils.utils import set_seed
PARALLEL_MANAGER = None
class ParallelManager(ProcessGroupMesh):
@ -21,71 +16,28 @@ class ParallelManager(ProcessGroupMesh):
self.dp_rank = dist.get_rank(self.dp_group)
self.cp_size = cp_size
self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
self.cp_rank = dist.get_rank(self.cp_group)
if cp_size > 1:
self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
self.cp_rank = dist.get_rank(self.cp_group)
else:
self.cp_group = None
self.cp_rank = None
self.sp_size = sp_size
self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
self.sp_rank = dist.get_rank(self.sp_group)
self.enable_sp = sp_size > 1
if sp_size > 1:
self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
self.sp_rank = dist.get_rank(self.sp_group)
else:
self.sp_group = None
self.sp_rank = None
logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}")
def set_parallel_manager(dp_size, cp_size, sp_size):
global PARALLEL_MANAGER
PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size)
def get_data_parallel_group():
return PARALLEL_MANAGER.dp_group
def get_data_parallel_size():
return PARALLEL_MANAGER.dp_size
def get_data_parallel_rank():
return PARALLEL_MANAGER.dp_rank
def get_sequence_parallel_group():
return PARALLEL_MANAGER.sp_group
def get_sequence_parallel_size():
return PARALLEL_MANAGER.sp_size
def get_sequence_parallel_rank():
return PARALLEL_MANAGER.sp_rank
def get_cfg_parallel_group():
return PARALLEL_MANAGER.cp_group
def get_cfg_parallel_size():
return PARALLEL_MANAGER.cp_size
def enable_sequence_parallel():
if PARALLEL_MANAGER is None:
return False
return PARALLEL_MANAGER.enable_sp
def get_parallel_manager():
return PARALLEL_MANAGER
def initialize(
rank=0,
world_size=1,
init_method=None,
seed: Optional[int] = None,
sp_size: Optional[int] = None,
enable_cp: bool = False,
):
if not dist.is_initialized():
try:
@ -97,24 +49,3 @@ def initialize(
init_dist_logger()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# init sequence parallel
if sp_size is None:
sp_size = dist.get_world_size()
dp_size = 1
else:
assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size"
dp_size = dist.get_world_size() // sp_size
# update cfg parallel
# NOTE: enable cp parallel will be slower. disable it for now.
if False and enable_cp and sp_size % 2 == 0:
sp_size = sp_size // 2
cp_size = 2
else:
cp_size = 1
set_parallel_manager(dp_size, cp_size, sp_size)
if seed is not None:
set_seed(seed + get_data_parallel_rank())

View File

@ -13,9 +13,10 @@ class VideoSysPipeline(DiffusionPipeline):
@staticmethod
def set_eval_and_device(device: torch.device, *modules):
for module in modules:
module.eval()
module.to(device)
modules = list(modules)
for i in range(len(modules)):
modules[i] = modules[i].eval()
modules[i] = modules[i].to(device)
@abstractmethod
def generate(self, *args, **kwargs):

View File

@ -694,7 +694,10 @@ class VideoAutoencoderPipeline(PreTrainedModel):
else:
return x
def forward(self, x):
def forward(self, x, decode_only=False, **kwargs):
if decode_only:
return self.decode(x, **kwargs)
assert self.cal_loss, "This method is only available when cal_loss is True"
z, posterior, x_z = self.encode(x)
x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2])

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,12 +1,21 @@
import inspect
from dataclasses import dataclass
from typing import Iterable, List, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers.models.attention import Attention
from diffusers.models.attention_processor import AttnProcessor
from einops import rearrange
from torch import nn
from torch.amp import autocast
from videosys.models.modules.normalization import LlamaRMSNorm
from videosys.core.comm import all_to_all_with_pad, get_pad, set_pad
from videosys.core.pab_mgr import enable_pab, if_broadcast_cross, if_broadcast_spatial, if_broadcast_temporal
from videosys.models.modules.normalization import LlamaRMSNorm, VchitectSpatialNorm
from videosys.utils.logging import logger
class OpenSoraAttention(nn.Module):
@ -203,3 +212,634 @@ class _SeqLenInfo:
seqstart=seqstart,
seqstart_py=seqstart_py,
)
class VchitectAttention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`):
The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8):
The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64):
The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
upcast_attention (`bool`, *optional*, defaults to False):
Set to `True` to upcast the attention computation to `float32`.
upcast_softmax (`bool`, *optional*, defaults to False):
Set to `True` to upcast the softmax computation to `float32`.
cross_attention_norm (`str`, *optional*, defaults to `None`):
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the group norm in the cross attention.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
norm_num_groups (`int`, *optional*, defaults to `None`):
The number of groups to use for the group norm in the attention.
spatial_norm_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the spatial normalization.
out_bias (`bool`, *optional*, defaults to `True`):
Set to `True` to use a bias in the output linear layer.
scale_qk (`bool`, *optional*, defaults to `True`):
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
only_cross_attention (`bool`, *optional*, defaults to `False`):
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
`added_kv_proj_dim` is not `None`.
eps (`float`, *optional*, defaults to 1e-5):
An additional value added to the denominator in group normalization that is used for numerical stability.
rescale_output_factor (`float`, *optional*, defaults to 1.0):
A factor to rescale the output by dividing it with this value.
residual_connection (`bool`, *optional*, defaults to `False`):
Set to `True` to add the residual connection to the output.
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
Set to `True` if the attention block is loaded from a deprecated state dict.
processor (`AttnProcessor`, *optional*, defaults to `None`):
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
`AttnProcessor` otherwise.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional[AttnProcessor] = None,
out_dim: int = None,
context_pre_only: bool = None,
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
self._from_deprecated_attn_block = _from_deprecated_attn_block
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
else:
self.group_norm = None
if spatial_norm_dim is not None:
self.spatial_norm = VchitectSpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
else:
self.spatial_norm = None
if qk_norm is None:
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
if cross_attention_norm is None:
self.norm_cross = None
elif cross_attention_norm == "layer_norm":
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
elif cross_attention_norm == "group_norm":
if self.added_kv_proj_dim is not None:
# The given `encoder_hidden_states` are initially of shape
# (batch_size, seq_len, added_kv_proj_dim) before being projected
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
# before the projection, so we need to use `added_kv_proj_dim` as
# the number of channels for the group norm.
norm_cross_num_channels = added_kv_proj_dim
else:
norm_cross_num_channels = self.cross_attention_dim
self.norm_cross = nn.GroupNorm(
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
)
else:
raise ValueError(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
self.to_q_cross = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_q_temp = nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k_temp = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v_temp = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
else:
self.to_k_temp = None
self.to_v_temp = None
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
self.to_out_temporal = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
nn.init.constant_(self.to_out_temporal.weight, 0)
nn.init.constant_(self.to_out_temporal.bias, 0)
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
self.to_add_out_temporal = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
nn.init.constant_(self.to_add_out_temporal.weight, 0)
nn.init.constant_(self.to_add_out_temporal.bias, 0)
self.to_out_context = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
nn.init.constant_(self.to_out_context.weight, 0)
nn.init.constant_(self.to_out_context.bias, 0)
# set attention processor
self.set_processor(processor)
# parallel
self.parallel_manager = None
# pab
self.spatial_count = 0
self.last_spatial = None
self.temporal_count = 0
self.last_temporal = None
self.cross_count = 0
self.last_cross = None
def set_processor(self, processor: "AttnProcessor") -> None:
r"""
Set the attention processor to use.
Args:
processor (`AttnProcessor`):
The attention processor to use.
"""
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
full_seqlen: Optional[int] = None,
Frame: Optional[int] = None,
timestep: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
r"""
The forward method of the `Attention` class.
Args:
hidden_states (`torch.Tensor`):
The hidden states of the query.
encoder_hidden_states (`torch.Tensor`, *optional*):
The hidden states of the encoder.
attention_mask (`torch.Tensor`, *optional*):
The attention mask to use. If `None`, no mask is applied.
**cross_attention_kwargs:
Additional keyword arguments to pass along to the cross attention.
Returns:
`torch.Tensor`: The output of the attention layer.
"""
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks"}
unused_kwargs = [
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
]
if len(unused_kwargs) > 0:
logger.warning(
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
full_seqlen=full_seqlen,
Frame=Frame,
timestep=timestep,
**cross_attention_kwargs,
)
@torch.no_grad()
def fuse_projections(self, fuse=True):
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if not self.is_cross_attention:
# fetch weight matrices.
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
# create a new single projection layer and copy over the weights.
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias)
self.fused_projections = fuse
class VchitectAttnProcessor:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
@autocast("cuda", enabled=False)
def apply_rotary_emb(
self,
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def spatial_attn(
self,
attn,
hidden_states,
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
batch_size,
head_dim,
):
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# attention
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
xq, xk = query.to(value.dtype), key.to(value.dtype)
query_spatial, key_spatial, value_spatial = xq, xk, value
query_spatial = query_spatial.transpose(1, 2)
key_spatial = key_spatial.transpose(1, 2)
value_spatial = value_spatial.transpose(1, 2)
hidden_states = hidden_states = F.scaled_dot_product_attention(
query_spatial, key_spatial, value_spatial, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
return hidden_states
def temporal_attention(
self,
attn,
hidden_states,
residual,
batch_size,
batchsize,
Frame,
head_dim,
freqs_cis,
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
):
query_t = attn.to_q_temp(hidden_states)
key_t = attn.to_k_temp(hidden_states)
value_t = attn.to_v_temp(hidden_states)
query_t = torch.cat([query_t, encoder_hidden_states_query_proj], dim=1)
key_t = torch.cat([key_t, encoder_hidden_states_key_proj], dim=1)
value_t = torch.cat([value_t, encoder_hidden_states_value_proj], dim=1)
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
query_t, key_t = query_t.to(value_t.dtype), key_t.to(value_t.dtype)
if attn.parallel_manager.sp_size > 1:
func = lambda x: self.dynamic_switch(attn, x, batchsize, to_spatial_shard=True)
query_t, key_t, value_t = map(func, [query_t, key_t, value_t])
func = lambda x: rearrange(x, "(B T) S H C -> (B S) T H C", T=Frame, B=batchsize)
xq_gather_temporal, xk_gather_temporal, xv_gather_temporal = map(func, [query_t, key_t, value_t])
freqs_cis_temporal = freqs_cis[: xq_gather_temporal.shape[1], :]
xq_gather_temporal, xk_gather_temporal = self.apply_rotary_emb(
xq_gather_temporal, xk_gather_temporal, freqs_cis=freqs_cis_temporal
)
xq_gather_temporal = xq_gather_temporal.transpose(1, 2)
xk_gather_temporal = xk_gather_temporal.transpose(1, 2)
xv_gather_temporal = xv_gather_temporal.transpose(1, 2)
batch_size_temp = xv_gather_temporal.shape[0]
hidden_states_temp = F.scaled_dot_product_attention(
xq_gather_temporal, xk_gather_temporal, xv_gather_temporal, dropout_p=0.0, is_causal=False
)
hidden_states_temp = hidden_states_temp.transpose(1, 2).reshape(batch_size_temp, -1, attn.heads * head_dim)
hidden_states_temp = hidden_states_temp.to(value_t.dtype)
hidden_states_temp = rearrange(hidden_states_temp, "(B S) T C -> (B T) S C", T=Frame, B=batchsize)
if attn.parallel_manager.sp_size > 1:
hidden_states_temp = self.dynamic_switch(attn, hidden_states_temp, batchsize, to_spatial_shard=False)
hidden_states_temporal, encoder_hidden_states_temporal = (
hidden_states_temp[:, : residual.shape[1]],
hidden_states_temp[:, residual.shape[1] :],
)
hidden_states_temporal = attn.to_out_temporal(hidden_states_temporal)
return hidden_states_temporal, encoder_hidden_states_temporal
def cross_attention(
self,
attn,
hidden_states,
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
batch_size,
head_dim,
cur_frame,
batchsize,
):
query_cross = attn.to_q_cross(hidden_states)
query_cross = torch.cat([query_cross, encoder_hidden_states_query_proj], dim=1)
key_y = encoder_hidden_states_key_proj[0].unsqueeze(0)
value_y = encoder_hidden_states_value_proj[0].unsqueeze(0)
query_y = query_cross.view(batch_size, -1, attn.heads, head_dim)
key_y = key_y.view(batchsize, -1, attn.heads, head_dim)
value_y = value_y.view(batchsize, -1, attn.heads, head_dim)
query_y = rearrange(query_y, "(B T) S H C -> B (S T) H C", T=cur_frame, B=batchsize)
query_y = query_y.transpose(1, 2)
key_y = key_y.transpose(1, 2)
value_y = value_y.transpose(1, 2)
cross_output = F.scaled_dot_product_attention(query_y, key_y, value_y, dropout_p=0.0, is_causal=False)
cross_output = cross_output.transpose(1, 2).reshape(batchsize, -1, attn.heads * head_dim)
cross_output = cross_output.to(query_cross.dtype)
cross_output = rearrange(cross_output, "B (S T) C -> (B T) S C", T=cur_frame, B=batchsize)
cross_output = attn.to_out_context(cross_output)
return cross_output
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
full_seqlen: Optional[int] = None,
Frame: Optional[int] = None,
timestep: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
batch_size = encoder_hidden_states.shape[0]
inner_dim = encoder_hidden_states_key_proj.shape[-1]
head_dim = inner_dim // attn.heads
batchsize = full_seqlen // Frame
# same as Frame if shard, otherwise sharded frame
cur_frame = batch_size // batchsize
# temporal attention
if enable_pab():
broadcast_temporal, attn.temporal_count = if_broadcast_temporal(int(timestep[0]), attn.temporal_count)
if enable_pab() and broadcast_temporal:
hidden_states_temporal, encoder_hidden_states_temporal = attn.last_temporal
else:
hidden_states_temporal, encoder_hidden_states_temporal = self.temporal_attention(
attn,
hidden_states,
residual,
batch_size,
batchsize,
Frame,
head_dim,
freqs_cis,
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
)
if enable_pab():
attn.last_temporal = (hidden_states_temporal, encoder_hidden_states_temporal)
# cross attn
if enable_pab():
broadcast_cross, attn.cross_count = if_broadcast_cross(int(timestep[0]), attn.cross_count)
if enable_pab() and broadcast_cross:
cross_output = attn.last_cross
else:
cross_output = self.cross_attention(
attn,
hidden_states,
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
batch_size,
head_dim,
cur_frame,
batchsize,
)
if enable_pab():
attn.last_cross = cross_output
# spatial attn
if enable_pab():
broadcast_spatial, attn.spatial_count = if_broadcast_spatial(int(timestep[0]), attn.spatial_count)
if enable_pab() and broadcast_spatial:
hidden_states = attn.last_spatial
else:
hidden_states = self.spatial_attn(
attn,
hidden_states,
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
batch_size,
head_dim,
)
if enable_pab():
attn.last_spatial = hidden_states
# processs attention outputs.
hidden_states = hidden_states * 1.1 + cross_output
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if cur_frame == 1:
hidden_states_temporal = hidden_states_temporal * 0
hidden_states = hidden_states + hidden_states_temporal
# encoder
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
encoder_hidden_states_temporal = attn.to_add_out_temporal(encoder_hidden_states_temporal)
if cur_frame == 1:
encoder_hidden_states_temporal = encoder_hidden_states_temporal * 0
encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_temporal
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
def dynamic_switch(self, attn, x, batchsize, to_spatial_shard: bool):
if to_spatial_shard:
scatter_dim, gather_dim = 2, 1
set_pad("spatial", x.shape[1], attn.parallel_manager.sp_group)
scatter_pad = get_pad("spatial")
gather_pad = get_pad("temporal")
else:
scatter_dim, gather_dim = 1, 2
scatter_pad = get_pad("temporal")
gather_pad = get_pad("spatial")
x = rearrange(x, "(B T) S ... -> B T S ...", B=batchsize)
x = all_to_all_with_pad(
x,
attn.parallel_manager.sp_group,
scatter_dim=scatter_dim,
gather_dim=gather_dim,
scatter_pad=scatter_pad,
gather_pad=gather_pad,
)
x = rearrange(x, "B T ... -> (B T) ...")
return x

View File

@ -2,6 +2,7 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class LlamaRMSNorm(nn.Module):
@ -100,3 +101,32 @@ class AdaLayerNorm(nn.Module):
x = self.norm(x) * (1 + scale) + shift
return x
class VchitectSpatialNorm(nn.Module):
"""
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
Args:
f_channels (`int`):
The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`):
The number of channels for the quantized vector as described in the paper.
"""
def __init__(
self,
f_channels: int,
zq_channels: int,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
f_size = f.shape[-2:]
zq = F.interpolate(zq, size=f_size, mode="nearest")
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f

View File

@ -22,15 +22,9 @@ from diffusers.utils import is_torch_version
from diffusers.utils.torch_utils import maybe_allow_in_graph
from torch import nn
from videosys.core.comm import all_to_all_comm, gather_sequence, get_spatial_pad, set_spatial_pad, split_sequence
from videosys.core.comm import all_to_all_comm, gather_sequence, get_pad, set_pad, split_sequence
from videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
from videosys.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_group,
get_cfg_parallel_size,
get_sequence_parallel_group,
get_sequence_parallel_size,
)
from videosys.core.parallel_mgr import ParallelManager
from videosys.models.modules.embeddings import apply_rotary_emb
from videosys.utils.utils import batch_func
@ -48,6 +42,49 @@ class CogVideoXAttnProcessor2_0:
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def _remove_extra_encoder(self, hidden_states, text_seq_length, attn):
# current layout is [text, 1/n seq, text, 1/n seq, ...]
# we want to remove the all the text info [text, seq]
sp_size = attn.parallel_manager.sp_size
split_seq = hidden_states.split(hidden_states.size(2) // sp_size, dim=2)
encoder_hidden_states = split_seq[0][:, :, :text_seq_length]
new_seq = [encoder_hidden_states]
for i in range(sp_size):
new_seq.append(split_seq[i][:, :, text_seq_length:])
hidden_states = torch.cat(new_seq, dim=2)
# remove padding added when all2all
# if pad is removed earlier than this
# the split size will be wrong
pad = get_pad("pad")
if pad > 0:
hidden_states = hidden_states.narrow(2, 0, hidden_states.size(2) - pad)
return hidden_states
def _add_extra_encoder(self, hidden_states, text_seq_length, attn):
# add padding for split and later all2all
# if pad is removed later than this
# the split size will be wrong
pad = get_pad("pad")
if pad > 0:
pad_shape = list(hidden_states.shape)
pad_shape[1] = pad
pad_tensor = torch.zeros(pad_shape, device=hidden_states.device, dtype=hidden_states.dtype)
hidden_states = torch.cat([hidden_states, pad_tensor], dim=1)
# current layout is [text, seq]
# we want to add the extra encoder info [text, 1/n seq, text, 1/n seq, ...]
sp_size = attn.parallel_manager.sp_size
encoder = hidden_states[:, :text_seq_length]
seq = hidden_states[:, text_seq_length:]
seq = seq.split(seq.size(1) // sp_size, dim=1)
new_seq = []
for i in range(sp_size):
new_seq.append(encoder)
new_seq.append(seq[i])
hidden_states = torch.cat(new_seq, dim=1)
return hidden_states
def __call__(
self,
attn: Attention,
@ -72,13 +109,15 @@ class CogVideoXAttnProcessor2_0:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
if enable_sequence_parallel():
if attn.parallel_manager.sp_size > 1:
assert (
attn.heads % get_sequence_parallel_size() == 0
), f"Number of heads {attn.heads} must be divisible by sequence parallel size {get_sequence_parallel_size()}"
attn_heads = attn.heads // get_sequence_parallel_size()
attn.heads % attn.parallel_manager.sp_size == 0
), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}"
attn_heads = attn.heads // attn.parallel_manager.sp_size
# normally we operate pad for every all2all. but for more convient implementation
# we move pad operation to encoder add and remove in cogvideo
query, key, value = map(
lambda x: all_to_all_comm(x, get_sequence_parallel_group(), scatter_dim=2, gather_dim=1),
lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1),
[query, key, value],
)
else:
@ -96,6 +135,13 @@ class CogVideoXAttnProcessor2_0:
if attn.norm_k is not None:
key = attn.norm_k(key)
if attn.parallel_manager.sp_size > 1:
# remove extra encoder for attention
query, key, value = map(
lambda x: self._remove_extra_encoder(x, text_seq_length, attn),
[query, key, value],
)
# Apply RoPE if needed
if image_rotary_emb is not None:
emb_len = image_rotary_emb[0].shape[0]
@ -113,77 +159,10 @@ class CogVideoXAttnProcessor2_0:
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
if enable_sequence_parallel():
hidden_states = all_to_all_comm(hidden_states, get_sequence_parallel_group(), scatter_dim=1, gather_dim=2)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
if attn.parallel_manager.sp_size > 1:
# add extra encoder for all_to_all
hidden_states = self._add_extra_encoder(hidden_states, text_seq_length, attn)
hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@ -266,6 +245,9 @@ class CogVideoXBlock(nn.Module):
processor=CogVideoXAttnProcessor2_0(),
)
# parallel
self.attn1.parallel_manager = None
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
@ -300,7 +282,7 @@ class CogVideoXBlock(nn.Module):
# attention
if enable_pab():
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count)
if enable_pab() and broadcast_attn:
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
else:
@ -474,6 +456,23 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False
# parallel
self.parallel_manager = None
def enable_parallel(self, dp_size, sp_size, enable_cp):
# update cfg parallel
if enable_cp and sp_size % 2 == 0:
sp_size = sp_size // 2
cp_size = 2
else:
cp_size = 1
self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size)
for _, module in self.named_modules():
if hasattr(module, "parallel_manager"):
module.parallel_manager = self.parallel_manager
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@ -485,9 +484,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
all_timesteps=None,
):
if get_cfg_parallel_size() > 1:
if self.parallel_manager.cp_size > 1:
(
hidden_states,
encoder_hidden_states,
@ -495,7 +493,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
timestep_cond,
image_rotary_emb,
) = batch_func(
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
hidden_states,
encoder_hidden_states,
timestep,
@ -530,9 +528,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
if enable_sequence_parallel():
set_spatial_pad(hidden_states.shape[1])
hidden_states = split_sequence(hidden_states, get_sequence_parallel_group(), dim=1, pad=get_spatial_pad())
if self.parallel_manager.sp_size > 1:
set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
@ -562,8 +560,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
timestep=timesteps if enable_pab() else None,
)
if enable_sequence_parallel():
hidden_states = gather_sequence(hidden_states, get_sequence_parallel_group(), dim=1, pad=get_spatial_pad())
if self.parallel_manager.sp_size > 1:
hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
@ -583,8 +581,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if get_cfg_parallel_size() > 1:
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
if self.parallel_manager.cp_size > 1:
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)

View File

@ -33,15 +33,7 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
from einops import rearrange, repeat
from torch import nn
from videosys.core.comm import (
all_to_all_with_pad,
gather_sequence,
get_spatial_pad,
get_temporal_pad,
set_spatial_pad,
set_temporal_pad,
split_sequence,
)
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
from videosys.core.pab_mgr import (
enable_pab,
get_mlp_output,
@ -51,12 +43,7 @@ from videosys.core.pab_mgr import (
if_broadcast_temporal,
save_mlp_output,
)
from videosys.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_group,
get_cfg_parallel_size,
get_sequence_parallel_group,
)
from videosys.core.parallel_mgr import ParallelManager
from videosys.utils.utils import batch_func
@ -388,9 +375,7 @@ class BasicTransformerBlock(nn.Module):
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
if enable_pab():
broadcast_spatial, self.spatial_count = if_broadcast_spatial(
int(org_timestep[0]), self.spatial_count, self.block_idx
)
broadcast_spatial, self.spatial_count = if_broadcast_spatial(int(org_timestep[0]), self.spatial_count)
if enable_pab() and broadcast_spatial:
attn_output = self.spatial_last
@ -681,6 +666,9 @@ class BasicTransformerBlock_(nn.Module):
self.block_idx = block_idx
self.count = 0
# parallel
self.parallel_manager: ParallelManager = None
def set_last_out(self, last_out: torch.Tensor):
self.last_out = last_out
@ -743,7 +731,7 @@ class BasicTransformerBlock_(nn.Module):
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
if enable_sequence_parallel():
if self.parallel_manager.sp_size > 1:
norm_hidden_states = self.dynamic_switch(norm_hidden_states, to_spatial_shard=True)
attn_output = self.attn1(
@ -753,7 +741,7 @@ class BasicTransformerBlock_(nn.Module):
**cross_attention_kwargs,
)
if enable_sequence_parallel():
if self.parallel_manager.sp_size > 1:
attn_output = self.dynamic_switch(attn_output, to_spatial_shard=False)
if self.use_ada_layer_norm_zero:
@ -838,15 +826,15 @@ class BasicTransformerBlock_(nn.Module):
def dynamic_switch(self, x, to_spatial_shard: bool):
if to_spatial_shard:
scatter_dim, gather_dim = 0, 1
scatter_pad = get_spatial_pad()
gather_pad = get_temporal_pad()
scatter_pad = get_pad("spatial")
gather_pad = get_pad("temporal")
else:
scatter_dim, gather_dim = 1, 0
scatter_pad = get_temporal_pad()
gather_pad = get_spatial_pad()
scatter_pad = get_pad("temporal")
gather_pad = get_pad("spatial")
x = all_to_all_with_pad(
x,
get_sequence_parallel_group(),
self.parallel_manager.sp_group,
scatter_dim=scatter_dim,
gather_dim=gather_dim,
scatter_pad=scatter_pad,
@ -1133,6 +1121,23 @@ class LatteT2V(ModelMixin, ConfigMixin):
temp_pos_embed = self.get_1d_sincos_temp_embed(inner_dim, video_length) # 1152 hidden size
self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
# parallel
self.parallel_manager = None
def enable_parallel(self, dp_size, sp_size, enable_cp):
# update cfg parallel
if enable_cp and sp_size % 2 == 0:
sp_size = sp_size // 2
cp_size = 2
else:
cp_size = 1
self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size)
for _, module in self.named_modules():
if hasattr(module, "parallel_manager"):
module.parallel_manager = self.parallel_manager
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@ -1191,7 +1196,7 @@ class LatteT2V(ModelMixin, ConfigMixin):
"""
# 0. Split batch for data parallelism
if get_cfg_parallel_size() > 1:
if self.parallel_manager.cp_size > 1:
(
hidden_states,
timestep,
@ -1201,7 +1206,7 @@ class LatteT2V(ModelMixin, ConfigMixin):
attention_mask,
encoder_attention_mask,
) = batch_func(
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
hidden_states,
timestep,
encoder_hidden_states,
@ -1292,14 +1297,14 @@ class LatteT2V(ModelMixin, ConfigMixin):
timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
if enable_sequence_parallel():
set_temporal_pad(frame + use_image_num)
set_spatial_pad(num_patches)
if self.parallel_manager.sp_size > 1:
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
@ -1420,7 +1425,7 @@ class LatteT2V(ModelMixin, ConfigMixin):
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
if enable_sequence_parallel():
if self.parallel_manager.sp_size > 1:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
if self.is_input_patches:
@ -1452,8 +1457,8 @@ class LatteT2V(ModelMixin, ConfigMixin):
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
# 3. Gather batch for data parallelism
if get_cfg_parallel_size() > 1:
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
if self.parallel_manager.cp_size > 1:
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)
@ -1466,12 +1471,12 @@ class LatteT2V(ModelMixin, ConfigMixin):
def split_from_second_dim(self, x, batch_size):
x = x.view(batch_size, -1, *x.shape[1:])
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
x = x.reshape(-1, *x.shape[2:])
return x
def gather_from_second_dim(self, x, batch_size):
x = x.view(batch_size, -1, *x.shape[1:])
x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad())
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
x = x.reshape(-1, *x.shape[2:])
return x

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -20,15 +20,7 @@ from timm.models.vision_transformer import Mlp
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from transformers import PretrainedConfig, PreTrainedModel
from videosys.core.comm import (
all_to_all_with_pad,
gather_sequence,
get_spatial_pad,
get_temporal_pad,
set_spatial_pad,
set_temporal_pad,
split_sequence,
)
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
from videosys.core.pab_mgr import (
enable_pab,
get_mlp_output,
@ -38,12 +30,7 @@ from videosys.core.pab_mgr import (
if_broadcast_temporal,
save_mlp_output,
)
from videosys.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_size,
get_data_parallel_group,
get_sequence_parallel_group,
)
from videosys.core.parallel_mgr import ParallelManager
from videosys.models.modules.activations import approx_gelu
from videosys.models.modules.attentions import OpenSoraAttention, OpenSoraMultiHeadCrossAttention
from videosys.models.modules.embeddings import (
@ -143,6 +130,9 @@ class STDiT3Block(nn.Module):
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
# parallel
self.parallel_manager: ParallelManager = None
# pab
self.block_idx = block_idx
self.attn_count = 0
@ -188,9 +178,7 @@ class STDiT3Block(nn.Module):
if self.temporal:
broadcast_attn, self.attn_count = if_broadcast_temporal(int(timestep[0]), self.attn_count)
else:
broadcast_attn, self.attn_count = if_broadcast_spatial(
int(timestep[0]), self.attn_count, self.block_idx
)
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count)
if enable_pab() and broadcast_attn:
x_m_s = self.last_attn
@ -203,12 +191,12 @@ class STDiT3Block(nn.Module):
# attention
if self.temporal:
if enable_sequence_parallel():
if self.parallel_manager.sp_size > 1:
x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=True)
x_m = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S)
x_m = self.attn(x_m)
x_m = rearrange(x_m, "(B S) T C -> B (T S) C", T=T, S=S)
if enable_sequence_parallel():
if self.parallel_manager.sp_size > 1:
x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=False)
else:
x_m = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S)
@ -287,17 +275,17 @@ class STDiT3Block(nn.Module):
def dynamic_switch(self, x, s, t, to_spatial_shard: bool):
if to_spatial_shard:
scatter_dim, gather_dim = 2, 1
scatter_pad = get_spatial_pad()
gather_pad = get_temporal_pad()
scatter_pad = get_pad("spatial")
gather_pad = get_pad("temporal")
else:
scatter_dim, gather_dim = 1, 2
scatter_pad = get_temporal_pad()
gather_pad = get_spatial_pad()
scatter_pad = get_pad("temporal")
gather_pad = get_pad("spatial")
x = rearrange(x, "b (t s) d -> b t s d", t=t, s=s)
x = all_to_all_with_pad(
x,
get_sequence_parallel_group(),
self.parallel_manager.sp_group,
scatter_dim=scatter_dim,
gather_dim=gather_dim,
scatter_pad=scatter_pad,
@ -449,6 +437,24 @@ class STDiT3(PreTrainedModel):
for param in self.y_embedder.parameters():
param.requires_grad = False
# parallel
self.parallel_manager: ParallelManager = None
def enable_parallel(self, dp_size, sp_size, enable_cp):
# update cfg parallel
if enable_cp and sp_size % 2 == 0:
sp_size = sp_size // 2
cp_size = 2
else:
cp_size = 1
self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size)
for name, module in self.named_modules():
if "spatial_blocks" in name or "temporal_blocks" in name:
if hasattr(module, "parallel_manager"):
module.parallel_manager = self.parallel_manager
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
@ -501,9 +507,14 @@ class STDiT3(PreTrainedModel):
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
):
# === Split batch ===
if get_cfg_parallel_size() > 1:
if self.parallel_manager.cp_size > 1:
x, timestep, y, x_mask, mask = batch_func(
partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
x,
timestep,
y,
x_mask,
mask,
)
dtype = self.x_embedder.proj.weight.dtype
@ -547,14 +558,14 @@ class STDiT3(PreTrainedModel):
x = x + pos_emb
# shard over the sequence dim if sp is enabled
if enable_sequence_parallel():
set_temporal_pad(T)
set_spatial_pad(S)
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
if self.parallel_manager.sp_size > 1:
set_pad("temporal", T, self.parallel_manager.sp_group)
set_pad("spatial", S, self.parallel_manager.sp_group)
x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
T = x.shape[1]
x_mask_org = x_mask
x_mask = split_sequence(
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
@ -589,9 +600,9 @@ class STDiT3(PreTrainedModel):
all_timesteps=all_timesteps,
)
if enable_sequence_parallel():
if self.parallel_manager.sp_size > 1:
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad())
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
T, S = x.shape[1], x.shape[2]
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
x_mask = x_mask_org
@ -604,8 +615,8 @@ class STDiT3(PreTrainedModel):
x = x.to(torch.float32)
# === Gather Output ===
if get_cfg_parallel_size() > 1:
x = gather_sequence(x, get_data_parallel_group(), dim=0)
if self.parallel_manager.cp_size > 1:
x = gather_sequence(x, self.parallel_manager.cp_group, dim=0)
return x

View File

@ -0,0 +1,644 @@
# Adapted from Vchitect
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# Vchitect: https://github.com/Vchitect/Vchitect-2.0
# diffusers: https://github.com/huggingface/diffusers
# --------------------------------------------------------
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, deprecate, unscale_lora_layers
from diffusers.utils.torch_utils import maybe_allow_in_graph
from einops import rearrange
from torch import nn
from videosys.core.comm import gather_from_second_dim, set_pad, split_from_second_dim
from videosys.core.parallel_mgr import ParallelManager
from videosys.models.modules.attentions import VchitectAttention, VchitectAttnProcessor
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
@maybe_allow_in_graph
class JointTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
super().__init__()
self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
self.norm1 = AdaLayerNormZero(dim)
if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous(
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
)
elif context_norm_type == "ada_norm_zero":
self.norm1_context = AdaLayerNormZero(dim)
else:
raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)
processor = VchitectAttnProcessor()
self.attn = VchitectAttention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim // num_attention_heads,
heads=num_attention_heads,
out_dim=attention_head_dim,
context_pre_only=context_pre_only,
bias=True,
processor=processor,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
if not context_pre_only:
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
else:
self.norm2_context = None
self.ff_context = None
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
freqs_cis: torch.Tensor,
full_seqlen: int,
Frame: int,
timestep: int,
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
freqs_cis=freqs_cis,
full_seqlen=full_seqlen,
Frame=Frame,
timestep=timestep,
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
# Process attention outputs for the `encoder_hidden_states`.
if self.context_pre_only:
encoder_hidden_states = None
else:
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
context_ff_output = _chunked_feed_forward(
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
)
else:
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return encoder_hidden_states, hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class VchitectXLTransformerModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
The Transformer model introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings.
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
out_channels (`int`, defaults to 16): Number of output channels.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: int = 128,
patch_size: int = 2,
in_channels: int = 16,
num_layers: int = 18,
attention_head_dim: int = 64,
num_attention_heads: int = 18,
joint_attention_dim: int = 4096,
caption_projection_dim: int = 1152,
pooled_projection_dim: int = 2048,
out_channels: int = 16,
pos_embed_max_size: int = 96,
rope_scaling_factor: float = 1.0,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = PatchEmbed(
height=self.config.sample_size,
width=self.config.sample_size,
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.inner_dim,
context_pre_only=i == num_layers - 1,
)
for i in range(self.config.num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
# Video param
# self.scatter_dim_zero = Identity()
self.freqs_cis = VchitectXLTransformerModel.precompute_freqs_cis(
self.inner_dim // self.config.num_attention_heads,
1000000,
theta=1e6,
rope_scaling_factor=rope_scaling_factor, # todo max pos embeds
)
# self.vid_token = nn.Parameter(torch.empty(self.inner_dim))
# parallel
self.parallel_manager = None
def enable_parallel(self, dp_size, sp_size, enable_cp):
# update cfg parallel
if enable_cp and sp_size % 2 == 0:
sp_size = sp_size // 2
cp_size = 2
else:
cp_size = 1
self.parallel_manager = ParallelManager(dp_size, cp_size, sp_size)
for _, module in self.named_modules():
if hasattr(module, "parallel_manager"):
module.parallel_manager = self.parallel_manager
@staticmethod
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, rope_scaling_factor: float = 1.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float)
t = t / rope_scaling_factor
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
Parameters:
chunk_size (`int`, *optional*):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=`dim`.
dim (`int`, *optional*, defaults to `0`):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
"""
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# By default chunk size is 1
chunk_size = chunk_size or 1
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, VchitectAttnProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(
name: str, module: torch.nn.Module, processors: Dict[str, VchitectAttnProcessor]
):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[VchitectAttnProcessor, Dict[str, VchitectAttnProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, VchitectAttention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def patchify_and_embed(self, x):
B, F, C, H, W = x.size()
x = rearrange(x, "b f c h w -> (b f) c h w")
x = self.pos_embed(x) # [B L D]
return x, F, [(H, W)] * B
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`VchitectXLTransformerModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
height, width = hidden_states.shape[-2:]
batch_size = hidden_states.shape[0]
hidden_states, F_num, _ = self.patchify_and_embed(
hidden_states
) # takes care of adding positional embeddings too.
full_seq = batch_size * F_num
self.freqs_cis = self.freqs_cis.to(hidden_states.device)
freqs_cis = self.freqs_cis
# seq_length = hidden_states.size(1)
# freqs_cis = self.freqs_cis[:hidden_states.size(1)*F_num]
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if self.parallel_manager.sp_size > 1:
set_pad("temporal", F_num, self.parallel_manager.sp_group)
hidden_states = split_from_second_dim(hidden_states, batch_size, self.parallel_manager.sp_group)
cur_temb = temb.repeat(hidden_states.shape[0] // batch_size, 1)
else:
cur_temb = temb.repeat(F_num, 1)
for block_idx, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=cur_temb,
freqs_cis=freqs_cis,
full_seqlen=full_seq,
Frame=F_num,
timestep=timestep,
)
if self.parallel_manager.sp_size > 1:
hidden_states = gather_from_second_dim(hidden_states, batch_size, self.parallel_manager.sp_group)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# unpatchify
# hidden_states = hidden_states[:, :-1] #Drop the video token
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size
width = width // patch_size
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
return list(self.transformer_blocks)
@classmethod
def from_pretrained_temporal(cls, pretrained_model_path, torch_dtype, logger, subfolder=None, tp_size=1):
import json
import os
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
config_file = os.path.join(pretrained_model_path, "config.json")
with open(config_file, "r") as f:
config = json.load(f)
config["tp_size"] = tp_size
from safetensors.torch import load_file
model = cls.from_config(config)
# model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_files = [
os.path.join(pretrained_model_path, "diffusion_pytorch_model.bin"),
os.path.join(pretrained_model_path, "diffusion_pytorch_model.safetensors"),
]
model_file = None
for fp in model_files:
if os.path.exists(fp):
model_file = fp
if not model_file:
raise RuntimeError(f"{model_file} does not exist")
if not os.path.isfile(model_file):
raise RuntimeError(f"{model_file} does not exist")
state_dict = load_file(model_file, device="cpu")
m, u = model.load_state_dict(state_dict, strict=False)
model = model.to(torch_dtype)
params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
total_params = [p.numel() for n, p in model.named_parameters()]
if logger is not None:
logger.info(f"model_file: {model_file}")
logger.info(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
logger.info(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
logger.info(f"### Total Parameters: {sum(total_params) / 1e6} M")
return model

View File

@ -13,6 +13,7 @@ import math
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
@ -26,19 +27,17 @@ from videosys.models.transformers.cogvideox_transformer_3d import CogVideoXTrans
from videosys.schedulers.scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
from videosys.schedulers.scheduling_dpm_cogvideox import CogVideoXDPMScheduler
from videosys.utils.logging import logger
from videosys.utils.utils import save_video
from videosys.utils.utils import save_video, set_seed
class CogVideoXPABConfig(PABConfig):
def __init__(
self,
steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
spatial_range: int = 2,
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
@ -114,7 +113,7 @@ class CogVideoXConfig:
class CogVideoXPipeline(VideoSysPipeline):
_optional_components = ["text_encoder", "tokenizer"]
_optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
@ -158,9 +157,6 @@ class CogVideoXPipeline(VideoSysPipeline):
subfolder="scheduler",
)
# set eval and device
self.set_eval_and_device(self._device, vae, transformer)
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
@ -169,7 +165,7 @@ class CogVideoXPipeline(VideoSysPipeline):
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
self.set_eval_and_device(self._device, text_encoder)
self.set_eval_and_device(self._device, text_encoder, vae, transformer)
# vae tiling
if config.vae_tiling:
@ -187,6 +183,31 @@ class CogVideoXPipeline(VideoSysPipeline):
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
# parallel
self._set_parallel()
def _set_seed(self, seed):
if dist.get_world_size() == 1:
set_seed(seed)
else:
set_seed(seed, self.transformer.parallel_manager.dp_rank)
def _set_parallel(
self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
):
# init sequence parallel
if sp_size is None:
sp_size = dist.get_world_size()
dp_size = 1
else:
assert (
dist.get_world_size() % sp_size == 0
), f"world_size {dist.get_world_size()} must be divisible by sp_size"
dp_size = dist.get_world_size() // sp_size
# transformer parallel
self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
@ -474,6 +495,7 @@ class CogVideoXPipeline(VideoSysPipeline):
num_frames: int = 49,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
seed: int = -1,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
num_videos_per_prompt: int = 1,
@ -489,7 +511,6 @@ class CogVideoXPipeline(VideoSysPipeline):
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
verbose=True,
) -> Union[VideoSysPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@ -572,6 +593,7 @@ class CogVideoXPipeline(VideoSysPipeline):
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
update_steps(num_inference_steps)
self._set_seed(seed)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@ -653,11 +675,10 @@ class CogVideoXPipeline(VideoSysPipeline):
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
# with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in progress_wrap(list(enumerate(timesteps))):
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@ -674,7 +695,6 @@ class CogVideoXPipeline(VideoSysPipeline):
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
all_timesteps=timesteps,
)[0]
noise_pred = noise_pred.float()
@ -713,8 +733,8 @@ class CogVideoXPipeline(VideoSysPipeline):
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
# progress_bar.update()
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latent":
video = self.decode_latents(latents)

View File

@ -29,13 +29,12 @@ from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
from videosys.models.transformers.latte_transformer_3d import LatteT2V
from videosys.utils.logging import logger
from videosys.utils.utils import save_video
from videosys.utils.utils import save_video, set_seed
class LattePABConfig(PABConfig):
def __init__(
self,
steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 800],
spatial_range: int = 2,
@ -62,7 +61,6 @@ class LattePABConfig(PABConfig):
},
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
@ -188,7 +186,7 @@ class LattePipeline(VideoSysPipeline):
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder"]
_optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
@ -238,9 +236,6 @@ class LattePipeline(VideoSysPipeline):
if config.enable_pab:
set_pab_manager(config.pab_config)
# set eval and device
self.set_eval_and_device(device, vae, transformer)
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
@ -249,11 +244,36 @@ class LattePipeline(VideoSysPipeline):
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
self.set_eval_and_device(device, text_encoder)
self.set_eval_and_device(device, text_encoder, vae, transformer)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# parallel
self._set_parallel()
def _set_seed(self, seed):
if dist.get_world_size() == 1:
set_seed(seed)
else:
set_seed(seed, self.transformer.parallel_manager.dp_rank)
def _set_parallel(
self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
):
# init sequence parallel
if sp_size is None:
sp_size = dist.get_world_size()
dp_size = 1
else:
assert (
dist.get_world_size() % sp_size == 0
), f"world_size {dist.get_world_size()} must be divisible by sp_size"
dp_size = dist.get_world_size() // sp_size
# transformer parallel
self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
@ -658,6 +678,7 @@ class LattePipeline(VideoSysPipeline):
negative_prompt: str = "",
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
seed: int = -1,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@ -745,6 +766,7 @@ class LattePipeline(VideoSysPipeline):
width = 512
update_steps(num_inference_steps)
self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds)
self._set_seed(seed)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):

View File

@ -2,19 +2,22 @@ import html
import json
import os
import re
import urllib.parse as ul
from typing import Optional, Tuple, Union
import ftfy
import torch
import torch.distributed as dist
from bs4 import BeautifulSoup
from diffusers.models import AutoencoderKL
from transformers import AutoTokenizer, T5EncoderModel
from videosys.core.pab_mgr import PABConfig, set_pab_manager
from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
from videosys.models.autoencoders.autoencoder_kl_open_sora import OpenSoraVAE_V1_2
from videosys.models.transformers.open_sora_transformer_3d import STDiT3
from videosys.schedulers.scheduling_rflow_open_sora import RFLOW
from videosys.utils.utils import save_video
from videosys.utils.utils import save_video, set_seed
from .data_process import get_image_size, get_num_frames, prepare_multi_resolution_info, read_from_path
@ -29,7 +32,6 @@ BAD_PUNCT_REGEX = re.compile(
class OpenSoraPABConfig(PABConfig):
def __init__(
self,
steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [450, 930],
spatial_range: int = 2,
@ -52,7 +54,6 @@ class OpenSoraPABConfig(PABConfig):
},
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
@ -187,11 +188,7 @@ class OpenSoraPipeline(VideoSysPipeline):
bad_punct_regex = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = [
"text_encoder",
"tokenizer",
]
_optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
@ -234,9 +231,6 @@ class OpenSoraPipeline(VideoSysPipeline):
if config.enable_pab:
set_pab_manager(config.pab_config)
# set eval and device
self.set_eval_and_device(device, vae, transformer)
self.register_modules(
text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, tokenizer=tokenizer
)
@ -245,7 +239,32 @@ class OpenSoraPipeline(VideoSysPipeline):
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
self.set_eval_and_device(self._device, text_encoder)
self.set_eval_and_device(self._device, vae, transformer, text_encoder)
# parallel
self._set_parallel()
def _set_seed(self, seed):
if dist.get_world_size() == 1:
set_seed(seed)
else:
set_seed(seed, self.transformer.parallel_manager.dp_rank)
def _set_parallel(
self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
):
# init sequence parallel
if sp_size is None:
sp_size = dist.get_world_size()
dp_size = 1
else:
assert (
dist.get_world_size() % sp_size == 0
), f"world_size {dist.get_world_size()} must be divisible by sp_size"
dp_size = dist.get_world_size() // sp_size
# transformer parallel
self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
def get_text_embeddings(self, texts):
text_tokens_and_mask = self.tokenizer(
@ -283,10 +302,6 @@ class OpenSoraPipeline(VideoSysPipeline):
return text.strip()
def _clean_caption(self, caption):
import urllib.parse as ul
from bs4 import BeautifulSoup
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
@ -418,6 +433,7 @@ class OpenSoraPipeline(VideoSysPipeline):
loop: int = 1,
llm_refine: bool = False,
negative_prompt: str = "",
seed: int = -1,
ms: Optional[str] = "",
refs: Optional[str] = "",
aes: float = 6.5,
@ -460,10 +476,6 @@ class OpenSoraPipeline(VideoSysPipeline):
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
@ -508,6 +520,8 @@ class OpenSoraPipeline(VideoSysPipeline):
fps = 24
image_size = get_image_size(resolution, aspect_ratio)
num_frames = get_num_frames(num_frames)
self._set_seed(seed)
update_steps(self._config.num_sampling_steps)
# == prepare batch prompts ==
batch_prompts = [prompt]
@ -621,7 +635,7 @@ class OpenSoraPipeline(VideoSysPipeline):
progress=verbose,
mask=masks,
)
samples = self.vae.decode(samples.to(self._dtype), num_frames=num_frames)
samples = self.vae(samples.to(self._dtype), decode_only=True, num_frames=num_frames)
video_clips.append(samples)
for i in range(1, loop):

View File

@ -1,3 +1,8 @@
from .pipeline_open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
from .pipeline_open_sora_plan import (
OpenSoraPlanConfig,
OpenSoraPlanPipeline,
OpenSoraPlanV110PABConfig,
OpenSoraPlanV120PABConfig,
)
__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanPABConfig"]
__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig"]

View File

@ -20,39 +20,27 @@ import torch.distributed as dist
import tqdm
from bs4 import BeautifulSoup
from diffusers.models import AutoencoderKL, Transformer2DModel
from diffusers.schedulers import PNDMScheduler
from diffusers.schedulers import EulerAncestralDiscreteScheduler, PNDMScheduler
from diffusers.utils.torch_utils import randn_tensor
from transformers import T5EncoderModel, T5Tokenizer
from transformers import AutoTokenizer, MT5EncoderModel, T5EncoderModel, T5Tokenizer
from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
from videosys.models.autoencoders.autoencoder_kl_open_sora_plan_v110 import (
CausalVAEModelWrapper as CausalVAEModelWrapperV110,
)
from videosys.models.autoencoders.autoencoder_kl_open_sora_plan_v120 import (
CausalVAEModelWrapper as CausalVAEModelWrapperV120,
)
from videosys.models.transformers.open_sora_plan_v110_transformer_3d import LatteT2V
from videosys.models.transformers.open_sora_plan_v120_transformer_3d import OpenSoraT2V
from videosys.utils.logging import logger
from videosys.utils.utils import save_video
from ...models.autoencoders.autoencoder_kl_open_sora_plan import ae_stride_config, getae_wrapper
from ...models.transformers.open_sora_plan_transformer_3d import LatteT2V
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import PixArtAlphaPipeline
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
>>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
>>> # Enable memory optimizations.
>>> pipe.enable_model_cpu_offload()
>>> prompt = "A small cactus with a happy face in the Sahara desert."
>>> image = pipe(prompt).images[0]
```
"""
from videosys.utils.utils import save_video, set_seed
class OpenSoraPlanPABConfig(PABConfig):
class OpenSoraPlanV110PABConfig(PABConfig):
def __init__(
self,
steps: int = 150,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
spatial_range: int = 2,
@ -97,7 +85,6 @@ class OpenSoraPlanPABConfig(PABConfig):
},
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
@ -113,6 +100,26 @@ class OpenSoraPlanPABConfig(PABConfig):
)
class OpenSoraPlanV120PABConfig(PABConfig):
def __init__(
self,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
spatial_range: int = 2,
cross_broadcast: bool = True,
cross_threshold: list = [100, 850],
cross_range: int = 6,
):
super().__init__(
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
cross_broadcast=cross_broadcast,
cross_threshold=cross_threshold,
cross_range=cross_range,
)
class OpenSoraPlanConfig:
"""
This config is to instantiate a `OpenSoraPlanPipeline` class for video generation.
@ -163,10 +170,10 @@ class OpenSoraPlanConfig:
def __init__(
self,
transformer: str = "LanguageBind/Open-Sora-Plan-v1.1.0",
ae: str = "CausalVAEModel_4x8x8",
text_encoder: str = "DeepFloyd/t5-v1_1-xxl",
num_frames: int = 65,
version: str = "v120",
transformer_type: str = "29x480p",
transformer: str = None,
text_encoder: str = None,
# ======= distributed ========
num_gpus: int = 1,
# ======= memory =======
@ -175,15 +182,32 @@ class OpenSoraPlanConfig:
tile_overlap_factor: float = 0.25,
# ======= pab ========
enable_pab: bool = False,
pab_config: PABConfig = OpenSoraPlanPABConfig(),
pab_config: PABConfig = None,
):
self.pipeline_cls = OpenSoraPlanPipeline
self.ae = ae
self.text_encoder = text_encoder
self.transformer = transformer
assert num_frames in [65, 221], "num_frames must be one of [65, 221]"
self.num_frames = num_frames
self.version = f"{num_frames}x512x512"
# get version
assert version in ["v110", "v120"], f"Unknown Open-Sora-Plan version: {version}"
self.version = version
self.transformer_type = transformer_type
# check transformer_type
if version == "v110":
assert transformer_type in ["65x512x512", "221x512x512"]
elif version == "v120":
assert transformer_type in ["93x480p", "93x720p", "29x480p", "29x720p"]
self.num_frames = int(transformer_type.split("x")[0])
# set default values according to version
if version == "v110":
transformer_default = "LanguageBind/Open-Sora-Plan-v1.1.0"
text_encoder_default = "DeepFloyd/t5-v1_1-xxl"
elif version == "v120":
transformer_default = "LanguageBind/Open-Sora-Plan-v1.2.0"
text_encoder_default = "google/mt5-xxl"
self.text_encoder = text_encoder or text_encoder_default
self.transformer = transformer or transformer_default
# ======= distributed ========
self.num_gpus = num_gpus
# ======= memory ========
@ -192,7 +216,13 @@ class OpenSoraPlanConfig:
self.tile_overlap_factor = tile_overlap_factor
# ======= pab ========
self.enable_pab = enable_pab
self.pab_config = pab_config
if self.enable_pab and pab_config is None:
if version == "v110":
self.pab_config = OpenSoraPlanV110PABConfig()
elif version == "v120":
self.pab_config = OpenSoraPlanV120PABConfig()
else:
self.pab_config = pab_config
class OpenSoraPlanPipeline(VideoSysPipeline):
@ -221,7 +251,7 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder"]
_optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
@ -240,26 +270,57 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
# init
if tokenizer is None:
tokenizer = T5Tokenizer.from_pretrained(config.text_encoder)
if config.version == "v110":
tokenizer = T5Tokenizer.from_pretrained(config.text_encoder)
elif config.version == "v120":
tokenizer = AutoTokenizer.from_pretrained(config.text_encoder)
if text_encoder is None:
text_encoder = T5EncoderModel.from_pretrained(config.text_encoder, torch_dtype=torch.float16)
if config.version == "v110":
text_encoder = T5EncoderModel.from_pretrained(config.text_encoder, torch_dtype=dtype)
elif config.version == "v120":
text_encoder = MT5EncoderModel.from_pretrained(
config.text_encoder, low_cpu_mem_usage=True, torch_dtype=dtype
)
if vae is None:
vae = getae_wrapper(config.ae)(config.transformer, subfolder="vae").to(dtype=dtype)
if config.version == "v110":
vae = CausalVAEModelWrapperV110(config.transformer, subfolder="vae").to(dtype=dtype)
elif config.version == "v120":
vae = CausalVAEModelWrapperV120(config.transformer, subfolder="vae").to(dtype=dtype)
if transformer is None:
transformer = LatteT2V.from_pretrained(config.transformer, subfolder=config.version, torch_dtype=dtype)
if config.version == "v110":
transformer = LatteT2V.from_pretrained(
config.transformer, subfolder=config.transformer_type, torch_dtype=dtype
)
elif config.version == "v120":
transformer = OpenSoraT2V.from_pretrained(
config.transformer, subfolder=config.transformer_type, torch_dtype=dtype
)
if scheduler is None:
scheduler = PNDMScheduler()
if config.version == "v110":
scheduler = PNDMScheduler()
elif config.version == "v120":
scheduler = EulerAncestralDiscreteScheduler()
# setting
if config.enable_tiling:
vae.vae.enable_tiling()
vae.vae.tile_overlap_factor = config.tile_overlap_factor
vae.vae_scale_factor = ae_stride_config[config.ae]
vae.vae.tile_sample_min_size = 512
vae.vae.tile_latent_min_size = 64
vae.vae.tile_sample_min_size_t = 29
vae.vae.tile_latent_min_size_t = 8
# if low_mem:
# vae.vae.tile_sample_min_size = 256
# vae.vae.tile_latent_min_size = 32
# vae.vae.tile_sample_min_size_t = 29
# vae.vae.tile_latent_min_size_t = 8
vae.vae_scale_factor = [4, 8, 8]
transformer.force_images = False
# set eval and device
self.set_eval_and_device(device, vae, transformer)
# pab
if config.enable_pab:
set_pab_manager(config.pab_config)
@ -272,10 +333,35 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
self.set_eval_and_device(device, text_encoder)
self.set_eval_and_device(device, text_encoder, vae, transformer)
# self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# parallel
self._set_parallel()
def _set_seed(self, seed):
if dist.get_world_size() == 1:
set_seed(seed)
else:
set_seed(seed, self.transformer.parallel_manager.dp_rank)
def _set_parallel(
self, dp_size: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: Optional[bool] = False
):
# init sequence parallel
if sp_size is None:
sp_size = dist.get_world_size()
dp_size = 1
else:
assert (
dist.get_world_size() % sp_size == 0
), f"world_size {dist.get_world_size()} must be divisible by sp_size"
dp_size = dist.get_world_size() // sp_size
# transformer parallel
self.transformer.enable_parallel(dp_size, sp_size, enable_cp)
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
@ -449,6 +535,143 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
return prompt_embeds, negative_prompt_embeds
def encode_prompt_v120(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
max_sequence_length: int = 120,
**kwargs,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
PixArt-Alpha, this should be "".
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
string.
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
"""
if device is None:
device = getattr(self, "_execution_device", None) or getattr(self, "device", None) or torch.device("cuda")
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper.
max_length = max_sequence_length
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {max_length} tokens: {removed_text}"
)
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0]
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@ -476,6 +699,8 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
callback_steps,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@ -519,6 +744,12 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
@ -685,10 +916,13 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
seed: int = -1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@ -697,6 +931,7 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
mask_feature: bool = True,
enable_temporal_attentions: bool = True,
verbose: bool = True,
max_sequence_length: int = 512,
) -> Union[VideoSysPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@ -768,11 +1003,22 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
returned where the first element is a list with the generated images
"""
# 1. Check inputs. Raise error if not correct
height = 512
width = 512
height = self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1]
width = self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2]
num_frames = self._config.num_frames
update_steps(num_inference_steps)
self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds)
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)
self._set_seed(seed)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
@ -782,7 +1028,7 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
device = getattr(self, "_execution_device", None) or getattr(self, "device", None) or torch.device("cuda")
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@ -790,23 +1036,50 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
clean_caption=clean_caption,
mask_feature=mask_feature,
)
if self._config.version == "v110":
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
clean_caption=clean_caption,
mask_feature=mask_feature,
)
prompt_attention_mask = None
negative_prompt_attention_mask = None
else:
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt_v120(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
if prompt_attention_mask is not None and negative_prompt_attention_mask is not None:
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
if self._config.version == "v110":
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
else:
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
@ -827,15 +1100,10 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
# if self.transformer.config.sample_size == 128:
# resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
# aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
# resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
# aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
# added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
for i, t in progress_wrap(list(enumerate(timesteps))):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@ -856,9 +1124,15 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
if prompt_embeds.ndim == 3:
if prompt_embeds is not None and prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
if prompt_attention_mask is not None and prompt_attention_mask.ndim == 2:
prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l
# prepare attention_mask.
# b c t h w -> b t h w
attention_mask = torch.ones_like(latent_model_input)[:, 0]
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
@ -867,6 +1141,8 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
enable_temporal_attentions=enable_temporal_attentions,
attention_mask=attention_mask,
encoder_attention_mask=prompt_attention_mask,
return_dict=False,
)[0]
@ -906,10 +1182,57 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
return VideoSysPipelineOutput(video=video)
def decode_latents(self, latents):
video = self.vae.decode(latents) # b t c h w
# b t c h w -> b t h w c
video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous()
video = self.vae(latents.to(self.vae.vae.dtype))
video = (
((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous()
) # b t h w c
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
return video
def save_video(self, video, output_path):
save_video(video, output_path, fps=24)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps

View File

@ -0,0 +1,3 @@
from .pipeline_vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline
__all__ = ["VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig"]

File diff suppressed because it is too large Load Diff

View File

12
videosys/utils/test.py Normal file
View File

@ -0,0 +1,12 @@
import functools
import torch
def empty_cache(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch.cuda.empty_cache()
return func(*args, **kwargs)
return wrapper

View File

@ -16,7 +16,17 @@ def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
p.requires_grad = flag
def set_seed(seed):
def set_seed(seed, dp_rank=None):
if seed == -1:
seed = random.randint(0, 1000000)
if dp_rank is not None:
seed = torch.tensor(seed, dtype=torch.int64).cuda()
if dist.get_world_size() > 1:
dist.broadcast(seed, 0)
seed = seed + dp_rank
seed = int(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)