first commit

This commit is contained in:
LiewFeng 2024-12-06 20:33:30 +08:00
parent 5576dc4575
commit 1cf683a0eb
112 changed files with 31014 additions and 2 deletions

37
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,37 @@
## Coding Standards
### Unit Tests
We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests.
To set up the environment for unit testing, first change your current directory to the root directory of your local ColossalAI repository, then run
```bash
pip install -r requirements/requirements-test.txt
```
If you encounter an error telling "Could not find a version that satisfies the requirement fbgemm-gpu==0.2.0", please downgrade your python version to 3.8 or 3.9 and try again.
If you only want to run CPU tests, you can run
```bash
pytest -m cpu tests/
```
If you have 8 GPUs on your machine, you can run the full test
```bash
pytest tests/
```
If you do not have 8 GPUs on your machine, do not worry. Unit testing will be automatically conducted when you put up a pull request to the main branch.
### Code Style
We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below.
```shell
# these commands are executed under the Colossal-AI directory
pip install pre-commit
pre-commit install
```
Code format checking will be automatically executed when you commit your changes.

2140
LICENSE Normal file

File diff suppressed because it is too large Load Diff

136
README.md
View File

@ -1,2 +1,134 @@
# TeaCache
Coming soon.
# Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model
<div class="is-size-5 publication-authors", align="center",>
<span class="author-block">
<a href="https://liewfeng.github.io" target="_blank">Feng Liu</a><sup>1</sup><sup>*</sup>,&nbsp;
</span>
<span class="author-block">
<a href="https://scholar.google.com.hk/citations?user=ZO3OQ-8AAAAJ" target="_blank">Shiwei Zhang</a><sup>2</sup>,&nbsp;
</span>
<span class="author-block">
<a href="https://jeffwang987.github.io" target="_blank">Xiaofeng Wang</a><sup>1,3</sup>,&nbsp;
</span>
<span class="author-block">
<a href="https://weilllllls.github.io" target="_blank">Yujie Wei</a><sup>4</sup>,&nbsp;
</span>
<span class="author-block">
<a href="http://haonanqiu.com" target="_blank">Haonan Qiu</a><sup>5</sup>
</span>
<br>
<span class="author-block">
<a href="https://callsys.github.io/zhaoyuzhong.github.io-main" target="_blank">Yuzhong Zhao</a><sup>1</sup>,&nbsp;
</span>
<span class="author-block">
<a href="https://scholar.google.com.sg/citations?user=16RDSEUAAAAJ" target="_blank">Yingya Zhang</a><sup>2</sup>,&nbsp;
</span>
<span class="author-block">
<a href="https://scholar.google.com/citations?user=tjEfgsEAAAAJ&hl=en&oi=ao" target="_blank">Qixiang Ye</a><sup>1</sup>,&nbsp;
</span>
<span class="author-block">
<a href="https://scholar.google.com/citations?user=0IKavloAAAAJ&hl=en&oi=ao" target="_blank">Fang Wan</a><sup>1</sup><sup></sup>
</span>
</div>
<div class="is-size-5 publication-authors", align="center">
<span class="author-block"><sup>1</sup>University of Chinese Academy of Sciences,&nbsp;</span>
<span class="author-block"><sup>2</sup>Alibaba Group</span>
<br>
<span class="author-block"><sup>3</sup>Institute of Automation, Chinese Academy of Sciences</span>
<br>
<span class="author-block"><sup>4</sup>Fudan University,&nbsp;</span>
<span class="author-block"><sup>5</sup>Nanyang Technological University</span>
</div>
<div class="is-size-5 publication-authors", align="center">
(* Work was done during internship at Alibaba Group. † Corresponding author.)
</div>
<div class="is-size-5 publication-authors", align="center">
<a href="https://arxiv.org/abs/2411.19108">Paper</a> |
<a href="https://github.com/LiewFeng/TeaCache/">Project Page</a>
</div>
![visualization](./assets/tisser.png)
## Introduction
We introduce Timestep Embedding Aware Cache (TeaCache), a training-free caching approach that estimates and leverages the fluctuating differences among model outputs across timesteps. For more details and visual results, please visit our [project page](https://github.com/LiewFeng/TeaCache).
## Installation
Prerequisites:
- Python >= 3.10
- PyTorch >= 1.13 (We recommend to use a >2.0 version)
- CUDA >= 11.6
We strongly recommend using Anaconda to create a new environment (Python >= 3.10) to run our examples:
```shell
conda create -n teacache python=3.10 -y
conda activate teacache
```
Install VideoSys:
```shell
git clone https://github.com/LiewFeng/TeaCache
cd TeaCache
pip install -e .
```
## Evaluation of TeaCache
We first generate videos according to VBench's prompts.
And then calculate Vbench, PSNR, LPIPS and SSIM based on the video generated.
1. Generate video
```
cd eval/teacache
python experiments/latte.py
python experiments/opensora.py
python experiments/open_sora_plan.py
```
2. Calculate Vbench score
```
# vbench is calculated independently
# get scores for all metrics
python vbench/run_vbench.py --video_path aaa --save_path bbb
# calculate final score
python vbench/cal_vbench.py --score_dir bbb
```
3. Calculate other metrics
```
# these metrics are calculated compared with original model
# gt video is the video of original model
# generated video is our methods's results
python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb
```
## Citation
```
@misc{liu2024timestep,
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
author={Feng Liu and Shiwei Zhang and Xiaofeng Wang and Yujie Wei and Haonan Qiu and Yuzhong Zhao and Yingya Zhang and Qixiang Ye and Fang Wan},
year={2024},
eprint={2411.19108},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2411.19108}
}
```
## Acknowledgement
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys). Thanks for their contributions!

BIN
assets/tisser.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.4 MiB

View File

@ -0,0 +1,6 @@
Common metrics
Include LPIPS, PSNR and SSIM.
The code is adapted from [common_metrics_on_video_quality
](https://github.com/JunyaoHu/common_metrics_on_video_quality).

View File

View File

@ -0,0 +1,97 @@
import lpips
import numpy as np
import torch
spatial = True # Return a spatial map of perceptual distance.
# Linearly calibrated models (LPIPS)
loss_fn = lpips.LPIPS(net="alex", spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
def trans(x):
# if greyscale images add channel
if x.shape[-3] == 1:
x = x.repeat(1, 1, 3, 1, 1)
# value range [0, 1] -> [-1, 1]
x = x * 2 - 1
return x
def calculate_lpips(videos1, videos2, device):
# image should be RGB, IMPORTANT: normalized to [-1,1]
assert videos1.shape == videos2.shape
# videos [batch_size, timestamps, channel, h, w]
# support grayscale input, if grayscale -> channel*3
# value range [0, 1] -> [-1, 1]
videos1 = trans(videos1)
videos2 = trans(videos2)
lpips_results = []
for video_num in range(videos1.shape[0]):
# get a video
# video [timestamps, channel, h, w]
video1 = videos1[video_num]
video2 = videos2[video_num]
lpips_results_of_a_video = []
for clip_timestamp in range(len(video1)):
# get a img
# img [timestamps[x], channel, h, w]
# img [channel, h, w] tensor
img1 = video1[clip_timestamp].unsqueeze(0).to(device)
img2 = video2[clip_timestamp].unsqueeze(0).to(device)
loss_fn.to(device)
# calculate lpips of a video
lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
lpips_results.append(lpips_results_of_a_video)
lpips_results = np.array(lpips_results)
lpips = {}
lpips_std = {}
for clip_timestamp in range(len(video1)):
lpips[clip_timestamp] = np.mean(lpips_results[:, clip_timestamp])
lpips_std[clip_timestamp] = np.std(lpips_results[:, clip_timestamp])
result = {
"value": lpips,
"value_std": lpips_std,
"video_setting": video1.shape,
"video_setting_name": "time, channel, heigth, width",
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
device = torch.device("cuda")
# device = torch.device("cpu")
import json
result = calculate_lpips(videos1, videos2, device)
print(json.dumps(result, indent=4))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,90 @@
import math
import numpy as np
import torch
def img_psnr(img1, img2):
# [0,1]
# compute mse
# mse = np.mean((img1-img2)**2)
mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
# compute psnr
if mse < 1e-10:
return 100
psnr = 20 * math.log10(1 / math.sqrt(mse))
return psnr
def trans(x):
return x
def calculate_psnr(videos1, videos2):
# videos [batch_size, timestamps, channel, h, w]
assert videos1.shape == videos2.shape
videos1 = trans(videos1)
videos2 = trans(videos2)
psnr_results = []
for video_num in range(videos1.shape[0]):
# get a video
# video [timestamps, channel, h, w]
video1 = videos1[video_num]
video2 = videos2[video_num]
psnr_results_of_a_video = []
for clip_timestamp in range(len(video1)):
# get a img
# img [timestamps[x], channel, h, w]
# img [channel, h, w] numpy
img1 = video1[clip_timestamp].numpy()
img2 = video2[clip_timestamp].numpy()
# calculate psnr of a video
psnr_results_of_a_video.append(img_psnr(img1, img2))
psnr_results.append(psnr_results_of_a_video)
psnr_results = np.array(psnr_results)
psnr = {}
psnr_std = {}
for clip_timestamp in range(len(video1)):
psnr[clip_timestamp] = np.mean(psnr_results[:, clip_timestamp])
psnr_std[clip_timestamp] = np.std(psnr_results[:, clip_timestamp])
result = {
"value": psnr,
"value_std": psnr_std,
"video_setting": video1.shape,
"video_setting_name": "time, channel, heigth, width",
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
import json
result = calculate_psnr(videos1, videos2)
print(json.dumps(result, indent=4))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,116 @@
import cv2
import numpy as np
import torch
def ssim(img1, img2):
C1 = 0.01**2
C2 = 0.03**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim_function(img1, img2):
# [0,1]
# ssim is the only metric extremely sensitive to gray being compared to b/w
if not img1.shape == img2.shape:
raise ValueError("Input images must have the same dimensions.")
if img1.ndim == 2:
return ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[0] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1[i], img2[i]))
return np.array(ssims).mean()
elif img1.shape[0] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError("Wrong input image dimensions.")
def trans(x):
return x
def calculate_ssim(videos1, videos2):
# videos [batch_size, timestamps, channel, h, w]
assert videos1.shape == videos2.shape
videos1 = trans(videos1)
videos2 = trans(videos2)
ssim_results = []
for video_num in range(videos1.shape[0]):
# get a video
# video [timestamps, channel, h, w]
video1 = videos1[video_num]
video2 = videos2[video_num]
ssim_results_of_a_video = []
for clip_timestamp in range(len(video1)):
# get a img
# img [timestamps[x], channel, h, w]
# img [channel, h, w] numpy
img1 = video1[clip_timestamp].numpy()
img2 = video2[clip_timestamp].numpy()
# calculate ssim of a video
ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
ssim_results.append(ssim_results_of_a_video)
ssim_results = np.array(ssim_results)
ssim = {}
ssim_std = {}
for clip_timestamp in range(len(video1)):
ssim[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp])
ssim_std[clip_timestamp] = np.std(ssim_results[:, clip_timestamp])
result = {
"value": ssim,
"value_std": ssim_std,
"video_setting": video1.shape,
"video_setting_name": "time, channel, heigth, width",
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
torch.device("cuda")
import json
result = calculate_ssim(videos1, videos2)
print(json.dumps(result, indent=4))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,160 @@
import argparse
import os
import imageio
import torch
import torchvision.transforms.functional as F
import tqdm
from calculate_lpips import calculate_lpips
from calculate_psnr import calculate_psnr
from calculate_ssim import calculate_ssim
def load_videos(directory, video_ids, file_extension):
videos = []
for video_id in video_ids:
video_path = os.path.join(directory, f"{video_id}.{file_extension}")
if os.path.exists(video_path):
video = load_video(video_path) # Define load_video based on how videos are stored
videos.append(video)
else:
raise ValueError(f"Video {video_id}.{file_extension} not found in {directory}")
return videos
def load_video(video_path):
"""
Load a video from the given path and convert it to a PyTorch tensor.
"""
# Read the video using imageio
reader = imageio.get_reader(video_path, "ffmpeg")
# Extract frames and convert to a list of tensors
frames = []
for frame in reader:
# Convert the frame to a tensor and permute the dimensions to match (C, H, W)
frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
frames.append(frame_tensor)
# Stack the list of tensors into a single tensor with shape (T, C, H, W)
video_tensor = torch.stack(frames)
return video_tensor
def resize_video(video, target_height, target_width):
resized_frames = []
for frame in video:
resized_frame = F.resize(frame, [target_height, target_width])
resized_frames.append(resized_frame)
return torch.stack(resized_frames)
def preprocess_eval_video(eval_video, generated_video_shape):
T_gen, _, H_gen, W_gen = generated_video_shape
T_eval, _, H_eval, W_eval = eval_video.shape
if T_eval < T_gen:
raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
if H_eval < H_gen or W_eval < W_gen:
# Resize the video maintaining the aspect ratio
resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
eval_video = resize_video(eval_video, resize_height, resize_width)
# Recalculate the dimensions
T_eval, _, H_eval, W_eval = eval_video.shape
# Center crop
start_h = (H_eval - H_gen) // 2
start_w = (W_eval - W_gen) // 2
cropped_video = eval_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
return cropped_video
def main(args):
device = "cuda"
gt_video_dir = args.gt_video_dir
generated_video_dir = args.generated_video_dir
video_ids = []
file_extension = "mp4"
for f in os.listdir(generated_video_dir):
if f.endswith(f".{file_extension}"):
video_ids.append(f.replace(f".{file_extension}", ""))
if not video_ids:
raise ValueError("No videos found in the generated video dataset. Exiting.")
print(f"Find {len(video_ids)} videos")
prompt_interval = 1
batch_size = 16
calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
lpips_results = []
psnr_results = []
ssim_results = []
total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
for idx, video_id in enumerate(tqdm.tqdm(range(total_len))):
gt_videos_tensor = []
generated_videos_tensor = []
for i in range(batch_size):
video_idx = idx * batch_size + i
if video_idx >= len(video_ids):
break
video_id = video_ids[video_idx]
generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.{file_extension}"))
generated_videos_tensor.append(generated_video)
eval_video = load_video(os.path.join(gt_video_dir, f"{video_id}.{file_extension}"))
gt_videos_tensor.append(eval_video)
gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
if calculate_lpips_flag:
result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
result = result["value"].values()
result = sum(result) / len(result)
lpips_results.append(result)
if calculate_psnr_flag:
result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
result = result["value"].values()
result = sum(result) / len(result)
psnr_results.append(result)
if calculate_ssim_flag:
result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
result = result["value"].values()
result = sum(result) / len(result)
ssim_results.append(result)
if (idx + 1) % prompt_interval == 0:
out_str = ""
for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
result = sum(results) / len(results)
out_str += f"{name}: {result:.4f}, "
print(f"Processed {idx + 1} videos. {out_str[:-2]}")
out_str = ""
for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
result = sum(results) / len(results)
out_str += f"{name}: {result:.4f}, "
out_str = out_str[:-2]
# save
with open(f"./{os.path.basename(generated_video_dir)}.txt", "w+") as f:
f.write(out_str)
print(f"Processed all videos. {out_str}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--gt_video_dir", type=str)
parser.add_argument("--generated_video_dir", type=str)
args = parser.parse_args()
main(args)

View File

View File

@ -0,0 +1,532 @@
from utils import generate_func, read_prompt_list
from videosys import LatteConfig, VideoSysEngine
import torch
from einops import rearrange, repeat
from torch import nn
import numpy as np
from typing import Any, Dict, Optional, Tuple
from videosys.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_size,
get_data_parallel_group,
get_sequence_parallel_group,
)
def teacache_forward(
self,
hidden_states: torch.Tensor,
timestep: Optional[torch.LongTensor] = None,
all_timesteps=None,
encoder_hidden_states: Optional[torch.Tensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_image_num: int = 0,
enable_temporal_attentions: bool = True,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# 0. Split batch for data parallelism
if get_cfg_parallel_size() > 1:
(
hidden_states,
timestep,
encoder_hidden_states,
added_cond_kwargs,
class_labels,
attention_mask,
encoder_attention_mask,
) = batch_func(
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
hidden_states,
timestep,
encoder_hidden_states,
added_cond_kwargs,
class_labels,
attention_mask,
encoder_attention_mask,
)
input_batch_size, c, frame, h, w = hidden_states.shape
frame = frame - use_image_num
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
org_timestep = timestep
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous()
elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
encoder_attention_mask_video = repeat(
encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame
).contiguous()
encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1)
encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1)
# Retrieve lora scale.
cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input
if self.is_input_patches: # here
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
num_patches = height * width
hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
# batch_size = hidden_states.shape[0]
batch_size = input_batch_size
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
if use_image_num != 0 and self.training:
encoder_hidden_states_video = encoder_hidden_states[:, :1, ...]
encoder_hidden_states_video = repeat(
encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame
).contiguous()
encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...]
encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous()
else:
encoder_hidden_states_spatial = repeat(
encoder_hidden_states, "b t d -> (b f) t d", f=frame
).contiguous()
# prepare timesteps for spatial and temporal block
timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
if self.enable_teacache:
inp = hidden_states.clone()
batch_size = inp.shape[0]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.transformer_blocks[0].scale_shift_table[None] + timestep_spatial.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
modulated_inp = self.transformer_blocks[0].norm1(inp) * (1 + scale_msa) + shift_msa
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [-2.46434137e+03, 3.08044764e+02, 8.07447667e+01, -4.11385132e+00, 1.11001402e-01]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
if enable_sequence_parallel():
set_temporal_pad(frame + use_image_num)
set_spatial_pad(num_patches)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
)
else:
temp_pos_embed = self.temp_pos_embed
hidden_states_origin = hidden_states.clone().detach()
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
if enable_temporal_attentions:
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0: # image-video joitn training
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
if i == 0:
hidden_states_video = hidden_states_video + temp_pos_embed
hidden_states_video = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
if i == 0:
hidden_states = hidden_states + temp_pos_embed
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
hidden_states = spatial_block(
hidden_states,
attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
None,
org_timestep,
all_timesteps=all_timesteps,
)
if enable_temporal_attentions:
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0 and self.training:
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
hidden_states_video = temp_block(
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
org_timestep,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
if i == 0 and frame > 1:
hidden_states = hidden_states + temp_pos_embed
hidden_states = temp_block(
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
org_timestep,
all_timesteps=all_timesteps,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
self.previous_residual = hidden_states - hidden_states_origin
else:
if enable_sequence_parallel():
set_temporal_pad(frame + use_image_num)
set_spatial_pad(num_patches)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
)
else:
temp_pos_embed = self.temp_pos_embed
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
if enable_temporal_attentions:
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0: # image-video joitn training
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
if i == 0:
hidden_states_video = hidden_states_video + temp_pos_embed
hidden_states_video = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
if i == 0:
hidden_states = hidden_states + temp_pos_embed
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
hidden_states = spatial_block(
hidden_states,
attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
None,
org_timestep,
all_timesteps=all_timesteps,
)
if enable_temporal_attentions:
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0 and self.training:
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
hidden_states_video = temp_block(
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
org_timestep,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
if i == 0 and frame > 1:
hidden_states = hidden_states + temp_pos_embed
hidden_states = temp_block(
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
org_timestep,
all_timesteps=all_timesteps,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
if enable_sequence_parallel():
if self.enable_teacache:
if should_calc:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
self.previous_residual = self.gather_from_second_dim(self.previous_residual, input_batch_size)
else:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
if self.is_input_patches:
if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# unpatchify
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
# 3. Gather batch for data parallelism
if get_cfg_parallel_size() > 1:
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
def eval_base(prompt_list):
config = LatteConfig()
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/latte_base", loop=5)
def eval_teacache_slow(prompt_list):
config = LatteConfig()
engine = VideoSysEngine(config)
engine.driver_worker.transformer.enable_teacache = True
engine.driver_worker.transformer.rel_l1_thresh = 0.1
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.previous_modulated_input = None
engine.driver_worker.transformer.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/latte_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
config = LatteConfig()
engine = VideoSysEngine(config)
engine.driver_worker.transformer.enable_teacache = True
engine.driver_worker.transformer.rel_l1_thresh = 0.2
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.previous_modulated_input = None
engine.driver_worker.transformer.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/latte_teacache_fast", loop=5)
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
# eval_base(prompt_list)
eval_teacache_slow(prompt_list)
# eval_teacache_fast(prompt_list)

View File

@ -0,0 +1,243 @@
from utils import generate_func, read_prompt_list
from videosys import OpenSoraConfig, VideoSysEngine
import torch
from einops import rearrange
from videosys.models.transformers.open_sora_transformer_3d import t2i_modulate, auto_grad_checkpoint
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_temporal_pad, set_spatial_pad, set_temporal_pad, split_sequence
import numpy as np
from videosys.utils.utils import batch_func
from videosys.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_size,
get_data_parallel_group,
get_sequence_parallel_group,
)
def teacache_forward(
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
):
# === Split batch ===
if get_cfg_parallel_size() > 1:
x, timestep, y, x_mask, mask = batch_func(
partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
)
dtype = self.x_embedder.proj.weight.dtype
B = x.size(0)
x = x.to(dtype)
timestep = timestep.to(dtype)
y = y.to(dtype)
# === get pos embed ===
_, _, Tx, Hx, Wx = x.size()
T, H, W = self.get_dynamic_size(x)
S = H * W
base_size = round(S**0.5)
resolution_sq = (height[0].item() * width[0].item()) ** 0.5
scale = resolution_sq / self.input_sq_size
pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size)
# === get timestep embed ===
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
fps = self.fps_embedder(fps.unsqueeze(1), B)
t = t + fps
t_mlp = self.t_block(t)
t0 = t0_mlp = None
if x_mask is not None:
t0_timestep = torch.zeros_like(timestep)
t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
t0 = t0 + fps
t0_mlp = self.t_block(t0)
# === get y embed ===
if self.config.skip_y_embedder:
y_lens = mask
if isinstance(y_lens, torch.Tensor):
y_lens = y_lens.long().tolist()
else:
y, y_lens = self.encode_text(y, mask)
# === get x embed ===
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = x + pos_emb
if self.enable_teacache:
inp = x.clone()
inp = rearrange(inp, "B T S C -> B (T S) C", T=T, S=S)
B, N, C = inp.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.spatial_blocks[0].scale_shift_table[None] + t_mlp.reshape(B, 6, -1)
).chunk(6, dim=1)
modulated_inp = t2i_modulate(self.spatial_blocks[0].norm1(inp), shift_msa, scale_msa)
if timestep[0] == all_timesteps[0] or timestep[0] == all_timesteps[-1]:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [2.17546007e+02, -1.18329252e+02, 2.68662585e+01, -4.59364272e-02, 4.84426240e-02]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
# === blocks ===
if self.enable_teacache:
if not should_calc:
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
x += self.previous_residual
else:
# shard over the sequence dim if sp is enabled
if enable_sequence_parallel():
set_temporal_pad(T)
set_spatial_pad(S)
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
T = x.shape[1]
x_mask_org = x_mask
x_mask = split_sequence(
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
)
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
origin_x = x.clone().detach()
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
x = auto_grad_checkpoint(
spatial_block,
x,
y,
t_mlp,
y_lens,
x_mask,
t0_mlp,
T,
S,
timestep,
all_timesteps=all_timesteps,
)
x = auto_grad_checkpoint(
temporal_block,
x,
y,
t_mlp,
y_lens,
x_mask,
t0_mlp,
T,
S,
timestep,
all_timesteps=all_timesteps,
)
self.previous_residual = x - origin_x
else:
# shard over the sequence dim if sp is enabled
if enable_sequence_parallel():
set_temporal_pad(T)
set_spatial_pad(S)
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
T = x.shape[1]
x_mask_org = x_mask
x_mask = split_sequence(
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
)
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
x = auto_grad_checkpoint(
spatial_block,
x,
y,
t_mlp,
y_lens,
x_mask,
t0_mlp,
T,
S,
timestep,
all_timesteps=all_timesteps,
)
x = auto_grad_checkpoint(
temporal_block,
x,
y,
t_mlp,
y_lens,
x_mask,
t0_mlp,
T,
S,
timestep,
all_timesteps=all_timesteps,
)
if enable_sequence_parallel():
if self.enable_teacache:
if should_calc:
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
self.previous_residual = rearrange(self.previous_residual, "B (T S) C -> B T S C", T=T, S=S)
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
T, S = x.shape[1], x.shape[2]
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
self.previous_residual = rearrange(self.previous_residual, "B T S C -> B (T S) C", T=T, S=S)
x_mask = x_mask_org
else:
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
T, S = x.shape[1], x.shape[2]
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
x_mask = x_mask_org
# === final layer ===
x = self.final_layer(x, t, x_mask, t0, T, S)
x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
# cast to float32 for better accuracy
x = x.to(torch.float32)
# === Gather Output ===
if get_cfg_parallel_size() > 1:
x = gather_sequence(x, get_data_parallel_group(), dim=0)
return x
def eval_base(prompt_list):
config = OpenSoraConfig()
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensora_base", loop=5)
def eval_teacache_slow(prompt_list):
config = OpenSoraConfig()
engine = VideoSysEngine(config)
engine.driver_worker.transformer.enable_teacache = True
engine.driver_worker.transformer.rel_l1_thresh = 0.1
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.previous_modulated_input = None
engine.driver_worker.transformer.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensora_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
config = OpenSoraConfig()
engine = VideoSysEngine(config)
engine.driver_worker.transformer.enable_teacache = True
engine.driver_worker.transformer.rel_l1_thresh = 0.2
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.previous_modulated_input = None
engine.driver_worker.transformer.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensora_teacache_fast", loop=5)
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
# eval_base(prompt_list)
eval_teacache_slow(prompt_list)
# eval_teacache_fast(prompt_list)

View File

@ -0,0 +1,594 @@
from utils import generate_func, read_prompt_list
from videosys import OpenSoraPlanConfig, VideoSysEngine
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
from typing import Any, Dict, Optional, Tuple
from videosys.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_group,
get_cfg_parallel_size,
get_sequence_parallel_group,
)
def teacache_forward(
self,
hidden_states: torch.Tensor,
timestep: Optional[torch.LongTensor] = None,
all_timesteps=None,
encoder_hidden_states: Optional[torch.Tensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_image_num: int = 0,
enable_temporal_attentions: bool = True,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# 0. Split batch
if get_cfg_parallel_size() > 1:
(
hidden_states,
timestep,
encoder_hidden_states,
class_labels,
attention_mask,
encoder_attention_mask,
) = batch_func(
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
hidden_states,
timestep,
encoder_hidden_states,
class_labels,
attention_mask,
encoder_attention_mask,
)
input_batch_size, c, frame, h, w = hidden_states.shape
frame = frame - use_image_num # 20-4=16
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
org_timestep = timestep
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is None:
attention_mask = torch.ones(
(input_batch_size, frame + use_image_num, h, w), device=hidden_states.device, dtype=hidden_states.dtype
)
attention_mask = self.vae_to_diff_mask(attention_mask, use_image_num)
dtype = attention_mask.dtype
attention_mask_compress = F.max_pool2d(
attention_mask.float(), kernel_size=self.compress_kv_factor, stride=self.compress_kv_factor
)
attention_mask_compress = attention_mask_compress.to(dtype)
attention_mask = self.make_attn_mask(attention_mask, frame, hidden_states.dtype)
attention_mask_compress = self.make_attn_mask(attention_mask_compress, frame, hidden_states.dtype)
# 1 + 4, 1 -> video condition, 4 -> image condition
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous()
encoder_attention_mask = encoder_attention_mask.to(self.dtype)
elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
encoder_attention_mask_video = repeat(
encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame
).contiguous()
encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1)
encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1)
encoder_attention_mask = encoder_attention_mask.to(self.dtype)
# Retrieve lora scale.
cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input
if self.is_input_patches: # here
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hw = (height, width)
num_patches = height * width
hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # alrady add positional embeddings
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
# batch_size = hidden_states.shape[0]
batch_size = input_batch_size
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states.to(self.dtype)) # 3 120 1152
if use_image_num != 0 and self.training:
encoder_hidden_states_video = encoder_hidden_states[:, :1, ...]
encoder_hidden_states_video = repeat(
encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame
).contiguous()
encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...]
encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous()
else:
encoder_hidden_states_spatial = repeat(
encoder_hidden_states, "b 1 t d -> (b f) t d", f=frame
).contiguous()
# prepare timesteps for spatial and temporal block
timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
pos_hw, pos_t = None, None
if self.use_rope:
pos_hw, pos_t = self.make_position(
input_batch_size, frame, use_image_num, height, width, hidden_states.device
)
if self.enable_teacache:
inp = hidden_states.clone()
batch_size = hidden_states.shape[0]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.transformer_blocks[0].scale_shift_table[None] + timestep_spatial.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
modulated_inp = self.transformer_blocks[0].norm1(inp) * (1 + scale_msa) + shift_msa
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [2.05943668e+05, -1.48759286e+04, 3.06085986e+02, 1.31418080e+00, 2.39658469e-03]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
if enable_sequence_parallel():
set_temporal_pad(frame + use_image_num)
set_spatial_pad(num_patches)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
)
else:
temp_pos_embed = self.temp_pos_embed
ori_hidden_states = hidden_states.clone()
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
pos_hw,
pos_hw,
hw,
use_reentrant=False,
)
if enable_temporal_attentions:
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0: # image-video joitn training
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
# if i == 0 and not self.use_rope:
if i == 0:
hidden_states_video = hidden_states_video + temp_pos_embed
hidden_states_video = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
use_reentrant=False,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
# if i == 0 and not self.use_rope:
if i == 0:
hidden_states = hidden_states + temp_pos_embed
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
use_reentrant=False,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
hidden_states = spatial_block(
hidden_states,
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
pos_hw,
pos_hw,
hw,
org_timestep,
all_timesteps=all_timesteps,
)
if enable_temporal_attentions:
# b c f h w, f = 16 + 4
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0 and self.training:
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
# if i == 0 and not self.use_rope:
# hidden_states_video = hidden_states_video + temp_pos_embed
hidden_states_video = temp_block(
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
org_timestep,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
# if i == 0 and not self.use_rope:
if i == 0:
hidden_states = hidden_states + temp_pos_embed
hidden_states = temp_block(
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
org_timestep,
all_timesteps=all_timesteps,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
self.previous_residual = hidden_states - ori_hidden_states
else:
if enable_sequence_parallel():
set_temporal_pad(frame + use_image_num)
set_spatial_pad(num_patches)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
)
else:
temp_pos_embed = self.temp_pos_embed
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
pos_hw,
pos_hw,
hw,
use_reentrant=False,
)
if enable_temporal_attentions:
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0: # image-video joitn training
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
# if i == 0 and not self.use_rope:
if i == 0:
hidden_states_video = hidden_states_video + temp_pos_embed
hidden_states_video = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
use_reentrant=False,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
# if i == 0 and not self.use_rope:
if i == 0:
hidden_states = hidden_states + temp_pos_embed
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
use_reentrant=False,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
hidden_states = spatial_block(
hidden_states,
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
pos_hw,
pos_hw,
hw,
org_timestep,
all_timesteps=all_timesteps,
)
if enable_temporal_attentions:
# b c f h w, f = 16 + 4
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0 and self.training:
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
# if i == 0 and not self.use_rope:
# hidden_states_video = hidden_states_video + temp_pos_embed
hidden_states_video = temp_block(
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
org_timestep,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
# if i == 0 and not self.use_rope:
if i == 0:
hidden_states = hidden_states + temp_pos_embed
hidden_states = temp_block(
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
org_timestep,
all_timesteps=all_timesteps,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
if enable_sequence_parallel():
if self.enable_teacache:
if should_calc:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
self.previous_residual = self.gather_from_second_dim(self.previous_residual, input_batch_size)
else:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
if self.is_input_patches:
if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# unpatchify
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
# 3. Gather batch for data parallelism
if get_cfg_parallel_size() > 1:
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
def eval_base(prompt_list):
config = OpenSoraPlanConfig()
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)
def eval_teacache_slow(prompt_list):
config = OpenSoraPlanConfig()
engine = VideoSysEngine(config)
engine.driver_worker.transformer.enable_teacache = True
engine.driver_worker.transformer.rel_l1_thresh = 0.1
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.previous_modulated_input = None
engine.driver_worker.transformer.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
config = OpenSoraPlanConfig()
engine = VideoSysEngine(config)
engine.driver_worker.transformer.enable_teacache = True
engine.driver_worker.transformer.rel_l1_thresh = 0.2
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.previous_modulated_input = None
engine.driver_worker.transformer.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_fast", loop=5)
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
# eval_base(prompt_list)
eval_teacache_slow(prompt_list)
# eval_teacache_fast(prompt_list)

View File

@ -0,0 +1,22 @@
import json
import os
import tqdm
from videosys.utils.utils import set_seed
def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict = {}):
kwargs["verbose"] = False
for prompt in tqdm.tqdm(prompt_list):
for l in range(loop):
set_seed(l)
video = pipeline.generate(prompt, **kwargs).video[0]
pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4"))
def read_prompt_list(prompt_list_path):
with open(prompt_list_path, "r") as f:
prompt_list = json.load(f)
prompt_list = [prompt["prompt_en"] for prompt in prompt_list]
return prompt_list

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,154 @@
import argparse
import json
import os
SEMANTIC_WEIGHT = 1
QUALITY_WEIGHT = 4
QUALITY_LIST = [
"subject consistency",
"background consistency",
"temporal flickering",
"motion smoothness",
"aesthetic quality",
"imaging quality",
"dynamic degree",
]
SEMANTIC_LIST = [
"object class",
"multiple objects",
"human action",
"color",
"spatial relationship",
"scene",
"appearance style",
"temporal style",
"overall consistency",
]
NORMALIZE_DIC = {
"subject consistency": {"Min": 0.1462, "Max": 1.0},
"background consistency": {"Min": 0.2615, "Max": 1.0},
"temporal flickering": {"Min": 0.6293, "Max": 1.0},
"motion smoothness": {"Min": 0.706, "Max": 0.9975},
"dynamic degree": {"Min": 0.0, "Max": 1.0},
"aesthetic quality": {"Min": 0.0, "Max": 1.0},
"imaging quality": {"Min": 0.0, "Max": 1.0},
"object class": {"Min": 0.0, "Max": 1.0},
"multiple objects": {"Min": 0.0, "Max": 1.0},
"human action": {"Min": 0.0, "Max": 1.0},
"color": {"Min": 0.0, "Max": 1.0},
"spatial relationship": {"Min": 0.0, "Max": 1.0},
"scene": {"Min": 0.0, "Max": 0.8222},
"appearance style": {"Min": 0.0009, "Max": 0.2855},
"temporal style": {"Min": 0.0, "Max": 0.364},
"overall consistency": {"Min": 0.0, "Max": 0.364},
}
DIM_WEIGHT = {
"subject consistency": 1,
"background consistency": 1,
"temporal flickering": 1,
"motion smoothness": 1,
"aesthetic quality": 1,
"imaging quality": 1,
"dynamic degree": 0.5,
"object class": 1,
"multiple objects": 1,
"human action": 1,
"color": 1,
"spatial relationship": 1,
"scene": 1,
"appearance style": 1,
"temporal style": 1,
"overall consistency": 1,
}
ordered_scaled_res = [
"total score",
"quality score",
"semantic score",
"subject consistency",
"background consistency",
"temporal flickering",
"motion smoothness",
"dynamic degree",
"aesthetic quality",
"imaging quality",
"object class",
"multiple objects",
"human action",
"color",
"spatial relationship",
"scene",
"appearance style",
"temporal style",
"overall consistency",
]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--score_dir", required=True, type=str)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
res_postfix = "_eval_results.json"
info_postfix = "_full_info.json"
files = os.listdir(args.score_dir)
res_files = [x for x in files if res_postfix in x]
info_files = [x for x in files if info_postfix in x]
assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files"
full_results = {}
for res_file in res_files:
# first check if results is normal
info_file = res_file.split(res_postfix)[0] + info_postfix
with open(os.path.join(args.score_dir, info_file), "r", encoding="utf-8") as f:
info = json.load(f)
assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list"
# read results
with open(os.path.join(args.score_dir, res_file), "r", encoding="utf-8") as f:
data = json.load(f)
for key, val in data.items():
full_results[key] = format(val[0], ".4f")
scaled_results = {}
dims = set()
for key, val in full_results.items():
dim = key.replace("_", " ") if "_" in key else key
scaled_score = (float(val) - NORMALIZE_DIC[dim]["Min"]) / (
NORMALIZE_DIC[dim]["Max"] - NORMALIZE_DIC[dim]["Min"]
)
scaled_score *= DIM_WEIGHT[dim]
scaled_results[dim] = scaled_score
dims.add(dim)
assert len(dims) == len(NORMALIZE_DIC), f"{set(NORMALIZE_DIC.keys())-dims} not calculated yet"
quality_score = sum([scaled_results[i] for i in QUALITY_LIST]) / sum([DIM_WEIGHT[i] for i in QUALITY_LIST])
semantic_score = sum([scaled_results[i] for i in SEMANTIC_LIST]) / sum([DIM_WEIGHT[i] for i in SEMANTIC_LIST])
scaled_results["quality score"] = quality_score
scaled_results["semantic score"] = semantic_score
scaled_results["total score"] = (quality_score * QUALITY_WEIGHT + semantic_score * SEMANTIC_WEIGHT) / (
QUALITY_WEIGHT + SEMANTIC_WEIGHT
)
formated_scaled_results = {"items": []}
for key in ordered_scaled_res:
formated_score = format(scaled_results[key] * 100, ".2f") + "%"
formated_scaled_results["items"].append({key: formated_score})
output_file_path = os.path.join(args.score_dir, "all_results.json")
with open(output_file_path, "w") as outfile:
json.dump(full_results, outfile, indent=4, sort_keys=True)
print(f"results saved to: {output_file_path}")
scaled_file_path = os.path.join(args.score_dir, "scaled_results.json")
with open(scaled_file_path, "w") as outfile:
json.dump(formated_scaled_results, outfile, indent=4, sort_keys=True)
print(f"results saved to: {scaled_file_path}")

View File

@ -0,0 +1,52 @@
import argparse
import torch
from vbench import VBench
full_info_path = "./vbench/VBench_full_info.json"
dimensions = [
"subject_consistency",
"imaging_quality",
"background_consistency",
"motion_smoothness",
"overall_consistency",
"human_action",
"multiple_objects",
"spatial_relationship",
"object_class",
"color",
"aesthetic_quality",
"appearance_style",
"temporal_flickering",
"scene",
"temporal_style",
"dynamic_degree",
]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--video_path", required=True, type=str)
parser.add_argument("--save_path", required=True, type=str)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
kwargs = {}
kwargs["imaging_quality_preprocessing_mode"] = "longer" # use VBench/evaluate.py default
for dimension in dimensions:
my_VBench = VBench(torch.device("cuda"), full_info_path, args.save_path)
my_VBench.evaluate(
videos_path=args.video_path,
name=dimension,
local=False,
read_frame=False,
dimension_list=[dimension],
mode="vbench_standard",
**kwargs,
)

22
requirements.txt Normal file
View File

@ -0,0 +1,22 @@
click
colossalai
diffusers==0.30.0
einops
fabric
ftfy
imageio
imageio-ffmpeg
matplotlib
ninja
numpy<2.0.0
omegaconf
packaging
psutil
pydantic
ray
rich
safetensors
timm
torch>=1.13
tqdm
transformers

55
setup.py Normal file
View File

@ -0,0 +1,55 @@
from typing import List
from setuptools import find_packages, setup
def fetch_requirements(path) -> List[str]:
"""
This function reads the requirements file.
Args:
path (str): the path to the requirements file.
Returns:
The lines in the requirements file.
"""
with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
def fetch_readme() -> str:
"""
This function reads the README.md file in the current directory.
Returns:
The lines in the README file.
"""
with open("README.md", encoding="utf-8") as f:
return f.read()
setup(
name="videosys",
version="2.0.0",
packages=find_packages(
exclude=(
"videos",
"tests",
"figure",
"*.egg-info",
)
),
description="VideoSys",
long_description=fetch_readme(),
long_description_content_type="text/markdown",
license="Apache Software License 2.0",
install_requires=fetch_requirements("requirements.txt"),
python_requires=">=3.6",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Environment :: GPU :: NVIDIA CUDA",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: System :: Distributed Computing",
],
)

223
videosys.egg-info/PKG-INFO Normal file
View File

@ -0,0 +1,223 @@
Metadata-Version: 2.1
Name: videosys
Version: 2.0.0
Summary: VideoSys
License: Apache Software License 2.0
Platform: UNKNOWN
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Environment :: GPU :: NVIDIA CUDA
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: System :: Distributed Computing
Requires-Python: >=3.6
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: click
Requires-Dist: colossalai
Requires-Dist: contexttimer
Requires-Dist: diffusers==0.30.0
Requires-Dist: einops
Requires-Dist: fabric
Requires-Dist: ftfy
Requires-Dist: imageio
Requires-Dist: imageio-ffmpeg
Requires-Dist: matplotlib
Requires-Dist: ninja
Requires-Dist: numpy<2.0.0
Requires-Dist: omegaconf
Requires-Dist: packaging
Requires-Dist: psutil
Requires-Dist: pydantic
Requires-Dist: ray
Requires-Dist: rich
Requires-Dist: safetensors
Requires-Dist: timm
Requires-Dist: torch>=1.13
Requires-Dist: tqdm
Requires-Dist: transformers
<p align="center">
<img width="55%" alt="VideoSys" src="./assets/figures/logo.png?raw=true">
</p>
<h3 align="center">
An easy and efficient system for video generation
</h3>
</p>
### Latest News 🔥
- [2024/08] 🔥 Evole from [OpenDiT](https://github.com/NUS-HPC-AI-Lab/VideoSys/tree/v1.0.0) to <b>VideoSys: An easy and efficient system for video generation.</b>
- [2024/08] 🔥 <b>Release PAB paper: [Real-Time Video Generation with Pyramid Attention Broadcast](https://arxiv.org/abs/2408.12588).</b>
- [2024/06] Propose Pyramid Attention Broadcast (PAB)[[paper](https://arxiv.org/abs/2408.12588)][[blog](https://oahzxl.github.io/PAB/)][[doc](./docs/pab.md)], the first approach to achieve <b>real-time</b> DiT-based video generation, delivering <b>negligible quality loss</b> without <b>requiring any training</b>.
- [2024/06] Support [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) and [Latte](https://github.com/Vchitect/Latte).
- [2024/03] Propose Dynamic Sequence Parallel (DSP)[[paper](https://arxiv.org/abs/2403.10266)][[doc](./docs/dsp.md)], achieves **3x** speed for training and **2x** speed for inference in Open-Sora compared with sota sequence parallelism.
- [2024/03] Support [Open-Sora: Democratizing Efficient Video Production for All](https://github.com/hpcaitech/Open-Sora).
- [2024/02] 🎉 Release [OpenDiT](https://github.com/NUS-HPC-AI-Lab/VideoSys/tree/v1.0.0): An Easy, Fast and Memory-Efficent System for DiT Training and Inference.
# About
VideoSys is an open-source project that provides a user-friendly and high-performance infrastructure for video generation. This comprehensive toolkit will support the entire pipeline from training and inference to serving and compression.
We are committed to continually integrating cutting-edge open-source video models and techniques. Stay tuned for exciting enhancements and new features on the horizon!
## Installation
Prerequisites:
- Python >= 3.10
- PyTorch >= 1.13 (We recommend to use a >2.0 version)
- CUDA >= 11.6
We strongly recommend using Anaconda to create a new environment (Python >= 3.10) to run our examples:
```shell
conda create -n videosys python=3.10 -y
conda activate videosys
```
Install VideoSys:
```shell
git clone https://github.com/NUS-HPC-AI-Lab/VideoSys
cd VideoSys
pip install -e .
```
## Usage
VideoSys supports many diffusion models with our various acceleration techniques, enabling these models to run faster and consume less memory.
<b>You can find all available models and their supported acceleration techniques in the following table. Click `Doc` to see how to use them.</b>
<table>
<tr>
<th rowspan="2">Model</th>
<th rowspan="2">Train</th>
<th rowspan="2">Infer</th>
<th colspan="2">Acceleration Techniques</th>
<th rowspan="2">Usage</th>
</tr>
<tr>
<th><a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#dyanmic-sequence-parallelism-dsp-paperdoc">DSP</a></th>
<th><a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#pyramid-attention-broadcast-pab-blogdoc">PAB</a></th>
</tr>
<tr>
<td>Open-Sora [<a href="https://github.com/hpcaitech/Open-Sora">source</a>]</td>
<td align="center">🟡</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center"><a href="./examples/open_sora/sample.py">Code</a></td>
</tr>
<tr>
<td>Open-Sora-Plan [<a href="https://github.com/PKU-YuanGroup/Open-Sora-Plan">source</a>]</td>
<td align="center">/</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center"><a href="./examples/open_sora_plan/sample.py">Code</a></td>
</tr>
<tr>
<td>Latte [<a href="https://github.com/Vchitect/Latte">source</a>]</td>
<td align="center">/</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center"><a href="./examples/latte/sample.py">Code</a></td>
</tr>
<tr>
<td>CogVideoX [<a href="https://github.com/THUDM/CogVideo">source</a>]</td>
<td align="center">/</td>
<td align="center">✅</td>
<td align="center">/</td>
<td align="center">✅</td>
<td align="center"><a href="./examples/cogvideox/sample.py">Code</a></td>
</tr>
</table>
## Acceleration Techniques
### Pyramid Attention Broadcast (PAB) [[paper](https://arxiv.org/abs/2408.12588)][[blog](https://arxiv.org/abs/2403.10266)][[doc](./docs/pab.md)]
Real-Time Video Generation with Pyramid Attention Broadcast
Authors: [Xuanlei Zhao](https://oahzxl.github.io/)<sup>1*</sup>, [Xiaolong Jin]()<sup>2*</sup>, [Kai Wang](https://kaiwang960112.github.io/)<sup>1*</sup>, and [Yang You](https://www.comp.nus.edu.sg/~youy/)<sup>1</sup> (* indicates equal contribution)
<sup>1</sup>National University of Singapore, <sup>2</sup>Purdue University
![method](./assets/figures/pab_method.png)
PAB is the first approach to achieve <b>real-time</b> DiT-based video generation, delivering <b>lossless quality</b> without <b>requiring any training</b>. By mitigating redundant attention computation, PAB achieves up to 21.6 FPS with 10.6x acceleration, without sacrificing quality across popular DiT-based video generation models including [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Latte](https://github.com/Vchitect/Latte) and [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan).
See its details [here](./docs/pab.md).
----
### Dyanmic Sequence Parallelism (DSP) [[paper](https://arxiv.org/abs/2403.10266)][[doc](./docs/dsp.md)]
![dsp_overview](./assets/figures/dsp_overview.png)
DSP is a novel, elegant and super efficient sequence parallelism for [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Latte](https://github.com/Vchitect/Latte) and other multi-dimensional transformer architecture.
It achieves **3x** speed for training and **2x** speed for inference in Open-Sora compared with sota sequence parallelism ([DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509)). For a 10s (80 frames) of 512x512 video, the inference latency of Open-Sora is:
| Method | 1xH800 | 8xH800 (DS Ulysses) | 8xH800 (DSP) |
| ------ | ------ | ------ | ------ |
| Latency(s) | 106 | 45 | 22 |
See its details [here](./docs/dsp.md).
## Contributing
We welcome and value any contributions and collaborations. Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
## Contributors
<a href="https://github.com/NUS-HPC-AI-Lab/VideoSys/graphs/contributors">
<img src="https://contrib.rocks/image?repo=NUS-HPC-AI-Lab/VideoSys"/>
</a>
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=NUS-HPC-AI-Lab/VideoSys&type=Date)](https://star-history.com/#NUS-HPC-AI-Lab/VideoSys&Date)
## Citation
```
@misc{videosys2024,
author={VideoSys Team},
title={VideoSys: An Easy and Efficient System for Video Generation},
year={2024},
publisher={GitHub},
url = {https://github.com/NUS-HPC-AI-Lab/VideoSys},
}
@misc{zhao2024pab,
title={Real-Time Video Generation with Pyramid Attention Broadcast},
author={Xuanlei Zhao and Xiaolong Jin and Kai Wang and Yang You},
year={2024},
eprint={2408.12588},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2408.12588},
}
@misc{zhao2024dsp,
title={DSP: Dynamic Sequence Parallelism for Multi-Dimensional Transformers},
author={Xuanlei Zhao and Shenggan Cheng and Chang Chen and Zangwei Zheng and Ziming Liu and Zheming Yang and Yang You},
year={2024},
eprint={2403.10266},
archivePrefix={arXiv},
primaryClass={cs.DC},
url={https://arxiv.org/abs/2403.10266},
}
@misc{zhao2024opendit,
author={Xuanlei Zhao, Zhongkai Zhao, Ziming Liu, Haotian Zhou, Qianli Ma, and Yang You},
title={OpenDiT: An Easy, Fast and Memory-Efficient System for DiT Training and Inference},
year={2024},
publisher={GitHub},
url={https://github.com/NUS-HPC-AI-Lab/VideoSys/tree/v1.0.0},
}
```

View File

@ -0,0 +1,60 @@
LICENSE
README.md
setup.py
tests/pipelines/__init__.py
tests/pipelines/cogvideox/__init__.py
tests/pipelines/cogvideox/test_cogvideox.py
tests/pipelines/latte/__init__.py
tests/pipelines/latte/test_latte.py
tests/pipelines/open_sora/__init__.py
tests/pipelines/open_sora/test_open_sora.py
tests/pipelines/open_sora_plan/__init__.py
tests/pipelines/open_sora_plan/test_open_sora_plan.py
videosys/__init__.py
videosys.egg-info/PKG-INFO
videosys.egg-info/SOURCES.txt
videosys.egg-info/dependency_links.txt
videosys.egg-info/requires.txt
videosys.egg-info/top_level.txt
videosys/core/__init__.py
videosys/core/comm.py
videosys/core/engine.py
videosys/core/mp_utils.py
videosys/core/pab_mgr.py
videosys/core/parallel_mgr.py
videosys/core/pipeline.py
videosys/core/shardformer/__init__.py
videosys/core/shardformer/t5/__init__.py
videosys/core/shardformer/t5/modeling.py
videosys/core/shardformer/t5/policy.py
videosys/models/__init__.py
videosys/models/autoencoders/__init__.py
videosys/models/autoencoders/autoencoder_kl_cogvideox.py
videosys/models/autoencoders/autoencoder_kl_open_sora.py
videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py
videosys/models/modules/__init__.py
videosys/models/modules/activations.py
videosys/models/modules/attentions.py
videosys/models/modules/downsampling.py
videosys/models/modules/embeddings.py
videosys/models/modules/normalization.py
videosys/models/modules/upsampling.py
videosys/models/transformers/__init__.py
videosys/models/transformers/cogvideox_transformer_3d.py
videosys/models/transformers/latte_transformer_3d.py
videosys/models/transformers/open_sora_plan_transformer_3d.py
videosys/models/transformers/open_sora_transformer_3d.py
videosys/pipelines/__init__.py
videosys/pipelines/cogvideox/__init__.py
videosys/pipelines/cogvideox/pipeline_cogvideox.py
videosys/pipelines/latte/__init__.py
videosys/pipelines/latte/pipeline_latte.py
videosys/pipelines/open_sora/__init__.py
videosys/pipelines/open_sora/data_process.py
videosys/pipelines/open_sora/pipeline_open_sora.py
videosys/pipelines/open_sora_plan/__init__.py
videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py
videosys/schedulers/__init__.py
videosys/schedulers/scheduling_ddim_cogvideox.py
videosys/schedulers/scheduling_dpm_cogvideox.py
videosys/schedulers/scheduling_rflow_open_sora.py

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,23 @@
click
colossalai
contexttimer
diffusers==0.30.0
einops
fabric
ftfy
imageio
imageio-ffmpeg
matplotlib
ninja
numpy<2.0.0
omegaconf
packaging
psutil
pydantic
ray
rich
safetensors
timm
torch>=1.13
tqdm
transformers

View File

@ -0,0 +1,2 @@
tests
videosys

15
videosys/__init__.py Normal file
View File

@ -0,0 +1,15 @@
from .core.engine import VideoSysEngine
from .core.parallel_mgr import initialize
from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline
from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
from .pipelines.open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
__all__ = [
"initialize",
"VideoSysEngine",
"LattePipeline", "LatteConfig", "LattePABConfig",
"OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanPABConfig",
"OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig",
"CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"
] # fmt: skip

Binary file not shown.

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

420
videosys/core/comm.py Normal file
View File

@ -0,0 +1,420 @@
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from torch.distributed import ProcessGroup
from videosys.core.parallel_mgr import get_sequence_parallel_size
# ======================================================
# Model
# ======================================================
def model_sharding(model: torch.nn.Module):
global_rank = dist.get_rank()
world_size = dist.get_world_size()
for _, param in model.named_parameters():
padding_size = (world_size - param.numel() % world_size) % world_size
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // world_size)
splited_params = splited_params[global_rank]
param.data = splited_params
# ======================================================
# AllGather & ReduceScatter
# ======================================================
class AsyncAllGatherForTwo(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
weight: Tensor,
bias: Tensor,
sp_rank: int,
sp_size: int,
group: Optional[ProcessGroup] = None,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
from torch.distributed._functional_collectives import all_gather_tensor
ctx.group = group
ctx.sp_rank = sp_rank
ctx.sp_size = sp_size
# all gather inputs
all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group)
# compute local qkv
local_qkv = F.linear(inputs, weight, bias).unsqueeze(0)
# remote compute
remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1])
# compute remote qkv
remote_qkv = F.linear(remote_inputs, weight, bias)
# concat local and remote qkv
if sp_rank == 0:
qkv = torch.cat([local_qkv, remote_qkv], dim=0)
else:
qkv = torch.cat([remote_qkv, local_qkv], dim=0)
qkv = rearrange(qkv, "sp b n c -> b (sp n) c")
ctx.save_for_backward(inputs, weight, remote_inputs)
return qkv
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
from torch.distributed._functional_collectives import reduce_scatter_tensor
group = ctx.group
sp_rank = ctx.sp_rank
sp_size = ctx.sp_size
inputs, weight, remote_inputs = ctx.saved_tensors
# split qkv_grad
qkv_grad = grad_outputs[0]
qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size)
qkv_grad = torch.chunk(qkv_grad, 2, dim=0)
if sp_rank == 0:
local_qkv_grad, remote_qkv_grad = qkv_grad
else:
remote_qkv_grad, local_qkv_grad = qkv_grad
# compute remote grad
remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0)
weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0)
bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0)
# launch async reduce scatter
remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad)
if sp_rank == 0:
remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0)
else:
remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0)
remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group)
# compute local grad and wait for reduce scatter
local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0)
weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0)
bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0)
# sum remote and local grad
inputs_grad = remote_inputs_grad + local_input_grad
return inputs_grad, weight_grad, bias_grad, None, None, None
class AllGather(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.unsqueeze(0), None
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
if not overlap:
dist.all_gather(buffer_list, inputs, group=group)
return outputs, None
else:
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: ProcessGroup,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0), None
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
if not overlap:
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs, None
else:
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
# TODO: support async backward
return (
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
# ======================================================
# AlltoAll
# ======================================================
def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
world_size = dist.get_world_size(process_group)
return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim)
@staticmethod
def backward(ctx, *grad_output):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
return (return_grad, None, None, None)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
# ======================================================
# Sequence Gather & Split
# ======================================================
def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
# skip if only one rank involved
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
if world_size == 1:
return input_
if pad > 0:
pad_size = list(input_.shape)
pad_size[dim] = pad
input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim)
dim_size = input_.size(dim)
assert dim_size % world_size == 0, f"dim_size ({dim_size}) is not divisible by world_size ({world_size})"
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
output = tensor_list[rank].contiguous()
return output
def _gather_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
# skip if only one rank involved
input_ = input_.contiguous()
world_size = dist.get_world_size(pg)
dist.get_rank(pg)
if world_size == 1:
return input_
# all gather
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
assert input_.device.type == "cuda"
torch.distributed.all_gather(tensor_list, input_, group=pg)
# concat
output = torch.cat(tensor_list, dim=dim)
if pad > 0:
output = output.narrow(dim, 0, output.size(dim) - pad)
return output
class _GatherForwardSplitBackward(torch.autograd.Function):
"""
Gather the input sequence.
Args:
input_: input matrix.
process_group: process group.
dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _gather_sequence_func(input_)
@staticmethod
def forward(ctx, input_, process_group, dim, grad_scale, pad):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
ctx.pad = pad
return _gather_sequence_func(input_, process_group, dim, pad)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale == "up":
grad_output = grad_output * dist.get_world_size(ctx.process_group)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.process_group)
return _split_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split sequence.
Args:
input_: input matrix.
process_group: parallel mode.
dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _split_sequence_func(input_)
@staticmethod
def forward(ctx, input_, process_group, dim, grad_scale, pad):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
ctx.pad = pad
return _split_sequence_func(input_, process_group, dim, pad)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale == "up":
grad_output = grad_output * dist.get_world_size(ctx.process_group)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.process_group)
return _gather_sequence_func(grad_output, ctx.process_group, ctx.pad), None, None, None, None
def split_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad)
def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad)
# ==============================
# Pad
# ==============================
SPTIAL_PAD = 0
TEMPORAL_PAD = 0
def set_spatial_pad(dim_size: int):
sp_size = get_sequence_parallel_size()
pad = (sp_size - (dim_size % sp_size)) % sp_size
global SPTIAL_PAD
SPTIAL_PAD = pad
def get_spatial_pad() -> int:
return SPTIAL_PAD
def set_temporal_pad(dim_size: int):
sp_size = get_sequence_parallel_size()
pad = (sp_size - (dim_size % sp_size)) % sp_size
global TEMPORAL_PAD
TEMPORAL_PAD = pad
def get_temporal_pad() -> int:
return TEMPORAL_PAD
def all_to_all_with_pad(
input_: torch.Tensor,
process_group: dist.ProcessGroup,
scatter_dim: int = 2,
gather_dim: int = 1,
scatter_pad: int = 0,
gather_pad: int = 0,
):
if scatter_pad > 0:
pad_shape = list(input_.shape)
pad_shape[scatter_dim] = scatter_pad
pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype)
input_ = torch.cat([input_, pad_tensor], dim=scatter_dim)
assert (
input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0
), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})"
input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
if gather_pad > 0:
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
return input_

128
videosys/core/engine.py Normal file
View File

@ -0,0 +1,128 @@
import os
from functools import partial
from typing import Any, Optional
import torch
import torch.distributed as dist
import videosys
from .mp_utils import ProcessWorkerWrapper, ResultHandler, WorkerMonitor, get_distributed_init_method, get_open_port
class VideoSysEngine:
"""
this is partly inspired by vllm
"""
def __init__(self, config):
self.config = config
self.parallel_worker_tasks = None
self._init_worker(config.pipeline_cls)
def _init_worker(self, pipeline_cls):
world_size = self.config.num_gpus
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU
# contention amongst the shards
if "OMP_NUM_THREADS" not in os.environ:
os.environ["OMP_NUM_THREADS"] = "1"
# NOTE: The two following lines need adaption for multi-node
assert world_size <= torch.cuda.device_count()
# change addr for multi-node
distributed_init_method = get_distributed_init_method("127.0.0.1", get_open_port())
if world_size == 1:
self.workers = []
self.worker_monitor = None
else:
result_handler = ResultHandler()
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_pipeline,
pipeline_cls=pipeline_cls,
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
),
)
for rank in range(1, world_size)
]
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
result_handler.start()
self.worker_monitor.start()
self.driver_worker = self._create_pipeline(
pipeline_cls=pipeline_cls, distributed_init_method=distributed_init_method
)
# TODO: add more options here for pipeline, or wrap all options into config
def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None):
videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method, seed=42)
pipeline = pipeline_cls(self.config)
return pipeline
def _run_workers(
self,
method: str,
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
# Start the workers first.
worker_outputs = [worker.execute_method(method, *args, **kwargs) for worker in self.workers]
if async_run_tensor_parallel_workers_only:
# Just return futures
return worker_outputs
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*args, **kwargs)
# Get the results of the workers.
return [driver_worker_output] + [output.get() for output in worker_outputs]
def _driver_execute_model(self, *args, **kwargs):
return self.driver_worker.generate(*args, **kwargs)
def generate(self, *args, **kwargs):
return self._run_workers("generate", *args, **kwargs)[0]
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for result in parallel_worker_tasks:
result.get()
def save_video(self, video, output_path):
return self.driver_worker.save_video(video, output_path)
def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
worker_monitor.close()
dist.destroy_process_group()
def __del__(self):
self.shutdown()

270
videosys/core/mp_utils.py Normal file
View File

@ -0,0 +1,270 @@
# adapted from vllm
# https://github.com/vllm-project/vllm/blob/main/vllm/executor/multiproc_worker_utils.py
import asyncio
import multiprocessing
import os
import socket
import sys
import threading
import traceback
import uuid
from dataclasses import dataclass
from multiprocessing import Queue
from multiprocessing.connection import wait
from typing import Any, Callable, Dict, Generic, List, Optional, TextIO, TypeVar, Union
from videosys.utils.logging import create_logger
T = TypeVar("T")
_TERMINATE = "TERMINATE" # sentinel
# ANSI color codes
CYAN = "\033[1;36m"
RESET = "\033[0;0m"
JOIN_TIMEOUT_S = 2
mp_method = "spawn" # fork cann't work
mp = multiprocessing.get_context(mp_method)
logger = create_logger()
def get_distributed_init_method(ip: str, port: int) -> str:
# Brackets are not permitted in ipv4 addresses,
# see https://github.com/python/cpython/issues/103848
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
def get_open_port() -> int:
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
@dataclass
class Result(Generic[T]):
"""Result of task dispatched to worker"""
task_id: uuid.UUID
value: Optional[T] = None
exception: Optional[BaseException] = None
class ResultFuture(threading.Event, Generic[T]):
"""Synchronous future for non-async case"""
def __init__(self):
super().__init__()
self.result: Optional[Result[T]] = None
def set_result(self, result: Result[T]):
self.result = result
self.set()
def get(self) -> T:
self.wait()
assert self.result is not None
if self.result.exception is not None:
raise self.result.exception
return self.result.value # type: ignore[return-value]
def _set_future_result(future: Union[ResultFuture, asyncio.Future], result: Result):
if isinstance(future, ResultFuture):
future.set_result(result)
return
loop = future.get_loop()
if not loop.is_closed():
if result.exception is not None:
loop.call_soon_threadsafe(future.set_exception, result.exception)
else:
loop.call_soon_threadsafe(future.set_result, result.value)
class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""
def __init__(self) -> None:
super().__init__(daemon=True)
self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
def run(self):
for result in iter(self.result_queue.get, _TERMINATE):
future = self.tasks.pop(result.task_id)
_set_future_result(future, result)
# Ensure that all waiters will receive an exception
for task_id, future in self.tasks.items():
_set_future_result(future, Result(task_id=task_id, exception=ChildProcessError("worker died")))
def close(self):
self.result_queue.put(_TERMINATE)
class WorkerMonitor(threading.Thread):
"""Monitor worker status (in background thread)"""
def __init__(self, workers: List["ProcessWorkerWrapper"], result_handler: ResultHandler):
super().__init__(daemon=True)
self.workers = workers
self.result_handler = result_handler
self._close = False
def run(self) -> None:
# Blocks until any worker exits
dead_sentinels = wait([w.process.sentinel for w in self.workers])
if not self._close:
self._close = True
# Kill / cleanup all workers
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
logger.error("Worker %s pid %s died, exit code: %s", process.name, process.pid, process.exitcode)
# Cleanup any remaining workers
logger.info("Killing local worker processes")
for worker in self.workers:
worker.kill_worker()
# Must be done after worker task queues are all closed
self.result_handler.close()
for worker in self.workers:
worker.process.join(JOIN_TIMEOUT_S)
def close(self):
if self._close:
return
self._close = True
logger.info("Terminating local worker processes")
for worker in self.workers:
worker.terminate_worker()
# Must be done after worker task queues are all closed
self.result_handler.close()
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
"""Prepend each output line with process-specific prefix"""
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
file_write = file.write
def write_with_prefix(s: str):
if not s:
return
if file.start_new_line: # type: ignore[attr-defined]
file_write(prefix)
idx = 0
while (next_idx := s.find("\n", idx)) != -1:
next_idx += 1
file_write(s[idx:next_idx])
if next_idx == len(s):
file.start_new_line = True # type: ignore[attr-defined]
return
file_write(prefix)
idx = next_idx
file_write(s[idx:])
file.start_new_line = False # type: ignore[attr-defined]
file.start_new_line = True # type: ignore[attr-defined]
file.write = write_with_prefix # type: ignore[method-assign]
def _run_worker_process(
worker_factory: Callable[[], Any],
task_queue: Queue,
result_queue: Queue,
) -> None:
"""Worker process event loop"""
# Add process-specific prefix to stdout and stderr
process_name = mp.current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
# Initialize worker
worker = worker_factory()
del worker_factory
# Accept tasks from the engine in task_queue
# and return task output in result_queue
logger.info("Worker ready; awaiting tasks")
try:
for items in iter(task_queue.get, _TERMINATE):
output = None
exception = None
task_id, method, args, kwargs = items
try:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
except BaseException as e:
tb = traceback.format_exc()
logger.error("Exception in worker %s while processing method %s: %s, %s", process_name, method, e, tb)
exception = e
result_queue.put(Result(task_id=task_id, value=output, exception=exception))
except KeyboardInterrupt:
pass
except Exception:
logger.exception("Worker failed")
logger.info("Worker exiting")
class ProcessWorkerWrapper:
"""Local process wrapper for handling single-node multi-GPU."""
def __init__(self, result_handler: ResultHandler, worker_factory: Callable[[], Any]) -> None:
self._task_queue = mp.Queue()
self.result_queue = result_handler.result_queue
self.tasks = result_handler.tasks
self.process = mp.Process( # type: ignore[attr-defined]
target=_run_worker_process,
name="VideoSysWorkerProcess",
kwargs=dict(
worker_factory=worker_factory,
task_queue=self._task_queue,
result_queue=self.result_queue,
),
daemon=True,
)
self.process.start()
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], method: str, args, kwargs):
task_id = uuid.uuid4()
self.tasks[task_id] = future
try:
self._task_queue.put((task_id, method, args, kwargs))
except BaseException as e:
del self.tasks[task_id]
raise ChildProcessError("worker died") from e
def execute_method(self, method: str, *args, **kwargs):
future: ResultFuture = ResultFuture()
self._enqueue_task(future, method, args, kwargs)
return future
async def execute_method_async(self, method: str, *args, **kwargs):
future = asyncio.get_running_loop().create_future()
self._enqueue_task(future, method, args, kwargs)
return await future
def terminate_worker(self):
try:
self._task_queue.put(_TERMINATE)
except ValueError:
self.process.kill()
self._task_queue.close()
def kill_worker(self):
self._task_queue.close()
self.process.kill()

233
videosys/core/pab_mgr.py Normal file
View File

@ -0,0 +1,233 @@
from videosys.utils.logging import logger
PAB_MANAGER = None
class PABConfig:
def __init__(
self,
steps: int,
cross_broadcast: bool = False,
cross_threshold: list = None,
cross_range: int = None,
spatial_broadcast: bool = False,
spatial_threshold: list = None,
spatial_range: int = None,
temporal_broadcast: bool = False,
temporal_threshold: list = None,
temporal_range: int = None,
mlp_broadcast: bool = False,
mlp_spatial_broadcast_config: dict = None,
mlp_temporal_broadcast_config: dict = None,
):
self.steps = steps
self.cross_broadcast = cross_broadcast
self.cross_threshold = cross_threshold
self.cross_range = cross_range
self.spatial_broadcast = spatial_broadcast
self.spatial_threshold = spatial_threshold
self.spatial_range = spatial_range
self.temporal_broadcast = temporal_broadcast
self.temporal_threshold = temporal_threshold
self.temporal_range = temporal_range
self.mlp_broadcast = mlp_broadcast
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
self.mlp_temporal_outputs = {}
self.mlp_spatial_outputs = {}
class PABManager:
def __init__(self, config: PABConfig):
self.config: PABConfig = config
init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
logger.info(init_prompt)
def if_broadcast_cross(self, timestep: int, count: int):
if (
self.config.cross_broadcast
and (timestep is not None)
and (count % self.config.cross_range != 0)
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
def if_broadcast_temporal(self, timestep: int, count: int):
if (
self.config.temporal_broadcast
and (timestep is not None)
and (count % self.config.temporal_range != 0)
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
if (
self.config.spatial_broadcast
and (timestep is not None)
and (count % self.config.spatial_range != 0)
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
):
flag = True
else:
flag = False
count = (count + 1) % self.config.steps
return flag, count
@staticmethod
def _is_t_in_skip_config(all_timesteps, timestep, config):
is_t_in_skip_config = False
skip_range = None
for key in config:
if key not in all_timesteps:
continue
index = all_timesteps.index(key)
skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
if timestep in skip_range:
is_t_in_skip_config = True
skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
break
return is_t_in_skip_config, skip_range
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if not self.config.mlp_broadcast:
return False, None, False, None
if is_temporal:
cur_config = self.config.mlp_temporal_broadcast_config
else:
cur_config = self.config.mlp_spatial_broadcast_config
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
next_flag = False
if (
self.config.mlp_broadcast
and (timestep is not None)
and (timestep in cur_config)
and (block_idx in cur_config[timestep]["block"])
):
flag = False
next_flag = True
count = count + 1
elif (
self.config.mlp_broadcast
and (timestep is not None)
and (is_t_in_skip_config)
and (block_idx in cur_config[skip_range[0]]["block"])
):
flag = True
count = 0
else:
flag = False
return flag, count, next_flag, skip_range
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
if is_temporal:
self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
else:
self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
skip_start_t = skip_range[0]
if is_temporal:
skip_output = (
self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
if self.config.mlp_temporal_outputs is not None
else None
)
else:
skip_output = (
self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
if self.config.mlp_spatial_outputs is not None
else None
)
if skip_output is not None:
if timestep == skip_range[-1]:
# TODO: save memory
if is_temporal:
del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
else:
del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
else:
raise ValueError(
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
)
return skip_output
def get_spatial_mlp_outputs(self):
return self.config.mlp_spatial_outputs
def get_temporal_mlp_outputs(self):
return self.config.mlp_temporal_outputs
def set_pab_manager(config: PABConfig):
global PAB_MANAGER
PAB_MANAGER = PABManager(config)
def enable_pab():
if PAB_MANAGER is None:
return False
return (
PAB_MANAGER.config.cross_broadcast
or PAB_MANAGER.config.spatial_broadcast
or PAB_MANAGER.config.temporal_broadcast
)
def update_steps(steps: int):
if PAB_MANAGER is not None:
PAB_MANAGER.config.steps = steps
def if_broadcast_cross(timestep: int, count: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_cross(timestep, count)
def if_broadcast_temporal(timestep: int, count: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
if not enable_pab():
return False, count
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if not enable_pab():
return False, count
return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)

View File

@ -0,0 +1,120 @@
from typing import Optional
import torch
import torch.distributed as dist
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from torch.distributed import ProcessGroup
from videosys.utils.logging import init_dist_logger, logger
from videosys.utils.utils import set_seed
PARALLEL_MANAGER = None
class ParallelManager(ProcessGroupMesh):
def __init__(self, dp_size, cp_size, sp_size):
super().__init__(dp_size, cp_size, sp_size)
dp_axis, cp_axis, sp_axis = 0, 1, 2
self.dp_size = dp_size
self.dp_group: ProcessGroup = self.get_group_along_axis(dp_axis)
self.dp_rank = dist.get_rank(self.dp_group)
self.cp_size = cp_size
self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
self.cp_rank = dist.get_rank(self.cp_group)
self.sp_size = sp_size
self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
self.sp_rank = dist.get_rank(self.sp_group)
self.enable_sp = sp_size > 1
logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}")
def set_parallel_manager(dp_size, cp_size, sp_size):
global PARALLEL_MANAGER
PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size)
def get_data_parallel_group():
return PARALLEL_MANAGER.dp_group
def get_data_parallel_size():
return PARALLEL_MANAGER.dp_size
def get_data_parallel_rank():
return PARALLEL_MANAGER.dp_rank
def get_sequence_parallel_group():
return PARALLEL_MANAGER.sp_group
def get_sequence_parallel_size():
return PARALLEL_MANAGER.sp_size
def get_sequence_parallel_rank():
return PARALLEL_MANAGER.sp_rank
def get_cfg_parallel_group():
return PARALLEL_MANAGER.cp_group
def get_cfg_parallel_size():
return PARALLEL_MANAGER.cp_size
def enable_sequence_parallel():
if PARALLEL_MANAGER is None:
return False
return PARALLEL_MANAGER.enable_sp
def get_parallel_manager():
return PARALLEL_MANAGER
def initialize(
rank=0,
world_size=1,
init_method=None,
seed: Optional[int] = None,
sp_size: Optional[int] = None,
enable_cp: bool = False,
):
if not dist.is_initialized():
try:
dist.destroy_process_group()
except Exception:
pass
dist.init_process_group(backend="nccl", init_method=init_method, world_size=world_size, rank=rank)
torch.cuda.set_device(rank)
init_dist_logger()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# init sequence parallel
if sp_size is None:
sp_size = dist.get_world_size()
dp_size = 1
else:
assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size"
dp_size = dist.get_world_size() // sp_size
# update cfg parallel
# NOTE: enable cp parallel will be slower. disable it for now.
if False and enable_cp and sp_size % 2 == 0:
sp_size = sp_size // 2
cp_size = 2
else:
cp_size = 1
set_parallel_manager(dp_size, cp_size, sp_size)
if seed is not None:
set_seed(seed + get_data_parallel_rank())

52
videosys/core/pipeline.py Normal file
View File

@ -0,0 +1,52 @@
import inspect
from abc import abstractmethod
from dataclasses import dataclass
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import BaseOutput
class VideoSysPipeline(DiffusionPipeline):
def __init__(self):
super().__init__()
@staticmethod
def set_eval_and_device(device: torch.device, *modules):
for module in modules:
module.eval()
module.to(device)
@abstractmethod
def generate(self, *args, **kwargs):
pass
def __call__(self, *args, **kwargs):
"""
In diffusers, it is a convention to call the pipeline object.
But in VideoSys, we will use the generate method for better prompt.
This is a wrapper for the generate method to support the diffusers usage.
"""
return self.generate(*args, **kwargs)
@classmethod
def _get_signature_keys(cls, obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
# modify: remove the config module from the expected modules
expected_modules = expected_modules - {"config"}
optional_names = list(optional_parameters)
for name in optional_names:
if name in cls._optional_components:
expected_modules.add(name)
optional_parameters.remove(name)
return expected_modules, optional_parameters
@dataclass
class VideoSysPipelineOutput(BaseOutput):
video: torch.Tensor

View File

View File

View File

@ -0,0 +1,39 @@
import torch
import torch.nn as nn
class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
@staticmethod
def from_native_module(module, *args, **kwargs):
assert module.__class__.__name__ == "FusedRMSNorm", (
"Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm."
"Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48"
)
layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps)
layer_norm.weight.data.copy_(module.weight.data)
layer_norm = layer_norm.to(module.weight.device)
return layer_norm

View File

@ -0,0 +1,68 @@
from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func
from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward
from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription
class T5EncoderPolicy(Policy):
def config_sanity_check(self):
assert not self.shard_config.enable_tensor_parallelism
assert not self.shard_config.enable_flash_attention
def preprocess(self):
return self.model
def module_policy(self):
from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack
policy = {}
# check whether apex is installed
try:
from apex.normalization import FusedRMSNorm # noqa
from videosys.core.shardformer.t5.modeling import T5LayerNorm
# recover hf from fused rms norm to T5 norm which is faster
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=T5LayerNorm,
),
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm),
policy=policy,
target_key=T5LayerSelfAttention,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm),
policy=policy,
target_key=T5Stack,
)
except (ImportError, ModuleNotFoundError):
pass
# use jit operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_T5_layer_ff_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_method_replacement(
description={
"forward": get_T5_layer_self_attention_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerSelfAttention,
)
return policy
def postprocess(self):
return self.model

View File

Binary file not shown.

View File

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,758 @@
# Adapted from OpenSora
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# OpenSora: https://github.com/hpcaitech/Open-Sora
# --------------------------------------------------------
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from einops import rearrange
from transformers import PretrainedConfig, PreTrainedModel
class DiagonalGaussianDistribution(object):
def __init__(
self,
parameters,
deterministic=False,
):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device, dtype=self.mean.dtype)
def sample(self):
# torch.randn: standard normal distribution
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device, dtype=self.mean.dtype)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None: # SCH: assumes other is a standard normal distribution
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3, 4])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3, 4],
)
def nll(self, sample, dims=[1, 2, 3, 4]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
def mode(self):
return self.mean
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def pad_at_dim(t, pad, dim=-1):
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = (0, 0) * dims_from_right
return F.pad(t, (*zeros, *pad), mode="constant")
def exists(v):
return v is not None
class CausalConv3d(nn.Module):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode="constant",
strides=None, # allow custom stride
**kwargs,
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop("dilation", 1)
stride = strides[0] if strides is not None else kwargs.pop("stride", 1)
self.pad_mode = pad_mode
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
stride = strides if strides is not None else (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(
self,
in_channels, # SCH: added
filters,
conv_fn,
activation_fn=nn.SiLU,
use_conv_shortcut=False,
num_groups=32,
):
super().__init__()
self.in_channels = in_channels
self.filters = filters
self.activate = activation_fn()
self.use_conv_shortcut = use_conv_shortcut
# SCH: MAGVIT uses GroupNorm by default
self.norm1 = nn.GroupNorm(num_groups, in_channels)
self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
self.norm2 = nn.GroupNorm(num_groups, self.filters)
self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False)
if in_channels != filters:
if self.use_conv_shortcut:
self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
else:
self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False)
def forward(self, x):
residual = x
x = self.norm1(x)
x = self.activate(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.activate(x)
x = self.conv2(x)
if self.in_channels != self.filters: # SCH: ResBlock X->Y
residual = self.conv3(residual)
return x + residual
def get_activation_fn(activation):
if activation == "relu":
activation_fn = nn.ReLU
elif activation == "swish":
activation_fn = nn.SiLU
else:
raise NotImplementedError
return activation_fn
class Encoder(nn.Module):
"""Encoder Blocks."""
def __init__(
self,
in_out_channels=4,
latent_embed_dim=512, # num channels for latent vector
filters=128,
num_res_blocks=4,
channel_multipliers=(1, 2, 2, 4),
temporal_downsample=(False, True, True),
num_groups=32, # for nn.GroupNorm
activation_fn="swish",
):
super().__init__()
self.filters = filters
self.num_res_blocks = num_res_blocks
self.num_blocks = len(channel_multipliers)
self.channel_multipliers = channel_multipliers
self.temporal_downsample = temporal_downsample
self.num_groups = num_groups
self.embedding_dim = latent_embed_dim
self.activation_fn = get_activation_fn(activation_fn)
self.activate = self.activation_fn()
self.conv_fn = CausalConv3d
self.block_args = dict(
conv_fn=self.conv_fn,
activation_fn=self.activation_fn,
use_conv_shortcut=False,
num_groups=self.num_groups,
)
# first layer conv
self.conv_in = self.conv_fn(
in_out_channels,
filters,
kernel_size=(3, 3, 3),
bias=False,
)
# ResBlocks and conv downsample
self.block_res_blocks = nn.ModuleList([])
self.conv_blocks = nn.ModuleList([])
filters = self.filters
prev_filters = filters # record for in_channels
for i in range(self.num_blocks):
filters = self.filters * self.channel_multipliers[i]
block_items = nn.ModuleList([])
for _ in range(self.num_res_blocks):
block_items.append(ResBlock(prev_filters, filters, **self.block_args))
prev_filters = filters # update in_channels
self.block_res_blocks.append(block_items)
if i < self.num_blocks - 1:
if self.temporal_downsample[i]:
t_stride = 2 if self.temporal_downsample[i] else 1
s_stride = 1
self.conv_blocks.append(
self.conv_fn(
prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride)
)
)
prev_filters = filters # update in_channels
else:
# if no t downsample, don't add since this does nothing for pipeline models
self.conv_blocks.append(nn.Identity(prev_filters)) # Identity
prev_filters = filters # update in_channels
# last layer res block
self.res_blocks = nn.ModuleList([])
for _ in range(self.num_res_blocks):
self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args))
prev_filters = filters # update in_channels
# MAGVIT uses Group Normalization
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters)
self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same")
def forward(self, x):
x = self.conv_in(x)
for i in range(self.num_blocks):
for j in range(self.num_res_blocks):
x = self.block_res_blocks[i][j](x)
if i < self.num_blocks - 1:
x = self.conv_blocks[i](x)
for i in range(self.num_res_blocks):
x = self.res_blocks[i](x)
x = self.norm1(x)
x = self.activate(x)
x = self.conv2(x)
return x
class Decoder(nn.Module):
"""Decoder Blocks."""
def __init__(
self,
in_out_channels=4,
latent_embed_dim=512,
filters=128,
num_res_blocks=4,
channel_multipliers=(1, 2, 2, 4),
temporal_downsample=(False, True, True),
num_groups=32, # for nn.GroupNorm
activation_fn="swish",
):
super().__init__()
self.filters = filters
self.num_res_blocks = num_res_blocks
self.num_blocks = len(channel_multipliers)
self.channel_multipliers = channel_multipliers
self.temporal_downsample = temporal_downsample
self.num_groups = num_groups
self.embedding_dim = latent_embed_dim
self.s_stride = 1
self.activation_fn = get_activation_fn(activation_fn)
self.activate = self.activation_fn()
self.conv_fn = CausalConv3d
self.block_args = dict(
conv_fn=self.conv_fn,
activation_fn=self.activation_fn,
use_conv_shortcut=False,
num_groups=self.num_groups,
)
filters = self.filters * self.channel_multipliers[-1]
prev_filters = filters
# last conv
self.conv1 = self.conv_fn(self.embedding_dim, filters, kernel_size=(3, 3, 3), bias=True)
# last layer res block
self.res_blocks = nn.ModuleList([])
for _ in range(self.num_res_blocks):
self.res_blocks.append(ResBlock(filters, filters, **self.block_args))
# ResBlocks and conv upsample
self.block_res_blocks = nn.ModuleList([])
self.num_blocks = len(self.channel_multipliers)
self.conv_blocks = nn.ModuleList([])
# reverse to keep track of the in_channels, but append also in a reverse direction
for i in reversed(range(self.num_blocks)):
filters = self.filters * self.channel_multipliers[i]
# resblock handling
block_items = nn.ModuleList([])
for _ in range(self.num_res_blocks):
block_items.append(ResBlock(prev_filters, filters, **self.block_args))
prev_filters = filters # SCH: update in_channels
self.block_res_blocks.insert(0, block_items) # SCH: append in front
# conv blocks with upsampling
if i > 0:
if self.temporal_downsample[i - 1]:
t_stride = 2 if self.temporal_downsample[i - 1] else 1
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
self.conv_blocks.insert(
0,
self.conv_fn(
prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3)
),
)
else:
self.conv_blocks.insert(
0,
nn.Identity(prev_filters),
)
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters)
self.conv_out = self.conv_fn(filters, in_out_channels, 3)
def forward(self, x):
x = self.conv1(x)
for i in range(self.num_res_blocks):
x = self.res_blocks[i](x)
for i in reversed(range(self.num_blocks)):
for j in range(self.num_res_blocks):
x = self.block_res_blocks[i][j](x)
if i > 0:
t_stride = 2 if self.temporal_downsample[i - 1] else 1
x = self.conv_blocks[i - 1](x)
x = rearrange(
x,
"B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)",
ts=t_stride,
hs=self.s_stride,
ws=self.s_stride,
)
x = self.norm1(x)
x = self.activate(x)
x = self.conv_out(x)
return x
class VAE_Temporal(nn.Module):
def __init__(
self,
in_out_channels=4,
latent_embed_dim=4,
embed_dim=4,
filters=128,
num_res_blocks=4,
channel_multipliers=(1, 2, 2, 4),
temporal_downsample=(True, True, False),
num_groups=32, # for nn.GroupNorm
activation_fn="swish",
):
super().__init__()
self.time_downsample_factor = 2 ** sum(temporal_downsample)
# self.time_padding = self.time_downsample_factor - 1
self.patch_size = (self.time_downsample_factor, 1, 1)
self.out_channels = in_out_channels
# NOTE: following MAGVIT, conv in bias=False in encoder first conv
self.encoder = Encoder(
in_out_channels=in_out_channels,
latent_embed_dim=latent_embed_dim * 2,
filters=filters,
num_res_blocks=num_res_blocks,
channel_multipliers=channel_multipliers,
temporal_downsample=temporal_downsample,
num_groups=num_groups, # for nn.GroupNorm
activation_fn=activation_fn,
)
self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)
self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)
self.decoder = Decoder(
in_out_channels=in_out_channels,
latent_embed_dim=latent_embed_dim,
filters=filters,
num_res_blocks=num_res_blocks,
channel_multipliers=channel_multipliers,
temporal_downsample=temporal_downsample,
num_groups=num_groups, # for nn.GroupNorm
activation_fn=activation_fn,
)
def get_latent_size(self, input_size):
latent_size = []
for i in range(3):
if input_size[i] is None:
lsize = None
elif i == 0:
time_padding = (
0
if (input_size[i] % self.time_downsample_factor == 0)
else self.time_downsample_factor - input_size[i] % self.time_downsample_factor
)
lsize = (input_size[i] + time_padding) // self.patch_size[i]
else:
lsize = input_size[i] // self.patch_size[i]
latent_size.append(lsize)
return latent_size
def encode(self, x):
time_padding = (
0
if (x.shape[2] % self.time_downsample_factor == 0)
else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor
)
x = pad_at_dim(x, (time_padding, 0), dim=2)
encoded_feature = self.encoder(x)
moments = self.quant_conv(encoded_feature).to(x.dtype)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z, num_frames=None):
time_padding = (
0
if (num_frames % self.time_downsample_factor == 0)
else self.time_downsample_factor - num_frames % self.time_downsample_factor
)
z = self.post_quant_conv(z)
x = self.decoder(z)
x = x[:, :, time_padding:]
return x
def forward(self, x, sample_posterior=True):
posterior = self.encode(x)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
recon_video = self.decode(z, num_frames=x.shape[2])
return recon_video, posterior, z
def VAE_Temporal_SD(**kwargs):
model = VAE_Temporal(
in_out_channels=4,
latent_embed_dim=4,
embed_dim=4,
filters=128,
num_res_blocks=4,
channel_multipliers=(1, 2, 2, 4),
temporal_downsample=(False, True, True),
**kwargs,
)
return model
class VideoAutoencoderKL(nn.Module):
def __init__(
self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None
):
super().__init__()
self.module = AutoencoderKL.from_pretrained(
from_pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only,
subfolder=subfolder,
)
self.out_channels = self.module.config.latent_channels
self.patch_size = (1, 8, 8)
self.micro_batch_size = micro_batch_size
def encode(self, x):
# x: (B, C, T, H, W)
B = x.shape[0]
x = rearrange(x, "B C T H W -> (B T) C H W")
if self.micro_batch_size is None:
x = self.module.encode(x).latent_dist.sample().mul_(0.18215)
else:
# NOTE: cannot be used for training
bs = self.micro_batch_size
x_out = []
for i in range(0, x.shape[0], bs):
x_bs = x[i : i + bs]
x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215)
x_out.append(x_bs)
x = torch.cat(x_out, dim=0)
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
return x
def decode(self, x, **kwargs):
# x: (B, C, T, H, W)
B = x.shape[0]
x = rearrange(x, "B C T H W -> (B T) C H W")
if self.micro_batch_size is None:
x = self.module.decode(x / 0.18215).sample
else:
# NOTE: cannot be used for training
bs = self.micro_batch_size
x_out = []
for i in range(0, x.shape[0], bs):
x_bs = x[i : i + bs]
x_bs = self.module.decode(x_bs / 0.18215).sample
x_out.append(x_bs)
x = torch.cat(x_out, dim=0)
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
return x
def get_latent_size(self, input_size):
latent_size = []
for i in range(3):
# assert (
# input_size[i] is None or input_size[i] % self.patch_size[i] == 0
# ), "Input size must be divisible by patch size"
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
return latent_size
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
class VideoAutoencoderKLTemporalDecoder(nn.Module):
def __init__(self, from_pretrained=None, cache_dir=None, local_files_only=False):
super().__init__()
self.module = AutoencoderKLTemporalDecoder.from_pretrained(
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only
)
self.out_channels = self.module.config.latent_channels
self.patch_size = (1, 8, 8)
def encode(self, x):
raise NotImplementedError
def decode(self, x, **kwargs):
B, _, T = x.shape[:3]
x = rearrange(x, "B C T H W -> (B T) C H W")
x = self.module.decode(x / 0.18215, num_frames=T).sample
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
return x
def get_latent_size(self, input_size):
latent_size = []
for i in range(3):
# assert (
# input_size[i] is None or input_size[i] % self.patch_size[i] == 0
# ), "Input size must be divisible by patch size"
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
return latent_size
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
class VideoAutoencoderPipelineConfig(PretrainedConfig):
model_type = "VideoAutoencoderPipeline"
def __init__(
self,
vae_2d=None,
vae_temporal=None,
from_pretrained=None,
freeze_vae_2d=False,
cal_loss=False,
micro_frame_size=None,
shift=0.0,
scale=1.0,
**kwargs,
):
self.vae_2d = vae_2d
self.vae_temporal = vae_temporal
self.from_pretrained = from_pretrained
self.freeze_vae_2d = freeze_vae_2d
self.cal_loss = cal_loss
self.micro_frame_size = micro_frame_size
self.shift = shift
self.scale = scale
super().__init__(**kwargs)
class VideoAutoencoderPipeline(PreTrainedModel):
config_class = VideoAutoencoderPipelineConfig
def __init__(self, config: VideoAutoencoderPipelineConfig):
super().__init__(config=config)
self.spatial_vae = VideoAutoencoderKL(
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
local_files_only=False,
micro_batch_size=4,
subfolder="vae",
)
self.temporal_vae = VAE_Temporal_SD()
self.cal_loss = config.cal_loss
self.micro_frame_size = config.micro_frame_size
self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0]
if config.freeze_vae_2d:
for param in self.spatial_vae.parameters():
param.requires_grad = False
self.out_channels = self.temporal_vae.out_channels
# normalization parameters
scale = torch.tensor(config.scale)
shift = torch.tensor(config.shift)
if len(scale.shape) > 0:
scale = scale[None, :, None, None, None]
if len(shift.shape) > 0:
shift = shift[None, :, None, None, None]
self.register_buffer("scale", scale)
self.register_buffer("shift", shift)
def encode(self, x):
x_z = self.spatial_vae.encode(x)
if self.micro_frame_size is None:
posterior = self.temporal_vae.encode(x_z)
z = posterior.sample()
else:
z_list = []
for i in range(0, x_z.shape[2], self.micro_frame_size):
x_z_bs = x_z[:, :, i : i + self.micro_frame_size]
posterior = self.temporal_vae.encode(x_z_bs)
z_list.append(posterior.sample())
z = torch.cat(z_list, dim=2)
if self.cal_loss:
return z, posterior, x_z
else:
return (z - self.shift) / self.scale
def decode(self, z, num_frames=None):
device = z.device
self.scale = self.scale.to(device)
self.shift = self.shift.to(device)
if not self.cal_loss:
z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype)
if self.micro_frame_size is None:
x_z = self.temporal_vae.decode(z, num_frames=num_frames)
x = self.spatial_vae.decode(x_z)
else:
x_z_list = []
for i in range(0, z.size(2), self.micro_z_frame_size):
z_bs = z[:, :, i : i + self.micro_z_frame_size]
x_z_bs = self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames))
x_z_list.append(x_z_bs)
num_frames -= self.micro_frame_size
x_z = torch.cat(x_z_list, dim=2)
x = self.spatial_vae.decode(x_z)
if self.cal_loss:
return x, x_z
else:
return x
def forward(self, x):
assert self.cal_loss, "This method is only available when cal_loss is True"
z, posterior, x_z = self.encode(x)
x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2])
return x_rec, x_z_rec, z, posterior, x_z
def get_latent_size(self, input_size):
if self.micro_frame_size is None or input_size[0] is None:
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
else:
sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]]
sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size))
sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size)
remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None]
if remain_temporal_size[0] > 0:
remain_size = self.temporal_vae.get_latent_size(remain_temporal_size)
sub_latent_size[0] += remain_size[0]
return sub_latent_size
def get_temporal_last_layer(self):
return self.temporal_vae.decoder.conv_out.conv.weight
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def OpenSoraVAE_V1_2(
micro_batch_size=4,
micro_frame_size=17,
from_pretrained=None,
freeze_vae_2d=False,
cal_loss=False,
):
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=micro_batch_size,
)
vae_temporal = dict(
type="VAE_Temporal_SD",
from_pretrained=None,
)
shift = (-0.10, 0.34, 0.27, 0.98)
scale = (3.85, 2.32, 2.33, 3.06)
kwargs = dict(
vae_2d=vae_2d,
vae_temporal=vae_temporal,
freeze_vae_2d=freeze_vae_2d,
cal_loss=cal_loss,
micro_frame_size=micro_frame_size,
shift=shift,
scale=scale,
)
model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs)
return model

File diff suppressed because it is too large Load Diff

View File

View File

@ -0,0 +1,3 @@
import torch.nn as nn
approx_gelu = lambda: nn.GELU(approximate="tanh")

View File

@ -0,0 +1,205 @@
from dataclasses import dataclass
from typing import Iterable, List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from videosys.models.modules.normalization import LlamaRMSNorm
class OpenSoraAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = LlamaRMSNorm,
enable_flash_attn: bool = False,
rope=None,
qk_norm_legacy: bool = False,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flash_attn = enable_flash_attn
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.qk_norm_legacy = qk_norm_legacy
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = False
if rope is not None:
self.rope = True
self.rotary_emb = rope
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
# flash attn is not memory efficient for small sequences, this is empirical
enable_flash_attn = self.enable_flash_attn and (N > B)
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
if self.qk_norm_legacy:
# WARNING: this may be a bug
if self.rope:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
q, k = self.q_norm(q), self.k_norm(k)
else:
q, k = self.q_norm(q), self.k_norm(k)
if self.rope:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
if enable_flash_attn:
from flash_attn import flash_attn_func
# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
else:
x = F.scaled_dot_product_attention(q, k, v)
x_output_shape = (B, N, C)
if not enable_flash_attn:
x = x.transpose(1, 2)
x = x.reshape(x_output_shape)
x = self.proj(x)
x = self.proj_drop(x)
return x
class OpenSoraMultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, enable_flash_attn=False):
super(OpenSoraMultiHeadCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model * 2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
self.enable_flash_attn = enable_flash_attn
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
if self.enable_flash_attn:
x = self.flash_attn_impl(q, k, v, mask, B, N, C)
else:
x = self.torch_impl(q, k, v, mask, B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def flash_attn_impl(self, q, k, v, mask, B, N, C):
from flash_attn import flash_attn_varlen_func
q_seqinfo = _SeqLenInfo.from_seqlens([N] * B)
k_seqinfo = _SeqLenInfo.from_seqlens(mask)
x = flash_attn_varlen_func(
q.view(-1, self.num_heads, self.head_dim),
k.view(-1, self.num_heads, self.head_dim),
v.view(-1, self.num_heads, self.head_dim),
cu_seqlens_q=q_seqinfo.seqstart.cuda(),
cu_seqlens_k=k_seqinfo.seqstart.cuda(),
max_seqlen_q=q_seqinfo.max_seqlen,
max_seqlen_k=k_seqinfo.max_seqlen,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
x = x.view(B, N, C)
return x
def torch_impl(self, q, k, v, mask, B, N, C):
q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
attn_mask = torch.zeros(B, 1, N, k.shape[2], dtype=torch.bool, device=q.device)
for i, m in enumerate(mask):
attn_mask[i, :, :, :m] = -1e9
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = out.transpose(1, 2).contiguous().view(B, N, C)
return x
@dataclass
class _SeqLenInfo:
"""
from xformers
(Internal) Represents the division of a dimension into blocks.
For example, to represents a dimension of length 7 divided into
three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
The members will be:
max_seqlen: 3
min_seqlen: 2
seqstart_py: [0, 2, 5, 7]
seqstart: torch.IntTensor([0, 2, 5, 7])
"""
seqstart: torch.Tensor
max_seqlen: int
min_seqlen: int
seqstart_py: List[int]
def to(self, device: torch.device) -> None:
self.seqstart = self.seqstart.to(device, non_blocking=True)
def intervals(self) -> Iterable[Tuple[int, int]]:
yield from zip(self.seqstart_py, self.seqstart_py[1:])
@classmethod
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
"""
Input tensors are assumed to be in shape [B, M, *]
"""
assert not isinstance(seqlens, torch.Tensor)
seqstart_py = [0]
max_seqlen = -1
min_seqlen = -1
for seqlen in seqlens:
min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
max_seqlen = max(max_seqlen, seqlen)
seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
return cls(
max_seqlen=max_seqlen,
min_seqlen=min_seqlen,
seqstart=seqstart,
seqstart_py=seqstart_py,
)

View File

@ -0,0 +1,71 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class CogVideoXDownsample3D(nn.Module):
# Todo: Wait for paper relase.
r"""
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
Args:
in_channels (`int`):
Number of channels in the input image.
out_channels (`int`):
Number of channels produced by the convolution.
kernel_size (`int`, defaults to `3`):
Size of the convolving kernel.
stride (`int`, defaults to `2`):
Stride of the convolution.
padding (`int`, defaults to `0`):
Padding added to all four sides of the input.
compress_time (`bool`, defaults to `False`):
Whether or not to compress the time dimension.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 2,
padding: int = 0,
compress_time: bool = False,
):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.compress_time:
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
if x.shape[-1] % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
else:
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
x = F.avg_pool1d(x, kernel_size=2, stride=2)
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
# Pad the tensor
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
x = self.conv(x)
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
return x

View File

@ -0,0 +1,412 @@
import functools
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
from timm.models.vision_transformer import Mlp
class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
text_embeds = self.text_proj(text_embeds)
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
return embeds
class OpenSoraPatchEmbed3D(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
patch_size=(2, 4, 4),
in_chans=3,
embed_dim=96,
norm_layer=None,
flatten=True,
):
super().__init__()
self.patch_size = patch_size
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, D, H, W = x.size()
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if D % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
x = self.proj(x) # (B C T H W)
if self.norm is not None:
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
return x
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
freqs = freqs.to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
if t_freq.dtype != dtype:
t_freq = t_freq.to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class SizeEmbedder(TimestepEmbedder):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
self.outdim = hidden_size
def forward(self, s, bs):
if s.ndim == 1:
s = s[:, None]
assert s.ndim == 2
if s.shape[0] != bs:
s = s.repeat(bs // s.shape[0], 1)
assert s.shape[0] == bs
b, dims = s.shape[0], s.shape[1]
s = rearrange(s, "b d -> (b d)")
s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
s_emb = self.mlp(s_freq)
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return s_emb
@property
def dtype(self):
return next(self.parameters()).dtype
class OpenSoraCaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(
self,
in_channels,
hidden_size,
uncond_prob,
act_layer=nn.GELU(approximate="tanh"),
token_num=120,
):
super().__init__()
self.y_proj = Mlp(
in_features=in_channels,
hidden_features=hidden_size,
out_features=hidden_size,
act_layer=act_layer,
drop=0,
)
self.register_buffer(
"y_embedding",
torch.randn(token_num, in_channels) / in_channels**0.5,
)
self.uncond_prob = uncond_prob
def token_drop(self, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return caption
def forward(self, caption, train, force_drop_ids=None):
if train:
assert caption.shape[2:] == self.y_embedding.shape
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption)
return caption
class OpenSoraPositionEmbedding2D(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
assert dim % 4 == 0, "dim must be divisible by 4"
half_dim = dim // 2
inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def _get_sin_cos_emb(self, t: torch.Tensor):
out = torch.einsum("i,d->id", t, self.inv_freq)
emb_cos = torch.cos(out)
emb_sin = torch.sin(out)
return torch.cat((emb_sin, emb_cos), dim=-1)
@functools.lru_cache(maxsize=512)
def _get_cached_emb(
self,
device: torch.device,
dtype: torch.dtype,
h: int,
w: int,
scale: float = 1.0,
base_size: Optional[int] = None,
):
grid_h = torch.arange(h, device=device) / scale
grid_w = torch.arange(w, device=device) / scale
if base_size is not None:
grid_h *= base_size / h
grid_w *= base_size / w
grid_h, grid_w = torch.meshgrid(
grid_w,
grid_h,
indexing="ij",
) # here w goes first
grid_h = grid_h.t().reshape(-1)
grid_w = grid_w.t().reshape(-1)
emb_h = self._get_sin_cos_emb(grid_h)
emb_w = self._get_sin_cos_emb(grid_w)
return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype)
def forward(
self,
x: torch.Tensor,
h: int,
w: int,
scale: Optional[float] = 1.0,
base_size: Optional[int] = None,
) -> torch.Tensor:
return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
crops_coords (`Tuple[int]`):
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the spatial positional embedding (height, width).
temporal_size (`int`):
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
grid_t = torch.from_numpy(grid_t).float()
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
# Spatial frequencies for height and width
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
# Broadcast and concatenate tensors along specified dimension
def broadcast(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = {len(t.shape) for t in tensors}
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*(list(t.shape) for t in tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
), "invalid dimensions for broadcastable concatenation"
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
return torch.cat(tensors, dim=dim)
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
t, h, w, d = freqs.shape
freqs = freqs.view(t * h * w, d)
# Generate sine and cosine components
sin = freqs.sin()
cos = freqs.cos()
if use_real:
return cos, sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Use for example in Lumina
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Use for example in Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)

View File

@ -0,0 +1,102 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class CogVideoXLayerNormZero(nn.Module):
def __init__(
self,
conditioning_dim: int,
embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
class AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
output_dim (`int`, *optional*):
norm_elementwise_affine (`bool`, defaults to `False):
norm_eps (`bool`, defaults to `False`):
chunk_dim (`int`, defaults to `0`):
"""
def __init__(
self,
embedding_dim: int,
num_embeddings: Optional[int] = None,
output_dim: Optional[int] = None,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
super().__init__()
self.chunk_dim = chunk_dim
output_dim = output_dim or embedding_dim * 2
if num_embeddings is not None:
self.emb = nn.Embedding(num_embeddings, embedding_dim)
else:
self.emb = None
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.emb is not None:
temb = self.emb(timestep)
temb = self.linear(self.silu(temb))
if self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
# other if-branch. This branch is specific to CogVideoX for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
else:
scale, shift = temb.chunk(2, dim=0)
x = self.norm(x) * (1 + scale) + shift
return x

View File

@ -0,0 +1,67 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class CogVideoXUpsample3D(nn.Module):
r"""
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
Args:
in_channels (`int`):
Number of channels in the input image.
out_channels (`int`):
Number of channels produced by the convolution.
kernel_size (`int`, defaults to `3`):
Size of the convolving kernel.
stride (`int`, defaults to `1`):
Stride of the convolution.
padding (`int`, defaults to `1`):
Padding added to all four sides of the input.
compress_time (`bool`, defaults to `False`):
Whether or not to compress the time dimension.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
compress_time: bool = False,
) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if self.compress_time:
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
# split first frame
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
x_first = F.interpolate(x_first, scale_factor=2.0)
x_rest = F.interpolate(x_rest, scale_factor=2.0)
x_first = x_first[:, :, None, :, :]
inputs = torch.cat([x_first, x_rest], dim=2)
elif inputs.shape[2] > 1:
inputs = F.interpolate(inputs, scale_factor=2.0)
else:
inputs = inputs.squeeze(2)
inputs = F.interpolate(inputs, scale_factor=2.0)
inputs = inputs[:, :, None, :, :]
else:
# only interpolate 2D
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = F.interpolate(inputs, scale_factor=2.0)
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = self.conv(inputs)
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
return inputs

View File

View File

@ -0,0 +1,591 @@
# Adapted from CogVideo
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# CogVideo: https://github.com/THUDM/CogVideo
# diffusers: https://github.com/huggingface/diffusers
# --------------------------------------------------------
from functools import partial
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import is_torch_version
from diffusers.utils.torch_utils import maybe_allow_in_graph
from torch import nn
from videosys.core.comm import all_to_all_comm, gather_sequence, get_spatial_pad, set_spatial_pad, split_sequence
from videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
from videosys.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_group,
get_cfg_parallel_size,
get_sequence_parallel_group,
get_sequence_parallel_size,
)
from videosys.models.modules.embeddings import apply_rotary_emb
from videosys.utils.utils import batch_func
from ..modules.embeddings import CogVideoXPatchEmbed
from ..modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
if enable_sequence_parallel():
assert (
attn.heads % get_sequence_parallel_size() == 0
), f"Number of heads {attn.heads} must be divisible by sequence parallel size {get_sequence_parallel_size()}"
attn_heads = attn.heads // get_sequence_parallel_size()
query, key, value = map(
lambda x: all_to_all_comm(x, get_sequence_parallel_group(), scatter_dim=2, gather_dim=1),
[query, key, value],
)
else:
attn_heads = attn.heads
inner_dim = key.shape[-1]
head_dim = inner_dim // attn_heads
query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
emb_len = image_rotary_emb[0].shape[0]
query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
)
if not attn.is_cross_attention:
key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
if enable_sequence_parallel():
hidden_states = all_to_all_comm(hidden_states, get_sequence_parallel_group(), scatter_dim=1, gather_dim=2)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*, defaults to `None`):
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in Feed-forward layer.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in Attention output projection layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
block_idx: int = 0,
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
)
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# pab
self.attn_count = 0
self.last_attn = None
self.block_idx = block_idx
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep=None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# attention
if enable_pab():
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
if enable_pab() and broadcast_attn:
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if enable_pab():
self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters:
num_attention_heads (`int`, defaults to `30`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `16`):
The number of channels in the output.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
sample_frames (`int`, defaults to `49`):
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
temporal_compression_ratio (`int`, defaults to `4`):
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
max_text_seq_length (`int`, defaults to `226`):
The maximum sequence length of the input text embeddings.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers.
spatial_interpolation_scale (`float`, defaults to `1.875`):
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
temporal_interpolation_scale (`float`, defaults to `1.0`):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: int = 16,
out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 3. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 4. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
all_timesteps=None,
):
if get_cfg_parallel_size() > 1:
(
hidden_states,
encoder_hidden_states,
timestep,
timestep_cond,
image_rotary_emb,
) = batch_func(
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
hidden_states,
encoder_hidden_states,
timestep,
timestep_cond,
image_rotary_emb,
)
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
text_seq_length = encoder_hidden_states.shape[1]
if not self.config.use_rotary_positional_embeddings:
seq_length = height * width * num_frames // (self.config.patch_size**2)
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
if enable_sequence_parallel():
set_spatial_pad(hidden_states.shape[1])
hidden_states = split_sequence(hidden_states, get_sequence_parallel_group(), dim=1, pad=get_spatial_pad())
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
timestep=timesteps if enable_pab() else None,
)
if enable_sequence_parallel():
hidden_states = gather_sequence(hidden_states, get_sequence_parallel_group(), dim=1, pad=get_spatial_pad())
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if get_cfg_parallel_size() > 1:
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,636 @@
# Adapted from OpenSora
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# OpenSora: https://github.com/hpcaitech/Open-Sora
# --------------------------------------------------------
from collections.abc import Iterable
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from transformers import PretrainedConfig, PreTrainedModel
from videosys.core.comm import (
all_to_all_with_pad,
gather_sequence,
get_spatial_pad,
get_temporal_pad,
set_spatial_pad,
set_temporal_pad,
split_sequence,
)
from videosys.core.pab_mgr import (
enable_pab,
get_mlp_output,
if_broadcast_cross,
if_broadcast_mlp,
if_broadcast_spatial,
if_broadcast_temporal,
save_mlp_output,
)
from videosys.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_size,
get_data_parallel_group,
get_sequence_parallel_group,
)
from videosys.models.modules.activations import approx_gelu
from videosys.models.modules.attentions import OpenSoraAttention, OpenSoraMultiHeadCrossAttention
from videosys.models.modules.embeddings import (
OpenSoraCaptionEmbedder,
OpenSoraPatchEmbed3D,
OpenSoraPositionEmbedding2D,
SizeEmbedder,
TimestepEmbedder,
)
from videosys.utils.utils import batch_func
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
class T2IFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
self.out_channels = out_channels
self.d_t = d_t
self.d_s = d_s
def t_mask_select(self, x_mask, x, masked_x, T, S):
# x: [B, (T, S), C]
# mased_x: [B, (T, S), C]
# x_mask: [B, T]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
x = torch.where(x_mask[:, :, None, None], x, masked_x)
x = rearrange(x, "B T S C -> B (T S) C")
return x
def forward(self, x, t, x_mask=None, t0=None, T=None, S=None):
if T is None:
T = self.d_t
if S is None:
S = self.d_s
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
if x_mask is not None:
shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1)
x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
x = self.t_mask_select(x_mask, x, x_zero, T, S)
x = self.linear(x)
return x
def auto_grad_checkpoint(module, *args, **kwargs):
if getattr(module, "grad_checkpointing", False):
if not isinstance(module, Iterable):
return checkpoint(module, *args, use_reentrant=False, **kwargs)
gc_step = module[0].grad_checkpointing_step
return checkpoint_sequential(module, gc_step, *args, use_reentrant=False, **kwargs)
return module(*args, **kwargs)
class STDiT3Block(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
mlp_ratio=4.0,
drop_path=0.0,
rope=None,
qk_norm=False,
temporal=False,
enable_flash_attn=False,
block_idx=None,
):
super().__init__()
self.temporal = temporal
self.hidden_size = hidden_size
self.enable_flash_attn = enable_flash_attn
self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
self.attn = OpenSoraAttention(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
qk_norm=qk_norm,
rope=rope,
enable_flash_attn=enable_flash_attn,
)
self.cross_attn = OpenSoraMultiHeadCrossAttention(hidden_size, num_heads, enable_flash_attn=enable_flash_attn)
self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
# pab
self.block_idx = block_idx
self.attn_count = 0
self.last_attn = None
self.cross_count = 0
self.last_cross = None
self.mlp_count = 0
def t_mask_select(self, x_mask, x, masked_x, T, S):
# x: [B, (T, S), C]
# mased_x: [B, (T, S), C]
# x_mask: [B, T]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
x = torch.where(x_mask[:, :, None, None], x, masked_x)
x = rearrange(x, "B T S C -> B (T S) C")
return x
def forward(
self,
x,
y,
t,
mask=None, # text mask
x_mask=None, # temporal mask
t0=None, # t with timestamp=0
T=None, # number of frames
S=None, # number of pixel patches
timestep=None,
all_timesteps=None,
):
# prepare modulate parameters
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
if x_mask is not None:
shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = (
self.scale_shift_table[None] + t0.reshape(B, 6, -1)
).chunk(6, dim=1)
if enable_pab():
if self.temporal:
broadcast_attn, self.attn_count = if_broadcast_temporal(int(timestep[0]), self.attn_count)
else:
broadcast_attn, self.attn_count = if_broadcast_spatial(
int(timestep[0]), self.attn_count, self.block_idx
)
if enable_pab() and broadcast_attn:
x_m_s = self.last_attn
else:
# modulate (attention)
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
if x_mask is not None:
x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero)
x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
# attention
if self.temporal:
if enable_sequence_parallel():
x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=True)
x_m = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S)
x_m = self.attn(x_m)
x_m = rearrange(x_m, "(B S) T C -> B (T S) C", T=T, S=S)
if enable_sequence_parallel():
x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=False)
else:
x_m = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S)
x_m = self.attn(x_m)
x_m = rearrange(x_m, "(B T) S C -> B (T S) C", T=T, S=S)
# modulate (attention)
x_m_s = gate_msa * x_m
if x_mask is not None:
x_m_s_zero = gate_msa_zero * x_m
x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, T, S)
if enable_pab():
self.last_attn = x_m_s
# residual
x = x + self.drop_path(x_m_s)
# cross attention
if enable_pab():
broadcast_cross, self.cross_count = if_broadcast_cross(int(timestep[0]), self.cross_count)
if enable_pab() and broadcast_cross:
x = x + self.last_cross
else:
x_cross = self.cross_attn(x, y, mask)
if enable_pab():
self.last_cross = x_cross
x = x + x_cross
if enable_pab():
broadcast_mlp, self.mlp_count, broadcast_next, skip_range = if_broadcast_mlp(
int(timestep[0]),
self.mlp_count,
self.block_idx,
all_timesteps,
is_temporal=self.temporal,
)
if enable_pab() and broadcast_mlp:
x_m_s = get_mlp_output(
skip_range,
timestep=int(timestep[0]),
block_idx=self.block_idx,
is_temporal=self.temporal,
)
else:
# modulate (MLP)
x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)
if x_mask is not None:
x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero)
x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
# MLP
x_m = self.mlp(x_m)
# modulate (MLP)
x_m_s = gate_mlp * x_m
if x_mask is not None:
x_m_s_zero = gate_mlp_zero * x_m
x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, T, S)
if enable_pab() and broadcast_next:
save_mlp_output(
timestep=int(timestep[0]),
block_idx=self.block_idx,
ff_output=x_m_s,
is_temporal=self.temporal,
)
# residual
x = x + self.drop_path(x_m_s)
return x
def dynamic_switch(self, x, s, t, to_spatial_shard: bool):
if to_spatial_shard:
scatter_dim, gather_dim = 2, 1
scatter_pad = get_spatial_pad()
gather_pad = get_temporal_pad()
else:
scatter_dim, gather_dim = 1, 2
scatter_pad = get_temporal_pad()
gather_pad = get_spatial_pad()
x = rearrange(x, "b (t s) d -> b t s d", t=t, s=s)
x = all_to_all_with_pad(
x,
get_sequence_parallel_group(),
scatter_dim=scatter_dim,
gather_dim=gather_dim,
scatter_pad=scatter_pad,
gather_pad=gather_pad,
)
new_s, new_t = x.shape[2], x.shape[1]
x = rearrange(x, "b t s d -> b (t s) d")
return x, new_s, new_t
class STDiT3Config(PretrainedConfig):
model_type = "STDiT3"
def __init__(
self,
input_size=(None, None, None),
input_sq_size=512,
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path=0.0,
caption_channels=4096,
model_max_length=300,
qk_norm=True,
enable_flash_attn=False,
only_train_temporal=False,
freeze_y_embedder=False,
skip_y_embedder=False,
**kwargs,
):
self.input_size = input_size
self.input_sq_size = input_sq_size
self.in_channels = in_channels
self.patch_size = patch_size
self.hidden_size = hidden_size
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.class_dropout_prob = class_dropout_prob
self.pred_sigma = pred_sigma
self.drop_path = drop_path
self.caption_channels = caption_channels
self.model_max_length = model_max_length
self.qk_norm = qk_norm
self.enable_flash_attn = enable_flash_attn
self.only_train_temporal = only_train_temporal
self.freeze_y_embedder = freeze_y_embedder
self.skip_y_embedder = skip_y_embedder
super().__init__(**kwargs)
class STDiT3(PreTrainedModel):
config_class = STDiT3Config
def __init__(self, config):
super().__init__(config)
self.pred_sigma = config.pred_sigma
self.in_channels = config.in_channels
self.out_channels = config.in_channels * 2 if config.pred_sigma else config.in_channels
# model size related
self.depth = config.depth
self.mlp_ratio = config.mlp_ratio
self.hidden_size = config.hidden_size
self.num_heads = config.num_heads
# computation related
self.drop_path = config.drop_path
self.enable_flash_attn = config.enable_flash_attn
# input size related
self.patch_size = config.patch_size
self.input_sq_size = config.input_sq_size
self.pos_embed = OpenSoraPositionEmbedding2D(config.hidden_size)
from rotary_embedding_torch import RotaryEmbedding
self.rope = RotaryEmbedding(dim=self.hidden_size // self.num_heads)
# embedding
self.x_embedder = OpenSoraPatchEmbed3D(config.patch_size, config.in_channels, config.hidden_size)
self.t_embedder = TimestepEmbedder(config.hidden_size)
self.fps_embedder = SizeEmbedder(self.hidden_size)
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True),
)
self.y_embedder = OpenSoraCaptionEmbedder(
in_channels=config.caption_channels,
hidden_size=config.hidden_size,
uncond_prob=config.class_dropout_prob,
act_layer=approx_gelu,
token_num=config.model_max_length,
)
# spatial blocks
drop_path = [x.item() for x in torch.linspace(0, self.drop_path, config.depth)]
self.spatial_blocks = nn.ModuleList(
[
STDiT3Block(
hidden_size=config.hidden_size,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
drop_path=drop_path[i],
qk_norm=config.qk_norm,
enable_flash_attn=config.enable_flash_attn,
block_idx=i,
)
for i in range(config.depth)
]
)
# temporal blocks
drop_path = [x.item() for x in torch.linspace(0, self.drop_path, config.depth)]
self.temporal_blocks = nn.ModuleList(
[
STDiT3Block(
hidden_size=config.hidden_size,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
drop_path=drop_path[i],
qk_norm=config.qk_norm,
enable_flash_attn=config.enable_flash_attn,
# temporal
temporal=True,
rope=self.rope.rotate_queries_or_keys,
block_idx=i,
)
for i in range(config.depth)
]
)
# final layer
self.final_layer = T2IFinalLayer(config.hidden_size, np.prod(self.patch_size), self.out_channels)
self.initialize_weights()
if config.only_train_temporal:
for param in self.parameters():
param.requires_grad = False
for block in self.temporal_blocks:
for param in block.parameters():
param.requires_grad = True
if config.freeze_y_embedder:
for param in self.y_embedder.parameters():
param.requires_grad = False
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize fps_embedder
nn.init.normal_(self.fps_embedder.mlp[0].weight, std=0.02)
nn.init.constant_(self.fps_embedder.mlp[0].bias, 0)
nn.init.constant_(self.fps_embedder.mlp[2].weight, 0)
nn.init.constant_(self.fps_embedder.mlp[2].bias, 0)
# Initialize timporal blocks
for block in self.temporal_blocks:
nn.init.constant_(block.attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.mlp.fc2.weight, 0)
def get_dynamic_size(self, x):
_, _, T, H, W = x.size()
if T % self.patch_size[0] != 0:
T += self.patch_size[0] - T % self.patch_size[0]
if H % self.patch_size[1] != 0:
H += self.patch_size[1] - H % self.patch_size[1]
if W % self.patch_size[2] != 0:
W += self.patch_size[2] - W % self.patch_size[2]
T = T // self.patch_size[0]
H = H // self.patch_size[1]
W = W // self.patch_size[2]
return (T, H, W)
def encode_text(self, y, mask=None):
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, self.hidden_size)
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, self.hidden_size)
return y, y_lens
def forward(
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
):
# === Split batch ===
if get_cfg_parallel_size() > 1:
x, timestep, y, x_mask, mask = batch_func(
partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
)
dtype = self.x_embedder.proj.weight.dtype
B = x.size(0)
x = x.to(dtype)
timestep = timestep.to(dtype)
y = y.to(dtype)
# === get pos embed ===
_, _, Tx, Hx, Wx = x.size()
T, H, W = self.get_dynamic_size(x)
S = H * W
base_size = round(S**0.5)
resolution_sq = (height[0].item() * width[0].item()) ** 0.5
scale = resolution_sq / self.input_sq_size
pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size)
# === get timestep embed ===
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
fps = self.fps_embedder(fps.unsqueeze(1), B)
t = t + fps
t_mlp = self.t_block(t)
t0 = t0_mlp = None
if x_mask is not None:
t0_timestep = torch.zeros_like(timestep)
t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
t0 = t0 + fps
t0_mlp = self.t_block(t0)
# === get y embed ===
if self.config.skip_y_embedder:
y_lens = mask
if isinstance(y_lens, torch.Tensor):
y_lens = y_lens.long().tolist()
else:
y, y_lens = self.encode_text(y, mask)
# === get x embed ===
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = x + pos_emb
# shard over the sequence dim if sp is enabled
if enable_sequence_parallel():
set_temporal_pad(T)
set_spatial_pad(S)
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
T = x.shape[1]
x_mask_org = x_mask
x_mask = split_sequence(
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
)
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
# === blocks ===
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
x = auto_grad_checkpoint(
spatial_block,
x,
y,
t_mlp,
y_lens,
x_mask,
t0_mlp,
T,
S,
timestep,
all_timesteps=all_timesteps,
)
x = auto_grad_checkpoint(
temporal_block,
x,
y,
t_mlp,
y_lens,
x_mask,
t0_mlp,
T,
S,
timestep,
all_timesteps=all_timesteps,
)
if enable_sequence_parallel():
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad())
T, S = x.shape[1], x.shape[2]
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
x_mask = x_mask_org
# === final layer ===
x = self.final_layer(x, t, x_mask, t0, T, S)
x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
# cast to float32 for better accuracy
x = x.to(torch.float32)
# === Gather Output ===
if get_cfg_parallel_size() > 1:
x = gather_sequence(x, get_data_parallel_group(), dim=0)
return x
def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
# N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
# unpad
x = x[:, :, :R_t, :R_h, :R_w]
return x

View File

View File

@ -0,0 +1,3 @@
from .pipeline_cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
__all__ = ["CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"]

View File

@ -0,0 +1,813 @@
# Adapted from CogVideo
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# CogVideo: https://github.com/THUDM/CogVideo
# diffusers: https://github.com/huggingface/diffusers
# --------------------------------------------------------
import inspect
import math
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from transformers import T5EncoderModel, T5Tokenizer
from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
from videosys.models.autoencoders.autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from videosys.models.modules.embeddings import get_3d_rotary_pos_embed
from videosys.models.transformers.cogvideox_transformer_3d import CogVideoXTransformer3DModel
from videosys.schedulers.scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
from videosys.schedulers.scheduling_dpm_cogvideox import CogVideoXDPMScheduler
from videosys.utils.logging import logger
from videosys.utils.utils import save_video
class CogVideoXPABConfig(PABConfig):
def __init__(
self,
steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
spatial_range: int = 2,
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
)
class CogVideoXConfig:
"""
This config is to instantiate a `CogVideoXPipeline` class for video generation.
To be specific, this config will be passed to engine by `VideoSysEngine(config)`.
In the engine, it will be used to instantiate the corresponding pipeline class.
And the engine will call the `generate` function of the pipeline to generate the video.
If you want to explore the detail of generation, please refer to the pipeline class below.
Args:
model_path (str):
A path to the pretrained pipeline. Defaults to "THUDM/CogVideoX-2b".
num_gpus (int):
The number of GPUs to use. Defaults to 1.
cpu_offload (bool):
Whether to enable CPU offload. Defaults to False.
vae_tiling (bool):
Whether to enable tiling for the VAE. Defaults to True.
enable_pab (bool):
Whether to enable Pyramid Attention Broadcast. Defaults to False.
pab_config (CogVideoXPABConfig):
The configuration for Pyramid Attention Broadcast. Defaults to `CogVideoXPABConfig()`.
Examples:
```python
from videosys import CogVideoXConfig, VideoSysEngine
# models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
# change num_gpus for multi-gpu inference
config = CogVideoXConfig("THUDM/CogVideoX-2b", num_gpus=1)
engine = VideoSysEngine(config)
prompt = "Sunset over the sea."
# num frames should be <= 49. resolution is fixed to 720p.
video = engine.generate(
prompt=prompt,
guidance_scale=6,
num_inference_steps=50,
num_frames=49,
).video[0]
engine.save_video(video, f"./outputs/{prompt}.mp4")
```
"""
def __init__(
self,
model_path: str = "THUDM/CogVideoX-2b",
# ======= distributed ========
num_gpus: int = 1,
# ======= memory =======
cpu_offload: bool = False,
vae_tiling: bool = True,
# ======= pab ========
enable_pab: bool = False,
pab_config=CogVideoXPABConfig(),
):
self.model_path = model_path
self.pipeline_cls = CogVideoXPipeline
# ======= distributed ========
self.num_gpus = num_gpus
# ======= memory ========
self.cpu_offload = cpu_offload
self.vae_tiling = vae_tiling
# ======= pab ========
self.enable_pab = enable_pab
self.pab_config = pab_config
class CogVideoXPipeline(VideoSysPipeline):
_optional_components = ["text_encoder", "tokenizer"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
]
def __init__(
self,
config: CogVideoXConfig,
tokenizer: Optional[T5Tokenizer] = None,
text_encoder: Optional[T5EncoderModel] = None,
vae: Optional[AutoencoderKLCogVideoX] = None,
transformer: Optional[CogVideoXTransformer3DModel] = None,
scheduler: Optional[CogVideoXDDIMScheduler] = None,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self._config = config
self._device = device
if config.model_path == "THUDM/CogVideoX-2b":
dtype = torch.float16
self._dtype = dtype
if transformer is None:
transformer = CogVideoXTransformer3DModel.from_pretrained(
config.model_path, subfolder="transformer", torch_dtype=self._dtype
)
if vae is None:
vae = AutoencoderKLCogVideoX.from_pretrained(config.model_path, subfolder="vae", torch_dtype=self._dtype)
if tokenizer is None:
tokenizer = T5Tokenizer.from_pretrained(config.model_path, subfolder="tokenizer")
if text_encoder is None:
text_encoder = T5EncoderModel.from_pretrained(
config.model_path, subfolder="text_encoder", torch_dtype=self._dtype
)
if scheduler is None:
scheduler = CogVideoXDDIMScheduler.from_pretrained(
config.model_path,
subfolder="scheduler",
)
# set eval and device
self.set_eval_and_device(self._device, vae, transformer)
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
# cpu offload
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
self.set_eval_and_device(self._device, text_encoder)
# vae tiling
if config.vae_tiling:
vae.enable_tiling()
# pab
if config.enable_pab:
set_pab_manager(config.pab_config)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
shape = (
batch_size,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
num_channels_latents,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents
frames = self.vae.decode(latents).sample
return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
def _prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def generate(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
num_frames: int = 49,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
verbose=True,
) -> Union[VideoSysPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
needs to be satisfied is that of divisibility mentioned above.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `226`):
Maximum sequence length in encoded prompt. Must be consistent with
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
Examples:
Returns:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if num_frames > 49:
raise ValueError(
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
update_steps(num_inference_steps)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
# with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in progress_wrap(list(enumerate(timesteps))):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
all_timesteps=timesteps,
)[0]
noise_pred = noise_pred.float()
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
# progress_bar.update()
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return VideoSysPipelineOutput(video=video)
def save_video(self, video, output_path):
save_video(video, output_path, fps=8)
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps

View File

@ -0,0 +1,3 @@
from .pipeline_latte import LatteConfig, LattePABConfig, LattePipeline
__all__ = ["LatteConfig", "LattePipeline", "LattePABConfig"]

View File

@ -0,0 +1,929 @@
# Adapted from Latte
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# Latte: https://github.com/Vchitect/Latte
# --------------------------------------------------------
import html
import inspect
import re
import urllib.parse as ul
from typing import Callable, List, Optional, Tuple, Union
import einops
import ftfy
import torch
import torch.distributed as dist
import tqdm
from bs4 import BeautifulSoup
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from diffusers.schedulers import DDIMScheduler
from diffusers.utils.torch_utils import randn_tensor
from transformers import T5EncoderModel, T5Tokenizer
from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
from videosys.models.transformers.latte_transformer_3d import LatteT2V
from videosys.utils.logging import logger
from videosys.utils.utils import save_video
class LattePABConfig(PABConfig):
def __init__(
self,
steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 800],
spatial_range: int = 2,
temporal_broadcast: bool = True,
temporal_threshold: list = [100, 800],
temporal_range: int = 3,
cross_broadcast: bool = True,
cross_threshold: list = [100, 800],
cross_range: int = 6,
mlp_broadcast: bool = True,
mlp_spatial_broadcast_config: dict = {
720: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
640: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
560: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
480: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
400: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
},
mlp_temporal_broadcast_config: dict = {
720: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
640: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
560: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
480: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
400: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
},
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=temporal_threshold,
temporal_range=temporal_range,
cross_broadcast=cross_broadcast,
cross_threshold=cross_threshold,
cross_range=cross_range,
mlp_broadcast=mlp_broadcast,
mlp_spatial_broadcast_config=mlp_spatial_broadcast_config,
mlp_temporal_broadcast_config=mlp_temporal_broadcast_config,
)
class LatteConfig:
"""
This config is to instantiate a `LattePipeline` class for video generation.
To be specific, this config will be passed to engine by `VideoSysEngine(config)`.
In the engine, it will be used to instantiate the corresponding pipeline class.
And the engine will call the `generate` function of the pipeline to generate the video.
If you want to explore the detail of generation, please refer to the pipeline class below.
Args:
model_path (str):
A path to the pretrained pipeline. Defaults to "maxin-cn/Latte-1".
num_gpus (int):
The number of GPUs to use. Defaults to 1.
enable_vae_temporal_decoder (bool):
Whether to enable VAE Temporal Decoder. Defaults to True.
beta_start (float):
The initial value of beta for DDIM. Defaults to 0.0001.
beta_end (float):
The final value of beta for DDIM. Defaults to 0.02.
beta_schedule (str):
The schedule of beta for DDIM. Defaults to "linear".
variance_type (str):
The type of variance for DDIM. Defaults to "learned_range".
enable_pab (bool):
Whether to enable Pyramid Attention Broadcast. Defaults to False.
pab_config (CogVideoXPABConfig):
The configuration for Pyramid Attention Broadcast. Defaults to `LattePABConfig()`.
Examples:
```python
from videosys import LatteConfig, VideoSysEngine
# change num_gpus for multi-gpu inference
config = LatteConfig("maxin-cn/Latte-1", num_gpus=1)
engine = VideoSysEngine(config)
prompt = "Sunset over the sea."
# video size is fixed to 16 frames, 512x512.
video = engine.generate(
prompt=prompt,
guidance_scale=7.5,
num_inference_steps=50,
).video[0]
engine.save_video(video, f"./outputs/{prompt}.mp4")
```
"""
def __init__(
self,
model_path: str = "maxin-cn/Latte-1",
# ======= distributed =======
num_gpus: int = 1,
# ======= vae ========
enable_vae_temporal_decoder: bool = True,
# ======= scheduler ========
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
variance_type: str = "learned_range",
# ======= memory =======
cpu_offload: bool = False,
# ======= pab ========
enable_pab: bool = False,
pab_config: PABConfig = LattePABConfig(),
):
self.model_path = model_path
self.pipeline_cls = LattePipeline
# ======= distributed =======
self.num_gpus = num_gpus
# ======= vae ========
self.enable_vae_temporal_decoder = enable_vae_temporal_decoder
# ======= memory ========
self.cpu_offload = cpu_offload
# ======= scheduler ========
self.beta_start = beta_start
self.beta_end = beta_end
self.beta_schedule = beta_schedule
self.variance_type = variance_type
# ======= pab ========
self.enable_pab = enable_pab
self.pab_config = pab_config
class LattePipeline(VideoSysPipeline):
r"""
Pipeline for text-to-image generation using PixArt-Alpha.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. PixArt-Alpha uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
bad_punct_regex = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
self,
config: LatteConfig,
tokenizer: Optional[T5Tokenizer] = None,
text_encoder: Optional[T5EncoderModel] = None,
vae: Optional[AutoencoderKL] = None,
transformer: Optional[LatteT2V] = None,
scheduler: Optional[DDIMScheduler] = None,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.float16,
):
super().__init__()
self._config = config
# initialize the model if not provided
if transformer is None:
transformer = LatteT2V.from_pretrained(config.model_path, subfolder="transformer", video_length=16).to(
dtype=dtype
)
if vae is None:
if config.enable_vae_temporal_decoder:
vae = AutoencoderKLTemporalDecoder.from_pretrained(
config.model_path, subfolder="vae_temporal_decoder", torch_dtype=dtype
)
else:
vae = AutoencoderKL.from_pretrained(config.model_path, subfolder="vae", torch_dtype=dtype)
if tokenizer is None:
tokenizer = T5Tokenizer.from_pretrained(config.model_path, subfolder="tokenizer")
if text_encoder is None:
text_encoder = T5EncoderModel.from_pretrained(
config.model_path, subfolder="text_encoder", torch_dtype=dtype
)
if scheduler is None:
scheduler = DDIMScheduler.from_pretrained(
config.model_path,
subfolder="scheduler",
beta_start=config.beta_start,
beta_end=config.beta_end,
beta_schedule=config.beta_schedule,
variance_type=config.variance_type,
clip_sample=False,
)
# pab
if config.enable_pab:
set_pab_manager(config.pab_config)
# set eval and device
self.set_eval_and_device(device, vae, transformer)
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
# cpu offload
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
self.set_eval_and_device(device, text_encoder)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
keep_index = mask.sum().item()
return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096
else:
masked_feature = emb * mask[:, None, :, None] # 1 120 4096
return masked_feature, emb.shape[2]
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
mask_feature: bool = True,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
PixArt-Alpha, this should be "".
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
string.
clean_caption (bool, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
mask_feature: (bool, defaults to `True`):
If `True`, the function will mask the text embeddings.
"""
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper.
max_length = 120
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {max_length} tokens: {removed_text}"
)
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds_attention_mask = attention_mask
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
else:
negative_prompt_embeds = None
# Perform additional masking.
if mask_feature and not embeds_initially_provided:
prompt_embeds = prompt_embeds.unsqueeze(1)
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
masked_negative_prompt_embeds = (
negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
)
# import torch.nn.functional as F
# padding = (0, 0, 0, 113) # (左, 右, 下, 上)
# masked_prompt_embeds_ = F.pad(masked_prompt_embeds, padding, "constant", 0)
# masked_negative_prompt_embeds_ = F.pad(masked_negative_prompt_embeds, padding, "constant", 0)
# print(masked_prompt_embeds == masked_prompt_embeds_[:, :masked_negative_prompt_embeds.shape[1], ...])
return masked_prompt_embeds, masked_negative_prompt_embeds
# return masked_prompt_embeds_, masked_negative_prompt_embeds_
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if not isinstance(text, (tuple, list)):
text = [text]
def process(text: str):
if clean_caption:
text = self._clean_caption(text)
text = self._clean_caption(text)
else:
text = text.lower().strip()
return text
return [process(t) for t in text]
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = ftfy.fix_text(caption)
caption = html.unescape(html.unescape(caption))
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None
):
shape = (
batch_size,
num_channels_latents,
video_length,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def generate(
self,
prompt: str = None,
negative_prompt: str = "",
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
mask_feature: bool = True,
enable_temporal_attentions: bool = True,
verbose: bool = True,
) -> Union[VideoSysPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Latte can only generate video of 16 frames 512x512.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
enable_temporal_attentions (`bool`, defaults to `True`):
If `True`, the model will use temporal attentions to generate the video.
verbose (`bool`, *optional*, defaults to `True`):
Whether to print progress bars and other information during inference.
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
# 1. Check inputs. Raise error if not correct
video_length = 16
height = 512
width = 512
update_steps(num_inference_steps)
self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
clean_caption=clean_caption,
mask_feature=mask_feature,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
video_length,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if self.transformer.config.sample_size == 128:
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
for i, t in progress_wrap(list(enumerate(timesteps))):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
current_timestep = t
if not torch.is_tensor(current_timestep):
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
all_timesteps=timesteps,
encoder_hidden_states=prompt_embeds,
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
enable_temporal_attentions=enable_temporal_attentions,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
# compute previous image: x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latents":
if latents.shape[2] == 1: # image
video = self.decode_latents_image(latents)
else: # video
if self._config.enable_vae_temporal_decoder:
video = self.decode_latents_with_temporal_decoder(latents)
else:
video = self.decode_latents(latents)
else:
video = latents
return VideoSysPipelineOutput(video=video)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return VideoSysPipelineOutput(video=video)
def decode_latents_image(self, latents):
video_length = latents.shape[2]
latents = 1 / self.vae.config.scaling_factor * latents
latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
video = []
for frame_idx in range(latents.shape[0]):
video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
video = torch.cat(video)
video = einops.rearrange(video, "(b f) c h w -> b f c h w", f=video_length)
video = (video / 2.0 + 0.5).clamp(0, 1)
return video
def decode_latents(self, latents):
video_length = latents.shape[2]
latents = 1 / self.vae.config.scaling_factor * latents
latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
video = []
for frame_idx in range(latents.shape[0]):
video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
video = torch.cat(video)
video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().contiguous()
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
return video
def decode_latents_with_temporal_decoder(self, latents):
video_length = latents.shape[2]
latents = 1 / self.vae.config.scaling_factor * latents
latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
video = []
decode_chunk_size = 14
for frame_idx in range(0, latents.shape[0], decode_chunk_size):
num_frames_in = latents[frame_idx : frame_idx + decode_chunk_size].shape[0]
decode_kwargs = {}
decode_kwargs["num_frames"] = num_frames_in
video.append(self.vae.decode(latents[frame_idx : frame_idx + decode_chunk_size], **decode_kwargs).sample)
video = torch.cat(video)
video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().contiguous()
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
return video
def save_video(self, video, output_path):
save_video(video, output_path, fps=8)

View File

@ -0,0 +1,3 @@
from .pipeline_open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
__all__ = ["OpenSoraConfig", "OpenSoraPipeline", "OpenSoraPABConfig"]

View File

@ -0,0 +1,807 @@
# Adapted from OpenSora
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# OpenSora: https://github.com/hpcaitech/Open-Sora
# --------------------------------------------------------
import numbers
import os
import re
import numpy as np
import requests
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from torchvision.io import write_video
from torchvision.utils import save_image
IMG_FPS = 120
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
regex = re.compile(
r"^(?:http|ftp)s?://" # http:// or https://
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
r"localhost|" # localhost...
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
r"(?::\d+)?" # optional port
r"(?:/?|[/?]\S+)$",
re.IGNORECASE,
)
# H:W
ASPECT_RATIO_MAP = {
"3:8": "0.38",
"9:21": "0.43",
"12:25": "0.48",
"1:2": "0.50",
"9:17": "0.53",
"27:50": "0.54",
"9:16": "0.56",
"5:8": "0.62",
"2:3": "0.67",
"3:4": "0.75",
"1:1": "1.00",
"4:3": "1.33",
"3:2": "1.50",
"16:9": "1.78",
"17:9": "1.89",
"2:1": "2.00",
"50:27": "2.08",
}
# computed from above code
# S = 8294400
ASPECT_RATIO_4K = {
"0.38": (1764, 4704),
"0.43": (1886, 4400),
"0.48": (1996, 4158),
"0.50": (2036, 4072),
"0.53": (2096, 3960),
"0.54": (2118, 3918),
"0.62": (2276, 3642),
"0.56": (2160, 3840), # base
"0.67": (2352, 3528),
"0.75": (2494, 3326),
"1.00": (2880, 2880),
"1.33": (3326, 2494),
"1.50": (3528, 2352),
"1.78": (3840, 2160),
"1.89": (3958, 2096),
"2.00": (4072, 2036),
"2.08": (4156, 1994),
}
# S = 3686400
ASPECT_RATIO_2K = {
"0.38": (1176, 3136),
"0.43": (1256, 2930),
"0.48": (1330, 2770),
"0.50": (1358, 2716),
"0.53": (1398, 2640),
"0.54": (1412, 2612),
"0.56": (1440, 2560), # base
"0.62": (1518, 2428),
"0.67": (1568, 2352),
"0.75": (1662, 2216),
"1.00": (1920, 1920),
"1.33": (2218, 1664),
"1.50": (2352, 1568),
"1.78": (2560, 1440),
"1.89": (2638, 1396),
"2.00": (2716, 1358),
"2.08": (2772, 1330),
}
# S = 2073600
ASPECT_RATIO_1080P = {
"0.38": (882, 2352),
"0.43": (942, 2198),
"0.48": (998, 2080),
"0.50": (1018, 2036),
"0.53": (1048, 1980),
"0.54": (1058, 1958),
"0.56": (1080, 1920), # base
"0.62": (1138, 1820),
"0.67": (1176, 1764),
"0.75": (1248, 1664),
"1.00": (1440, 1440),
"1.33": (1662, 1246),
"1.50": (1764, 1176),
"1.78": (1920, 1080),
"1.89": (1980, 1048),
"2.00": (2036, 1018),
"2.08": (2078, 998),
}
# S = 921600
ASPECT_RATIO_720P = {
"0.38": (588, 1568),
"0.43": (628, 1466),
"0.48": (666, 1388),
"0.50": (678, 1356),
"0.53": (698, 1318),
"0.54": (706, 1306),
"0.56": (720, 1280), # base
"0.62": (758, 1212),
"0.67": (784, 1176),
"0.75": (832, 1110),
"1.00": (960, 960),
"1.33": (1108, 832),
"1.50": (1176, 784),
"1.78": (1280, 720),
"1.89": (1320, 698),
"2.00": (1358, 680),
"2.08": (1386, 666),
}
# S = 409920
ASPECT_RATIO_480P = {
"0.38": (392, 1046),
"0.43": (420, 980),
"0.48": (444, 925),
"0.50": (452, 904),
"0.53": (466, 880),
"0.54": (470, 870),
"0.56": (480, 854), # base
"0.62": (506, 810),
"0.67": (522, 784),
"0.75": (554, 738),
"1.00": (640, 640),
"1.33": (740, 555),
"1.50": (784, 522),
"1.78": (854, 480),
"1.89": (880, 466),
"2.00": (906, 454),
"2.08": (924, 444),
}
# S = 230400
ASPECT_RATIO_360P = {
"0.38": (294, 784),
"0.43": (314, 732),
"0.48": (332, 692),
"0.50": (340, 680),
"0.53": (350, 662),
"0.54": (352, 652),
"0.56": (360, 640), # base
"0.62": (380, 608),
"0.67": (392, 588),
"0.75": (416, 554),
"1.00": (480, 480),
"1.33": (554, 416),
"1.50": (588, 392),
"1.78": (640, 360),
"1.89": (660, 350),
"2.00": (678, 340),
"2.08": (692, 332),
}
# S = 102240
ASPECT_RATIO_240P = {
"0.38": (196, 522),
"0.43": (210, 490),
"0.48": (222, 462),
"0.50": (226, 452),
"0.53": (232, 438),
"0.54": (236, 436),
"0.56": (240, 426), # base
"0.62": (252, 404),
"0.67": (262, 393),
"0.75": (276, 368),
"1.00": (320, 320),
"1.33": (370, 278),
"1.50": (392, 262),
"1.78": (426, 240),
"1.89": (440, 232),
"2.00": (452, 226),
"2.08": (462, 222),
}
# S = 36864
ASPECT_RATIO_144P = {
"0.38": (117, 312),
"0.43": (125, 291),
"0.48": (133, 277),
"0.50": (135, 270),
"0.53": (139, 262),
"0.54": (141, 260),
"0.56": (144, 256), # base
"0.62": (151, 241),
"0.67": (156, 234),
"0.75": (166, 221),
"1.00": (192, 192),
"1.33": (221, 165),
"1.50": (235, 156),
"1.78": (256, 144),
"1.89": (263, 139),
"2.00": (271, 135),
"2.08": (277, 132),
}
# from PixArt
# S = 8294400
ASPECT_RATIO_2880 = {
"0.25": (1408, 5760),
"0.26": (1408, 5568),
"0.27": (1408, 5376),
"0.28": (1408, 5184),
"0.32": (1600, 4992),
"0.33": (1600, 4800),
"0.34": (1600, 4672),
"0.4": (1792, 4480),
"0.42": (1792, 4288),
"0.47": (1920, 4096),
"0.49": (1920, 3904),
"0.51": (1920, 3776),
"0.55": (2112, 3840),
"0.59": (2112, 3584),
"0.68": (2304, 3392),
"0.72": (2304, 3200),
"0.78": (2496, 3200),
"0.83": (2496, 3008),
"0.89": (2688, 3008),
"0.93": (2688, 2880),
"1.0": (2880, 2880),
"1.07": (2880, 2688),
"1.12": (3008, 2688),
"1.21": (3008, 2496),
"1.28": (3200, 2496),
"1.39": (3200, 2304),
"1.47": (3392, 2304),
"1.7": (3584, 2112),
"1.82": (3840, 2112),
"2.03": (3904, 1920),
"2.13": (4096, 1920),
"2.39": (4288, 1792),
"2.5": (4480, 1792),
"2.92": (4672, 1600),
"3.0": (4800, 1600),
"3.12": (4992, 1600),
"3.68": (5184, 1408),
"3.82": (5376, 1408),
"3.95": (5568, 1408),
"4.0": (5760, 1408),
}
# S = 4194304
ASPECT_RATIO_2048 = {
"0.25": (1024, 4096),
"0.26": (1024, 3968),
"0.27": (1024, 3840),
"0.28": (1024, 3712),
"0.32": (1152, 3584),
"0.33": (1152, 3456),
"0.35": (1152, 3328),
"0.4": (1280, 3200),
"0.42": (1280, 3072),
"0.48": (1408, 2944),
"0.5": (1408, 2816),
"0.52": (1408, 2688),
"0.57": (1536, 2688),
"0.6": (1536, 2560),
"0.68": (1664, 2432),
"0.72": (1664, 2304),
"0.78": (1792, 2304),
"0.82": (1792, 2176),
"0.88": (1920, 2176),
"0.94": (1920, 2048),
"1.0": (2048, 2048),
"1.07": (2048, 1920),
"1.13": (2176, 1920),
"1.21": (2176, 1792),
"1.29": (2304, 1792),
"1.38": (2304, 1664),
"1.46": (2432, 1664),
"1.67": (2560, 1536),
"1.75": (2688, 1536),
"2.0": (2816, 1408),
"2.09": (2944, 1408),
"2.4": (3072, 1280),
"2.5": (3200, 1280),
"2.89": (3328, 1152),
"3.0": (3456, 1152),
"3.11": (3584, 1152),
"3.62": (3712, 1024),
"3.75": (3840, 1024),
"3.88": (3968, 1024),
"4.0": (4096, 1024),
}
# S = 1048576
ASPECT_RATIO_1024 = {
"0.25": (512, 2048),
"0.26": (512, 1984),
"0.27": (512, 1920),
"0.28": (512, 1856),
"0.32": (576, 1792),
"0.33": (576, 1728),
"0.35": (576, 1664),
"0.4": (640, 1600),
"0.42": (640, 1536),
"0.48": (704, 1472),
"0.5": (704, 1408),
"0.52": (704, 1344),
"0.57": (768, 1344),
"0.6": (768, 1280),
"0.68": (832, 1216),
"0.72": (832, 1152),
"0.78": (896, 1152),
"0.82": (896, 1088),
"0.88": (960, 1088),
"0.94": (960, 1024),
"1.0": (1024, 1024),
"1.07": (1024, 960),
"1.13": (1088, 960),
"1.21": (1088, 896),
"1.29": (1152, 896),
"1.38": (1152, 832),
"1.46": (1216, 832),
"1.67": (1280, 768),
"1.75": (1344, 768),
"2.0": (1408, 704),
"2.09": (1472, 704),
"2.4": (1536, 640),
"2.5": (1600, 640),
"2.89": (1664, 576),
"3.0": (1728, 576),
"3.11": (1792, 576),
"3.62": (1856, 512),
"3.75": (1920, 512),
"3.88": (1984, 512),
"4.0": (2048, 512),
}
# S = 262144
ASPECT_RATIO_512 = {
"0.25": (256, 1024),
"0.26": (256, 992),
"0.27": (256, 960),
"0.28": (256, 928),
"0.32": (288, 896),
"0.33": (288, 864),
"0.35": (288, 832),
"0.4": (320, 800),
"0.42": (320, 768),
"0.48": (352, 736),
"0.5": (352, 704),
"0.52": (352, 672),
"0.57": (384, 672),
"0.6": (384, 640),
"0.68": (416, 608),
"0.72": (416, 576),
"0.78": (448, 576),
"0.82": (448, 544),
"0.88": (480, 544),
"0.94": (480, 512),
"1.0": (512, 512),
"1.07": (512, 480),
"1.13": (544, 480),
"1.21": (544, 448),
"1.29": (576, 448),
"1.38": (576, 416),
"1.46": (608, 416),
"1.67": (640, 384),
"1.75": (672, 384),
"2.0": (704, 352),
"2.09": (736, 352),
"2.4": (768, 320),
"2.5": (800, 320),
"2.89": (832, 288),
"3.0": (864, 288),
"3.11": (896, 288),
"3.62": (928, 256),
"3.75": (960, 256),
"3.88": (992, 256),
"4.0": (1024, 256),
}
# S = 65536
ASPECT_RATIO_256 = {
"0.25": (128, 512),
"0.26": (128, 496),
"0.27": (128, 480),
"0.28": (128, 464),
"0.32": (144, 448),
"0.33": (144, 432),
"0.35": (144, 416),
"0.4": (160, 400),
"0.42": (160, 384),
"0.48": (176, 368),
"0.5": (176, 352),
"0.52": (176, 336),
"0.57": (192, 336),
"0.6": (192, 320),
"0.68": (208, 304),
"0.72": (208, 288),
"0.78": (224, 288),
"0.82": (224, 272),
"0.88": (240, 272),
"0.94": (240, 256),
"1.0": (256, 256),
"1.07": (256, 240),
"1.13": (272, 240),
"1.21": (272, 224),
"1.29": (288, 224),
"1.38": (288, 208),
"1.46": (304, 208),
"1.67": (320, 192),
"1.75": (336, 192),
"2.0": (352, 176),
"2.09": (368, 176),
"2.4": (384, 160),
"2.5": (400, 160),
"2.89": (416, 144),
"3.0": (432, 144),
"3.11": (448, 144),
"3.62": (464, 128),
"3.75": (480, 128),
"3.88": (496, 128),
"4.0": (512, 128),
}
def get_closest_ratio(height: float, width: float, ratios: dict):
aspect_ratio = height / width
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
return closest_ratio
ASPECT_RATIOS = {
"144p": (36864, ASPECT_RATIO_144P),
"256": (65536, ASPECT_RATIO_256),
"240p": (102240, ASPECT_RATIO_240P),
"360p": (230400, ASPECT_RATIO_360P),
"512": (262144, ASPECT_RATIO_512),
"480p": (409920, ASPECT_RATIO_480P),
"720p": (921600, ASPECT_RATIO_720P),
"1024": (1048576, ASPECT_RATIO_1024),
"1080p": (2073600, ASPECT_RATIO_1080P),
"2k": (3686400, ASPECT_RATIO_2K),
"2048": (4194304, ASPECT_RATIO_2048),
"2880": (8294400, ASPECT_RATIO_2880),
"4k": (8294400, ASPECT_RATIO_4K),
}
def get_image_size(resolution, ar_ratio):
ar_key = ASPECT_RATIO_MAP[ar_ratio]
rs_dict = ASPECT_RATIOS[resolution][1]
assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}"
return rs_dict[ar_key]
NUM_FRAMES_MAP = {
"1x": 51,
"2x": 102,
"4x": 204,
"8x": 408,
"16x": 816,
"2s": 51,
"4s": 102,
"8s": 204,
"16s": 408,
"32s": 816,
}
def get_num_frames(num_frames):
if num_frames in NUM_FRAMES_MAP:
return NUM_FRAMES_MAP[num_frames]
else:
return int(num_frames)
def save_sample(x, save_path=None, fps=8, normalize=True, value_range=(-1, 1), force_video=False, verbose=True):
"""
Args:
x (Tensor): shape [C, T, H, W]
"""
assert x.ndim == 4
if not force_video and x.shape[1] == 1: # T = 1: save as image
save_path += ".png"
x = x.squeeze(1)
save_image([x], save_path, normalize=normalize, value_range=value_range)
else:
save_path += ".mp4"
if normalize:
low, high = value_range
x.clamp_(min=low, max=high)
x.sub_(low).div_(max(high - low, 1e-5))
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
write_video(save_path, x, fps=fps, video_codec="h264")
if verbose:
print(f"Saved to {save_path}")
return save_path
def is_url(url):
return re.match(regex, url) is not None
def download_url(input_path):
output_dir = "cache"
os.makedirs(output_dir, exist_ok=True)
base_name = os.path.basename(input_path)
output_path = os.path.join(output_dir, base_name)
img_data = requests.get(input_path).content
with open(output_path, "wb") as handler:
handler.write(img_data)
print(f"URL {input_path} downloaded to {output_path}")
return output_path
def get_transforms_video(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "image_size must be square for center crop"
transform_video = transforms.Compose(
[
ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(),
UCFCenterCropVideo(image_size[0]),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform_video = transforms.Compose(
[
ToTensorVideo(), # TCHW
ResizeCrop(image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
else:
raise NotImplementedError(f"Transform {name} not implemented")
return transform_video
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i : i + h, j : j + w]
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
if h < th or w < tw:
raise ValueError("height and width must be no smaller than crop_size")
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def resize_scale(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size[0] / min(H, W)
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
class UCFCenterCropVideo:
"""
First scale to the specified size in equal proportion to the short edge,
then center cropping
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
clip_center_crop = center_crop(clip_resize, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class ResizeCrop:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
clip = resize_crop_to_fill(clip, self.size)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
def get_transforms_image(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "Image size must be square for center crop"
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
else:
raise NotImplementedError(f"Transform {name} not implemented")
return transform
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
def resize_crop_to_fill(pil_image, image_size):
w, h = pil_image.size # PIL is (W, H)
th, tw = image_size
rh, rw = th / h, tw / w
if rh > rw:
sh, sw = th, round(w * rh)
image = pil_image.resize((sw, sh), Image.BICUBIC)
i = 0
j = int(round((sw - tw) / 2.0))
else:
sh, sw = round(h * rw), tw
image = pil_image.resize((sw, sh), Image.BICUBIC)
i = int(round((sh - th) / 2.0))
j = 0
arr = np.array(image)
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
return Image.fromarray(arr[i : i + th, j : j + tw])
def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)):
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
if transform is None:
transform = get_transforms_video(image_size=image_size, name=transform_name)
video = transform(vframes) # T C H W
video = video.permute(1, 0, 2, 3)
return video
def read_from_path(path, image_size, transform_name="center"):
if is_url(path):
path = download_url(path)
ext = os.path.splitext(path)[-1].lower()
if ext.lower() in VID_EXTENSIONS:
return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
else:
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)):
image = pil_loader(path)
if transform is None:
transform = get_transforms_image(image_size=image_size, name=transform_name)
image = transform(image)
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
video = video.permute(1, 0, 2, 3)
return video
def prepare_multi_resolution_info(info_type, batch_size, image_size, num_frames, fps, device, dtype):
if info_type is None:
return dict()
elif info_type == "PixArtMS":
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(batch_size, 1)
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(batch_size, 1)
return dict(ar=ar, hw=hw)
elif info_type in ["STDiT2", "OpenSora"]:
fps = fps if num_frames > 1 else IMG_FPS
fps = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size)
height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(batch_size)
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(batch_size)
num_frames = torch.tensor([num_frames], device=device, dtype=dtype).repeat(batch_size)
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(batch_size)
return dict(height=height, width=width, num_frames=num_frames, ar=ar, fps=fps)
else:
raise NotImplementedError

View File

@ -0,0 +1,958 @@
import html
import json
import os
import re
from typing import Optional, Tuple, Union
import ftfy
import torch
from diffusers.models import AutoencoderKL
from transformers import AutoTokenizer, T5EncoderModel
from videosys.core.pab_mgr import PABConfig, set_pab_manager
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
from videosys.models.autoencoders.autoencoder_kl_open_sora import OpenSoraVAE_V1_2
from videosys.models.transformers.open_sora_transformer_3d import STDiT3
from videosys.schedulers.scheduling_rflow_open_sora import RFLOW
from videosys.utils.utils import save_video
from .data_process import get_image_size, get_num_frames, prepare_multi_resolution_info, read_from_path
os.environ["TOKENIZERS_PARALLELISM"] = "true"
BAD_PUNCT_REGEX = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
class OpenSoraPABConfig(PABConfig):
def __init__(
self,
steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [450, 930],
spatial_range: int = 2,
temporal_broadcast: bool = True,
temporal_threshold: list = [450, 930],
temporal_range: int = 4,
cross_broadcast: bool = True,
cross_threshold: list = [450, 930],
cross_range: int = 6,
mlp_broadcast: bool = True,
mlp_spatial_broadcast_config: dict = {
676: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
788: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
864: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
},
mlp_temporal_broadcast_config: dict = {
676: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
788: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
864: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
},
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=temporal_threshold,
temporal_range=temporal_range,
cross_broadcast=cross_broadcast,
cross_threshold=cross_threshold,
cross_range=cross_range,
mlp_broadcast=mlp_broadcast,
mlp_spatial_broadcast_config=mlp_spatial_broadcast_config,
mlp_temporal_broadcast_config=mlp_temporal_broadcast_config,
)
class OpenSoraConfig:
"""
This config is to instantiate a `OpenSoraPipeline` class for video generation.
To be specific, this config will be passed to engine by `VideoSysEngine(config)`.
In the engine, it will be used to instantiate the corresponding pipeline class.
And the engine will call the `generate` function of the pipeline to generate the video.
If you want to explore the detail of generation, please refer to the pipeline class below.
Args:
transformer (str):
The transformer model to use. Defaults to "hpcai-tech/OpenSora-STDiT-v3".
vae (str):
The VAE model to use. Defaults to "hpcai-tech/OpenSora-VAE-v1.2".
text_encoder (str):
The text encoder model to use. Defaults to "DeepFloyd/t5-v1_1-xxl".
num_gpus (int):
The number of GPUs to use. Defaults to 1.
num_sampling_steps (int):
The number of sampling steps. Defaults to 30.
cfg_scale (float):
The configuration scale. Defaults to 7.0.
tiling_size (int):
The tiling size. Defaults to 4.
enable_flash_attn (bool):
Whether to enable Flash Attention. Defaults to False.
enable_pab (bool):
Whether to enable Pyramid Attention Broadcast. Defaults to False.
pab_config (CogVideoXPABConfig):
The configuration for Pyramid Attention Broadcast. Defaults to `LattePABConfig()`.
Examples:
```python
from videosys import OpenSoraConfig, VideoSysEngine
# change num_gpus for multi-gpu inference
# sampling parameters are defined in the config
config = OpenSoraConfig(num_sampling_steps=30, cfg_scale=7.0, num_gpus=1)
engine = VideoSysEngine(config)
prompt = "Sunset over the sea."
# num frames: 2s, 4s, 8s, 16s
# resolution: 144p, 240p, 360p, 480p, 720p
# aspect ratio: 9:16, 16:9, 3:4, 4:3, 1:1
video = engine.generate(
prompt=prompt,
resolution="480p",
aspect_ratio="9:16",
num_frames="2s",
).video[0]
engine.save_video(video, f"./outputs/{prompt}.mp4")
```
"""
def __init__(
self,
transformer: str = "hpcai-tech/OpenSora-STDiT-v3",
vae: str = "hpcai-tech/OpenSora-VAE-v1.2",
text_encoder: str = "DeepFloyd/t5-v1_1-xxl",
# ======== distributed ========
num_gpus: int = 1,
# ======== scheduler ========
num_sampling_steps: int = 30,
cfg_scale: float = 7.0,
# ======= memory =======
cpu_offload: bool = False,
# ======== vae ========
tiling_size: int = 4,
# ======== speedup ========
enable_flash_attn: bool = False,
# ======== pab ========
enable_pab: bool = False,
pab_config: PABConfig = OpenSoraPABConfig(),
):
self.pipeline_cls = OpenSoraPipeline
self.transformer = transformer
self.vae = vae
self.text_encoder = text_encoder
# ======== distributed ========
self.num_gpus = num_gpus
# ======== scheduler ========
self.num_sampling_steps = num_sampling_steps
self.cfg_scale = cfg_scale
# ======== vae ========
self.tiling_size = tiling_size
# ======= memory ========
self.cpu_offload = cpu_offload
# ======== speedup ========
self.enable_flash_attn = enable_flash_attn
# ======== pab ========
self.enable_pab = enable_pab
self.pab_config = pab_config
class OpenSoraPipeline(VideoSysPipeline):
r"""
Pipeline for text-to-image generation using PixArt-Alpha.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. PixArt-Alpha uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`STDiT3`]):
A text conditioned `STDiT3` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
"""
bad_punct_regex = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = [
"text_encoder",
"tokenizer",
]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
self,
config: OpenSoraConfig,
text_encoder: Optional[T5EncoderModel] = None,
tokenizer: Optional[AutoTokenizer] = None,
vae: Optional[AutoencoderKL] = None,
transformer: Optional[STDiT3] = None,
scheduler: Optional[RFLOW] = None,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self._config = config
self._device = device
self._dtype = dtype
# initialize the model if not provided
if text_encoder is None:
text_encoder = T5EncoderModel.from_pretrained(config.text_encoder).to(dtype)
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(config.text_encoder)
if vae is None:
vae = OpenSoraVAE_V1_2(
from_pretrained=config.vae,
micro_frame_size=17,
micro_batch_size=config.tiling_size,
).to(dtype)
if transformer is None:
transformer = STDiT3.from_pretrained(config.transformer, enable_flash_attn=config.enable_flash_attn).to(
dtype
)
if scheduler is None:
scheduler = RFLOW(
use_timestep_transform=True, num_sampling_steps=config.num_sampling_steps, cfg_scale=config.cfg_scale
)
# pab
if config.enable_pab:
set_pab_manager(config.pab_config)
# set eval and device
self.set_eval_and_device(device, vae, transformer)
self.register_modules(
text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, tokenizer=tokenizer
)
# cpu offload
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
self.set_eval_and_device(self._device, text_encoder)
def get_text_embeddings(self, texts):
text_tokens_and_mask = self.tokenizer(
texts,
max_length=300,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
device = self._execution_device
input_ids = text_tokens_and_mask["input_ids"].to(device)
attention_mask = text_tokens_and_mask["attention_mask"].to(device)
with torch.no_grad():
text_encoder_embs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
)["last_hidden_state"].detach()
return text_encoder_embs, attention_mask
def encode_prompt(self, text):
caption_embs, emb_masks = self.get_text_embeddings(text)
caption_embs = caption_embs[:, None]
return dict(y=caption_embs, mask=emb_masks)
def null_embed(self, n):
null_y = self.transformer.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None].to(self._execution_device)
return null_y
@staticmethod
def _basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def _clean_caption(self, caption):
import urllib.parse as ul
from bs4 import BeautifulSoup
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = self._basic_clean(caption)
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
def text_preprocessing(self, text, use_text_preprocessing: bool = True):
if use_text_preprocessing:
# The exact text cleaning as was in the training stage:
text = self._clean_caption(text)
text = self._clean_caption(text)
return text
else:
return text.lower().strip()
@torch.no_grad()
def generate(
self,
prompt: str,
resolution="480p",
aspect_ratio="9:16",
num_frames: int = 51,
loop: int = 1,
llm_refine: bool = False,
negative_prompt: str = "",
ms: Optional[str] = "",
refs: Optional[str] = "",
aes: float = 6.5,
flow: Optional[float] = None,
camera_motion: Optional[float] = None,
condition_frame_length: int = 5,
align: int = 5,
condition_frame_edit: float = 0.0,
return_dict: bool = True,
verbose: bool = True,
) -> Union[VideoSysPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
resolution (`str`, *optional*, defaults to `"480p"`):
The resolution of the generated video.
aspect_ratio (`str`, *optional*, defaults to `"9:16"`):
The aspect ratio of the generated video.
num_frames (`int`, *optional*, defaults to 51):
The number of frames to generate.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
# == basic ==
fps = 24
image_size = get_image_size(resolution, aspect_ratio)
num_frames = get_num_frames(num_frames)
# == prepare batch prompts ==
batch_prompts = [prompt]
ms = [ms]
refs = [refs]
# == get json from prompts ==
batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms)
# == get reference for condition ==
refs = collect_references_batch(refs, self.vae, image_size)
# == multi-resolution info ==
model_args = prepare_multi_resolution_info(
"OpenSora", len(batch_prompts), image_size, num_frames, fps, self._device, self._dtype
)
# == process prompts step by step ==
# 0. split prompt
# each element in the list is [prompt_segment_list, loop_idx_list]
batched_prompt_segment_list = []
batched_loop_idx_list = []
for prompt in batch_prompts:
prompt_segment_list, loop_idx_list = split_prompt(prompt)
batched_prompt_segment_list.append(prompt_segment_list)
batched_loop_idx_list.append(loop_idx_list)
# 1. refine prompt by openai
# if llm_refine:
# only call openai API when
# 1. seq parallel is not enabled
# 2. seq parallel is enabled and the process is rank 0
# if not enable_sequence_parallelism or (enable_sequence_parallelism and coordinator.is_master()):
# for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
# batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list)
# # sync the prompt if using seq parallel
# if enable_sequence_parallelism:
# coordinator.block_all()
# prompt_segment_length = [
# len(prompt_segment_list) for prompt_segment_list in batched_prompt_segment_list
# ]
# # flatten the prompt segment list
# batched_prompt_segment_list = [
# prompt_segment
# for prompt_segment_list in batched_prompt_segment_list
# for prompt_segment in prompt_segment_list
# ]
# # create a list of size equal to world size
# broadcast_obj_list = [batched_prompt_segment_list] * coordinator.world_size
# dist.broadcast_object_list(broadcast_obj_list, 0)
# # recover the prompt list
# batched_prompt_segment_list = []
# segment_start_idx = 0
# all_prompts = broadcast_obj_list[0]
# for num_segment in prompt_segment_length:
# batched_prompt_segment_list.append(
# all_prompts[segment_start_idx : segment_start_idx + num_segment]
# )
# segment_start_idx += num_segment
# 2. append score
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
batched_prompt_segment_list[idx] = append_score_to_prompts(
prompt_segment_list,
aes=aes,
flow=flow,
camera_motion=camera_motion,
)
# 3. clean prompt with T5
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
batched_prompt_segment_list[idx] = [self.text_preprocessing(prompt) for prompt in prompt_segment_list]
# 4. merge to obtain the final prompt
batch_prompts = []
for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list):
batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list))
# == Iter over loop generation ==
video_clips = []
for loop_i in range(loop):
# == get prompt for loop i ==
batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i)
# == add condition frames for loop ==
if loop_i > 0:
refs, ms = append_generated(
self.vae, video_clips[-1], refs, ms, loop_i, condition_frame_length, condition_frame_edit
)
# == sampling ==
input_size = (num_frames, *image_size)
latent_size = self.vae.get_latent_size(input_size)
z = torch.randn(
len(batch_prompts), self.vae.out_channels, *latent_size, device=self._device, dtype=self._dtype
)
model_args.update(self.encode_prompt(batch_prompts_loop))
y_null = self.null_embed(len(batch_prompts_loop))
masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)
samples = self.scheduler.sample(
self.transformer,
z=z,
model_args=model_args,
y_null=y_null,
device=self._device,
progress=verbose,
mask=masks,
)
samples = self.vae.decode(samples.to(self._dtype), num_frames=num_frames)
video_clips.append(samples)
for i in range(1, loop):
video_clips[i] = video_clips[i][:, dframe_to_frame(condition_frame_length) :]
video = torch.cat(video_clips, dim=1)
low, high = -1, 1
video.clamp_(min=low, max=high)
video.sub_(low).div_(max(high - low, 1e-5))
video = video.mul(255).add_(0.5).clamp_(0, 255).permute(0, 2, 3, 4, 1).to("cpu", torch.uint8)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return VideoSysPipelineOutput(video=video)
def save_video(self, video, output_path):
save_video(video, output_path, fps=24)
def load_prompts(prompt_path, start_idx=None, end_idx=None):
with open(prompt_path, "r") as f:
prompts = [line.strip() for line in f.readlines()]
prompts = prompts[start_idx:end_idx]
return prompts
def get_save_path_name(
save_dir,
sample_name=None, # prefix
sample_idx=None, # sample index
prompt=None, # used prompt
prompt_as_path=False, # use prompt as path
num_sample=1, # number of samples to generate for one prompt
k=None, # kth sample
):
if sample_name is None:
sample_name = "" if prompt_as_path else "sample"
sample_name_suffix = prompt if prompt_as_path else f"_{sample_idx:04d}"
save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix[:50]}")
if num_sample != 1:
save_path = f"{save_path}-{k}"
return save_path
def get_eval_save_path_name(
save_dir,
id, # add id parameter
sample_name=None, # prefix
sample_idx=None, # sample index
prompt=None, # used prompt
prompt_as_path=False, # use prompt as path
num_sample=1, # number of samples to generate for one prompt
k=None, # kth sample
):
if sample_name is None:
sample_name = "" if prompt_as_path else "sample"
save_path = os.path.join(save_dir, f"{id}")
if num_sample != 1:
save_path = f"{save_path}-{k}"
return save_path
def append_score_to_prompts(prompts, aes=None, flow=None, camera_motion=None):
new_prompts = []
for prompt in prompts:
new_prompt = prompt
if aes is not None and "aesthetic score:" not in prompt:
new_prompt = f"{new_prompt} aesthetic score: {aes:.1f}."
if flow is not None and "motion score:" not in prompt:
new_prompt = f"{new_prompt} motion score: {flow:.1f}."
if camera_motion is not None and "camera motion:" not in prompt:
new_prompt = f"{new_prompt} camera motion: {camera_motion}."
new_prompts.append(new_prompt)
return new_prompts
def extract_json_from_prompts(prompts, reference, mask_strategy):
ret_prompts = []
for i, prompt in enumerate(prompts):
parts = re.split(r"(?=[{])", prompt)
assert len(parts) <= 2, f"Invalid prompt: {prompt}"
ret_prompts.append(parts[0])
if len(parts) > 1:
additional_info = json.loads(parts[1])
for key in additional_info:
assert key in ["reference_path", "mask_strategy"], f"Invalid key: {key}"
if key == "reference_path":
reference[i] = additional_info[key]
elif key == "mask_strategy":
mask_strategy[i] = additional_info[key]
return ret_prompts, reference, mask_strategy
def collect_references_batch(reference_paths, vae, image_size):
refs_x = [] # refs_x: [batch, ref_num, C, T, H, W]
for reference_path in reference_paths:
if reference_path == "":
refs_x.append([])
continue
ref_path = reference_path.split(";")
ref = []
for r_path in ref_path:
r = read_from_path(r_path, image_size, transform_name="resize_crop")
r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
r_x = r_x.squeeze(0)
ref.append(r_x)
refs_x.append(ref)
return refs_x
def extract_prompts_loop(prompts, num_loop):
ret_prompts = []
for prompt in prompts:
if prompt.startswith("|0|"):
prompt_list = prompt.split("|")[1:]
text_list = []
for i in range(0, len(prompt_list), 2):
start_loop = int(prompt_list[i])
text = prompt_list[i + 1]
end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop + 1
text_list.extend([text] * (end_loop - start_loop))
prompt = text_list[num_loop]
ret_prompts.append(prompt)
return ret_prompts
def split_prompt(prompt_text):
if prompt_text.startswith("|0|"):
# this is for prompts which look like
# |0| a beautiful day |1| a sunny day |2| a rainy day
# we want to parse it into a list of prompts with the loop index
prompt_list = prompt_text.split("|")[1:]
text_list = []
loop_idx = []
for i in range(0, len(prompt_list), 2):
start_loop = int(prompt_list[i])
text = prompt_list[i + 1].strip()
text_list.append(text)
loop_idx.append(start_loop)
return text_list, loop_idx
else:
return [prompt_text], None
def merge_prompt(text_list, loop_idx_list=None):
if loop_idx_list is None:
return text_list[0]
else:
prompt = ""
for i, text in enumerate(text_list):
prompt += f"|{loop_idx_list[i]}|{text}"
return prompt
MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"]
def parse_mask_strategy(mask_strategy):
mask_batch = []
if mask_strategy == "" or mask_strategy is None:
return mask_batch
mask_strategy = mask_strategy.split(";")
for mask in mask_strategy:
mask_group = mask.split(",")
num_group = len(mask_group)
assert num_group >= 1 and num_group <= 6, f"Invalid mask strategy: {mask}"
mask_group.extend(MASK_DEFAULT[num_group:])
for i in range(5):
mask_group[i] = int(mask_group[i])
mask_group[5] = float(mask_group[5])
mask_batch.append(mask_group)
return mask_batch
def find_nearest_point(value, point, max_value):
t = value // point
if value % point > point / 2 and t < max_value // point - 1:
t += 1
return t * point
def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
masks = []
no_mask = True
for i, mask_strategy in enumerate(mask_strategys):
no_mask = False
mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
mask_strategy = parse_mask_strategy(mask_strategy)
for mst in mask_strategy:
loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst
if loop_id != loop_i:
continue
ref = refs_x[i][m_id]
if m_ref_start < 0:
# ref: [C, T, H, W]
m_ref_start = ref.shape[1] + m_ref_start
if m_target_start < 0:
# z: [B, C, T, H, W]
m_target_start = z.shape[2] + m_target_start
if align is not None:
m_ref_start = find_nearest_point(m_ref_start, align, ref.shape[1])
m_target_start = find_nearest_point(m_target_start, align, z.shape[2])
m_length = min(m_length, z.shape[2] - m_target_start, ref.shape[1] - m_ref_start)
z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
mask[m_target_start : m_target_start + m_length] = edit_ratio
masks.append(mask)
if no_mask:
return None
masks = torch.stack(masks)
return masks
def append_generated(vae, generated_video, refs_x, mask_strategy, loop_i, condition_frame_length, condition_frame_edit):
ref_x = vae.encode(generated_video)
for j, refs in enumerate(refs_x):
if refs is None:
refs_x[j] = [ref_x[j]]
else:
refs.append(ref_x[j])
if mask_strategy[j] is None or mask_strategy[j] == "":
mask_strategy[j] = ""
else:
mask_strategy[j] += ";"
mask_strategy[
j
] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length},{condition_frame_edit}"
return refs_x, mask_strategy
def dframe_to_frame(num):
assert num % 5 == 0, f"Invalid num: {num}"
return num // 5 * 17
OPENAI_CLIENT = None
REFINE_PROMPTS = None
REFINE_PROMPTS_PATH = "assets/texts/t2v_pllava.txt"
REFINE_PROMPTS_TEMPLATE = """
You need to refine user's input prompt. The user's input prompt is used for video generation task. You need to refine the user's prompt to make it more suitable for the task. Here are some examples of refined prompts:
{}
The refined prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The refined prompt should be in English.
"""
RANDOM_PROMPTS = None
RANDOM_PROMPTS_TEMPLATE = """
You need to generate one input prompt for video generation task. The prompt should be suitable for the task. Here are some examples of refined prompts:
{}
The prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The prompt should be in English.
"""
def get_openai_response(sys_prompt, usr_prompt, model="gpt-4o"):
global OPENAI_CLIENT
if OPENAI_CLIENT is None:
from openai import OpenAI
OPENAI_CLIENT = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
completion = OPENAI_CLIENT.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": sys_prompt,
}, # <-- This is the system message that provides context to the model
{
"role": "user",
"content": usr_prompt,
}, # <-- This is the user message for which the model will generate a response
],
)
return completion.choices[0].message.content
def get_random_prompt_by_openai():
global RANDOM_PROMPTS
if RANDOM_PROMPTS is None:
examples = load_prompts(REFINE_PROMPTS_PATH)
RANDOM_PROMPTS = RANDOM_PROMPTS_TEMPLATE.format("\n".join(examples))
response = get_openai_response(RANDOM_PROMPTS, "Generate one example.")
return response
def refine_prompt_by_openai(prompt):
global REFINE_PROMPTS
if REFINE_PROMPTS is None:
examples = load_prompts(REFINE_PROMPTS_PATH)
REFINE_PROMPTS = REFINE_PROMPTS_TEMPLATE.format("\n".join(examples))
response = get_openai_response(REFINE_PROMPTS, prompt)
return response
def has_openai_key():
return "OPENAI_API_KEY" in os.environ
def refine_prompts_by_openai(prompts):
new_prompts = []
for prompt in prompts:
try:
if prompt.strip() == "":
new_prompt = get_random_prompt_by_openai()
print(f"[Info] Empty prompt detected, generate random prompt: {new_prompt}")
else:
new_prompt = refine_prompt_by_openai(prompt)
print(f"[Info] Refine prompt: {prompt} -> {new_prompt}")
new_prompts.append(new_prompt)
except Exception as e:
print(f"[Warning] Failed to refine prompt: {prompt} due to {e}")
new_prompts.append(prompt)
return new_prompts
def add_watermark(
input_video_path, watermark_image_path="./assets/images/watermark/watermark.png", output_video_path=None
):
# execute this command in terminal with subprocess
# return if the process is successful
if output_video_path is None:
output_video_path = input_video_path.replace(".mp4", "_watermark.mp4")
cmd = f'ffmpeg -y -i {input_video_path} -i {watermark_image_path} -filter_complex "[1][0]scale2ref=oh*mdar:ih*0.1[logo][video];[video][logo]overlay" {output_video_path}'
exit_code = os.system(cmd)
is_success = exit_code == 0
return is_success

View File

@ -0,0 +1,3 @@
from .pipeline_open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanPABConfig"]

View File

@ -0,0 +1,915 @@
# Adapted from Open-Sora-Plan
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
# --------------------------------------------------------
import html
import inspect
import math
import re
import urllib.parse as ul
from typing import Callable, List, Optional, Tuple, Union
import ftfy
import torch
import torch.distributed as dist
import tqdm
from bs4 import BeautifulSoup
from diffusers.models import AutoencoderKL, Transformer2DModel
from diffusers.schedulers import PNDMScheduler
from diffusers.utils.torch_utils import randn_tensor
from transformers import T5EncoderModel, T5Tokenizer
from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
from videosys.utils.logging import logger
from videosys.utils.utils import save_video
from ...models.autoencoders.autoencoder_kl_open_sora_plan import ae_stride_config, getae_wrapper
from ...models.transformers.open_sora_plan_transformer_3d import LatteT2V
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import PixArtAlphaPipeline
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
>>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
>>> # Enable memory optimizations.
>>> pipe.enable_model_cpu_offload()
>>> prompt = "A small cactus with a happy face in the Sahara desert."
>>> image = pipe(prompt).images[0]
```
"""
class OpenSoraPlanPABConfig(PABConfig):
def __init__(
self,
steps: int = 150,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
spatial_range: int = 2,
temporal_broadcast: bool = True,
temporal_threshold: list = [100, 850],
temporal_range: int = 4,
cross_broadcast: bool = True,
cross_threshold: list = [100, 850],
cross_range: int = 6,
mlp_broadcast: bool = True,
mlp_spatial_broadcast_config: dict = {
738: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
714: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
690: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
666: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
642: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
618: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
594: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
570: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
546: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
522: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
498: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
474: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
450: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
426: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
},
mlp_temporal_broadcast_config: dict = {
738: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
714: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
690: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
666: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
642: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
618: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
594: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
570: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
546: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
522: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
498: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
474: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
450: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
426: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
},
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=temporal_threshold,
temporal_range=temporal_range,
cross_broadcast=cross_broadcast,
cross_threshold=cross_threshold,
cross_range=cross_range,
mlp_broadcast=mlp_broadcast,
mlp_spatial_broadcast_config=mlp_spatial_broadcast_config,
mlp_temporal_broadcast_config=mlp_temporal_broadcast_config,
)
class OpenSoraPlanConfig:
"""
This config is to instantiate a `OpenSoraPlanPipeline` class for video generation.
To be specific, this config will be passed to engine by `VideoSysEngine(config)`.
In the engine, it will be used to instantiate the corresponding pipeline class.
And the engine will call the `generate` function of the pipeline to generate the video.
If you want to explore the detail of generation, please refer to the pipeline class below.
Args:
transformer (str):
The transformer model to use. Defaults to "LanguageBind/Open-Sora-Plan-v1.1.0".
ae (str):
The Autoencoder model to use. Defaults to "CausalVAEModel_4x8x8".
text_encoder (str):
The text encoder model to use. Defaults to "DeepFloyd/t5-v1_1-xxl".
num_frames (int):
The number of frames to generate. Must be one of [65, 221].
num_gpus (int):
The number of GPUs to use. Defaults to 1.
enable_tiling (bool):
Whether to enable tiling. Defaults to True.
tile_overlap_factor (float):
The overlap factor for tiling. Defaults to 0.25.
enable_pab (bool):
Whether to enable Pyramid Attention Broadcast. Defaults to False.
pab_config (CogVideoXPABConfig):
The configuration for Pyramid Attention Broadcast. Defaults to `LattePABConfig()`.
Examples:
```python
from videosys import OpenSoraPlanConfig, VideoSysEngine
# num frames: 65 or 221
# change num_gpus for multi-gpu inference
config = OpenSoraPlanConfig(num_frames=65, num_gpus=1)
engine = VideoSysEngine(config)
prompt = "Sunset over the sea."
video = engine.generate(
prompt=prompt,
guidance_scale=7.5,
num_inference_steps=150,
).video[0]
engine.save_video(video, f"./outputs/{prompt}.mp4")
```
"""
def __init__(
self,
transformer: str = "LanguageBind/Open-Sora-Plan-v1.1.0",
ae: str = "CausalVAEModel_4x8x8",
text_encoder: str = "DeepFloyd/t5-v1_1-xxl",
num_frames: int = 65,
# ======= distributed ========
num_gpus: int = 1,
# ======= memory =======
cpu_offload: bool = False,
enable_tiling: bool = True,
tile_overlap_factor: float = 0.25,
# ======= pab ========
enable_pab: bool = False,
pab_config: PABConfig = OpenSoraPlanPABConfig(),
):
self.pipeline_cls = OpenSoraPlanPipeline
self.ae = ae
self.text_encoder = text_encoder
self.transformer = transformer
assert num_frames in [65, 221], "num_frames must be one of [65, 221]"
self.num_frames = num_frames
self.version = f"{num_frames}x512x512"
# ======= distributed ========
self.num_gpus = num_gpus
# ======= memory ========
self.cpu_offload = cpu_offload
self.enable_tiling = enable_tiling
self.tile_overlap_factor = tile_overlap_factor
# ======= pab ========
self.enable_pab = enable_pab
self.pab_config = pab_config
class OpenSoraPlanPipeline(VideoSysPipeline):
r"""
Pipeline for text-to-image generation using PixArt-Alpha.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. PixArt-Alpha uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
bad_punct_regex = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
self,
config: OpenSoraPlanConfig,
tokenizer: Optional[T5Tokenizer] = None,
text_encoder: Optional[T5EncoderModel] = None,
vae: Optional[AutoencoderKL] = None,
transformer: Optional[Transformer2DModel] = None,
scheduler: Optional[PNDMScheduler] = None,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.float16,
):
super().__init__()
self._config = config
# init
if tokenizer is None:
tokenizer = T5Tokenizer.from_pretrained(config.text_encoder)
if text_encoder is None:
text_encoder = T5EncoderModel.from_pretrained(config.text_encoder, torch_dtype=torch.float16)
if vae is None:
vae = getae_wrapper(config.ae)(config.transformer, subfolder="vae").to(dtype=dtype)
if transformer is None:
transformer = LatteT2V.from_pretrained(config.transformer, subfolder=config.version, torch_dtype=dtype)
if scheduler is None:
scheduler = PNDMScheduler()
# setting
if config.enable_tiling:
vae.vae.enable_tiling()
vae.vae.tile_overlap_factor = config.tile_overlap_factor
vae.vae_scale_factor = ae_stride_config[config.ae]
transformer.force_images = False
# set eval and device
self.set_eval_and_device(device, vae, transformer)
# pab
if config.enable_pab:
set_pab_manager(config.pab_config)
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
# cpu offload
if config.cpu_offload:
self.enable_model_cpu_offload()
else:
self.set_eval_and_device(device, text_encoder)
# self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
keep_index = mask.sum().item()
return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096
else:
masked_feature = emb * mask[:, None, :, None] # 1 120 4096
return masked_feature, emb.shape[2]
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
mask_feature: bool = True,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
PixArt-Alpha, this should be "".
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
string.
clean_caption (bool, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
mask_feature: (bool, defaults to `True`):
If `True`, the function will mask the text embeddings.
"""
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper.
max_length = 300
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because the model can only handle sequences up to"
f" {max_length} tokens: {removed_text}"
)
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds_attention_mask = attention_mask
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
else:
negative_prompt_embeds = None
# print(prompt_embeds.shape) # 1 120 4096
# print(negative_prompt_embeds.shape) # 1 120 4096
# Perform additional masking.
if mask_feature and not embeds_initially_provided:
prompt_embeds = prompt_embeds.unsqueeze(1)
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
masked_negative_prompt_embeds = (
negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
)
# import torch.nn.functional as F
# padding = (0, 0, 0, 113) # (左, 右, 下, 上)
# masked_prompt_embeds_ = F.pad(masked_prompt_embeds, padding, "constant", 0)
# masked_negative_prompt_embeds_ = F.pad(masked_negative_prompt_embeds, padding, "constant", 0)
# print(masked_prompt_embeds == masked_prompt_embeds_[:, :masked_negative_prompt_embeds.shape[1], ...])
return masked_prompt_embeds, masked_negative_prompt_embeds
# return masked_prompt_embeds_, masked_negative_prompt_embeds_
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if not isinstance(text, (tuple, list)):
text = [text]
def process(text: str):
if clean_caption:
text = self._clean_caption(text)
text = self._clean_caption(text)
else:
text = text.lower().strip()
return text
return [process(t) for t in text]
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",
# noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",
# noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",
# noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = ftfy.fix_text(caption)
caption = html.unescape(html.unescape(caption))
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
shape = (
batch_size,
num_channels_latents,
(math.ceil((int(num_frames) - 1) / self.vae.vae_scale_factor[0]) + 1)
if int(num_frames) % 2 == 1
else math.ceil(int(num_frames) / self.vae.vae_scale_factor[0]),
math.ceil(int(height) / self.vae.vae_scale_factor[1]),
math.ceil(int(width) / self.vae.vae_scale_factor[2]),
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def generate(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 150,
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
mask_feature: bool = True,
enable_temporal_attentions: bool = True,
verbose: bool = True,
) -> Union[VideoSysPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
# 1. Check inputs. Raise error if not correct
height = 512
width = 512
num_frames = self._config.num_frames
update_steps(num_inference_steps)
self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
clean_caption=clean_caption,
mask_feature=mask_feature,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
# if self.transformer.config.sample_size == 128:
# resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
# aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
# resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
# aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
# added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
for i, t in progress_wrap(list(enumerate(timesteps))):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
all_timesteps=timesteps,
encoder_hidden_states=prompt_embeds,
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
enable_temporal_attentions=enable_temporal_attentions,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
# compute previous image: x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latents":
video = self.decode_latents(latents)
video = video[:, :num_frames, :height, :width]
else:
video = latents
return VideoSysPipelineOutput(video=video)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return VideoSysPipelineOutput(video=video)
def decode_latents(self, latents):
video = self.vae.decode(latents) # b t c h w
# b t c h w -> b t h w c
video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous()
return video
def save_video(self, video, output_path):
save_video(video, output_path, fps=24)

Some files were not shown because too many files have changed in this diff Show More