mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2026-01-23 10:24:28 +08:00
update
This commit is contained in:
parent
7780b1aea3
commit
3d8b01f91b
168
.gitignore
vendored
168
.gitignore
vendored
@ -1,168 +0,0 @@
|
||||
outputs/
|
||||
processed/
|
||||
profile/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/.build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
# macos
|
||||
*.DS_Store
|
||||
#data/
|
||||
|
||||
docs/.build
|
||||
|
||||
# pytorch checkpoint
|
||||
*.pt
|
||||
|
||||
# ignore any kernel build files
|
||||
.o
|
||||
.so
|
||||
|
||||
# ignore python interface defition file
|
||||
.pyi
|
||||
|
||||
# ignore coverage test file
|
||||
coverage.lcov
|
||||
coverage.xml
|
||||
|
||||
# ignore testmon and coverage files
|
||||
.coverage
|
||||
.testmondata*
|
||||
|
||||
pretrained
|
||||
samples
|
||||
cache_dir
|
||||
test_outputs
|
||||
datasets
|
||||
@ -1,7 +0,0 @@
|
||||
[settings]
|
||||
line_length = 120
|
||||
multi_line_output=3
|
||||
include_trailing_comma = true
|
||||
ignore_comments = true
|
||||
profile = black
|
||||
honor_noqa = true
|
||||
@ -1,39 +0,0 @@
|
||||
repos:
|
||||
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: v2.2.1
|
||||
hooks:
|
||||
- id: autoflake
|
||||
name: autoflake (python)
|
||||
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
name: sort all imports (python)
|
||||
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 23.9.1
|
||||
hooks:
|
||||
- id: black
|
||||
name: black formatter
|
||||
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v13.0.1
|
||||
hooks:
|
||||
- id: clang-format
|
||||
name: clang formatter
|
||||
types_or: [c++, c]
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: check-merge-conflict
|
||||
- id: check-case-conflict
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: mixed-line-ending
|
||||
args: ['--fix=lf']
|
||||
@ -1,37 +0,0 @@
|
||||
## Coding Standards
|
||||
|
||||
### Unit Tests
|
||||
We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests.
|
||||
|
||||
To set up the environment for unit testing, first change your current directory to the root directory of your local ColossalAI repository, then run
|
||||
```bash
|
||||
pip install -r requirements/requirements-test.txt
|
||||
```
|
||||
If you encounter an error telling "Could not find a version that satisfies the requirement fbgemm-gpu==0.2.0", please downgrade your python version to 3.8 or 3.9 and try again.
|
||||
|
||||
If you only want to run CPU tests, you can run
|
||||
|
||||
```bash
|
||||
pytest -m cpu tests/
|
||||
```
|
||||
|
||||
If you have 8 GPUs on your machine, you can run the full test
|
||||
|
||||
```bash
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
If you do not have 8 GPUs on your machine, do not worry. Unit testing will be automatically conducted when you put up a pull request to the main branch.
|
||||
|
||||
|
||||
### Code Style
|
||||
|
||||
We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below.
|
||||
|
||||
```shell
|
||||
# these commands are executed under the Colossal-AI directory
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Code format checking will be automatically executed when you commit your changes.
|
||||
134
README.md
134
README.md
@ -1,134 +0,0 @@
|
||||
# Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model
|
||||
|
||||
<div class="is-size-5 publication-authors", align="center",>
|
||||
<span class="author-block">
|
||||
<a href="https://liewfeng.github.io" target="_blank">Feng Liu</a><sup>1</sup><sup>*</sup>,
|
||||
</span>
|
||||
<span class="author-block">
|
||||
<a href="https://scholar.google.com.hk/citations?user=ZO3OQ-8AAAAJ" target="_blank">Shiwei Zhang</a><sup>2</sup>,
|
||||
</span>
|
||||
<span class="author-block">
|
||||
<a href="https://jeffwang987.github.io" target="_blank">Xiaofeng Wang</a><sup>1,3</sup>,
|
||||
</span>
|
||||
<span class="author-block">
|
||||
<a href="https://weilllllls.github.io" target="_blank">Yujie Wei</a><sup>4</sup>,
|
||||
</span>
|
||||
<span class="author-block">
|
||||
<a href="http://haonanqiu.com" target="_blank">Haonan Qiu</a><sup>5</sup>
|
||||
</span>
|
||||
<br>
|
||||
<span class="author-block">
|
||||
<a href="https://callsys.github.io/zhaoyuzhong.github.io-main" target="_blank">Yuzhong Zhao</a><sup>1</sup>,
|
||||
</span>
|
||||
<span class="author-block">
|
||||
<a href="https://scholar.google.com.sg/citations?user=16RDSEUAAAAJ" target="_blank">Yingya Zhang</a><sup>2</sup>,
|
||||
</span>
|
||||
<span class="author-block">
|
||||
<a href="https://scholar.google.com/citations?user=tjEfgsEAAAAJ&hl=en&oi=ao" target="_blank">Qixiang Ye</a><sup>1</sup>,
|
||||
</span>
|
||||
<span class="author-block">
|
||||
<a href="https://scholar.google.com/citations?user=0IKavloAAAAJ&hl=en&oi=ao" target="_blank">Fang Wan</a><sup>1</sup><sup>†</sup>
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div class="is-size-5 publication-authors", align="center">
|
||||
<span class="author-block"><sup>1</sup>University of Chinese Academy of Sciences, </span>
|
||||
<span class="author-block"><sup>2</sup>Alibaba Group</span>
|
||||
<br>
|
||||
<span class="author-block"><sup>3</sup>Institute of Automation, Chinese Academy of Sciences</span>
|
||||
<br>
|
||||
<span class="author-block"><sup>4</sup>Fudan University, </span>
|
||||
<span class="author-block"><sup>5</sup>Nanyang Technological University</span>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="is-size-5 publication-authors", align="center">
|
||||
(* Work was done during internship at Alibaba Group. † Corresponding author.)
|
||||
</div>
|
||||
|
||||
<div class="is-size-5 publication-authors", align="center">
|
||||
<a href="https://arxiv.org/abs/2411.19108">Paper</a> |
|
||||
<a href="https://github.com/LiewFeng/TeaCache/">Project Page</a>
|
||||
</div>
|
||||
|
||||

|
||||
|
||||
## Introduction
|
||||
We introduce Timestep Embedding Aware Cache (TeaCache), a training-free caching approach that estimates and leverages the fluctuating differences among model outputs across timesteps. For more details and visual results, please visit our [project page](https://github.com/LiewFeng/TeaCache).
|
||||
|
||||
## Installation
|
||||
|
||||
Prerequisites:
|
||||
|
||||
- Python >= 3.10
|
||||
- PyTorch >= 1.13 (We recommend to use a >2.0 version)
|
||||
- CUDA >= 11.6
|
||||
|
||||
We strongly recommend using Anaconda to create a new environment (Python >= 3.10) to run our examples:
|
||||
|
||||
```shell
|
||||
conda create -n teacache python=3.10 -y
|
||||
conda activate teacache
|
||||
```
|
||||
|
||||
Install VideoSys:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/LiewFeng/TeaCache
|
||||
cd TeaCache
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
||||
## Evaluation of TeaCache
|
||||
|
||||
We first generate videos according to VBench's prompts.
|
||||
|
||||
And then calculate Vbench, PSNR, LPIPS and SSIM based on the video generated.
|
||||
|
||||
1. Generate video
|
||||
```
|
||||
cd eval/teacache
|
||||
python experiments/latte.py
|
||||
python experiments/opensora.py
|
||||
python experiments/open_sora_plan.py
|
||||
```
|
||||
|
||||
2. Calculate Vbench score
|
||||
```
|
||||
# vbench is calculated independently
|
||||
# get scores for all metrics
|
||||
python vbench/run_vbench.py --video_path aaa --save_path bbb
|
||||
# calculate final score
|
||||
python vbench/cal_vbench.py --score_dir bbb
|
||||
```
|
||||
|
||||
3. Calculate other metrics
|
||||
```
|
||||
# these metrics are calculated compared with original model
|
||||
# gt video is the video of original model
|
||||
# generated video is our methods's results
|
||||
python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
```
|
||||
@misc{liu2024timestep,
|
||||
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
|
||||
author={Feng Liu and Shiwei Zhang and Xiaofeng Wang and Yujie Wei and Haonan Qiu and Yuzhong Zhao and Yingya Zhang and Qixiang Ye and Fang Wan},
|
||||
year={2024},
|
||||
eprint={2411.19108},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV},
|
||||
url={https://arxiv.org/abs/2411.19108}
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys). Thanks for their contributions!
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 6.4 MiB |
@ -1,6 +0,0 @@
|
||||
Common metrics
|
||||
|
||||
Include LPIPS, PSNR and SSIM.
|
||||
|
||||
The code is adapted from [common_metrics_on_video_quality
|
||||
](https://github.com/JunyaoHu/common_metrics_on_video_quality).
|
||||
@ -1,97 +0,0 @@
|
||||
import lpips
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
spatial = True # Return a spatial map of perceptual distance.
|
||||
|
||||
# Linearly calibrated models (LPIPS)
|
||||
loss_fn = lpips.LPIPS(net="alex", spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
|
||||
# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
|
||||
|
||||
|
||||
def trans(x):
|
||||
# if greyscale images add channel
|
||||
if x.shape[-3] == 1:
|
||||
x = x.repeat(1, 1, 3, 1, 1)
|
||||
|
||||
# value range [0, 1] -> [-1, 1]
|
||||
x = x * 2 - 1
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def calculate_lpips(videos1, videos2, device):
|
||||
# image should be RGB, IMPORTANT: normalized to [-1,1]
|
||||
|
||||
assert videos1.shape == videos2.shape
|
||||
|
||||
# videos [batch_size, timestamps, channel, h, w]
|
||||
|
||||
# support grayscale input, if grayscale -> channel*3
|
||||
# value range [0, 1] -> [-1, 1]
|
||||
videos1 = trans(videos1)
|
||||
videos2 = trans(videos2)
|
||||
|
||||
lpips_results = []
|
||||
|
||||
for video_num in range(videos1.shape[0]):
|
||||
# get a video
|
||||
# video [timestamps, channel, h, w]
|
||||
video1 = videos1[video_num]
|
||||
video2 = videos2[video_num]
|
||||
|
||||
lpips_results_of_a_video = []
|
||||
for clip_timestamp in range(len(video1)):
|
||||
# get a img
|
||||
# img [timestamps[x], channel, h, w]
|
||||
# img [channel, h, w] tensor
|
||||
|
||||
img1 = video1[clip_timestamp].unsqueeze(0).to(device)
|
||||
img2 = video2[clip_timestamp].unsqueeze(0).to(device)
|
||||
|
||||
loss_fn.to(device)
|
||||
|
||||
# calculate lpips of a video
|
||||
lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
|
||||
lpips_results.append(lpips_results_of_a_video)
|
||||
|
||||
lpips_results = np.array(lpips_results)
|
||||
|
||||
lpips = {}
|
||||
lpips_std = {}
|
||||
|
||||
for clip_timestamp in range(len(video1)):
|
||||
lpips[clip_timestamp] = np.mean(lpips_results[:, clip_timestamp])
|
||||
lpips_std[clip_timestamp] = np.std(lpips_results[:, clip_timestamp])
|
||||
|
||||
result = {
|
||||
"value": lpips,
|
||||
"value_std": lpips_std,
|
||||
"video_setting": video1.shape,
|
||||
"video_setting_name": "time, channel, heigth, width",
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# test code / using example
|
||||
|
||||
|
||||
def main():
|
||||
NUMBER_OF_VIDEOS = 8
|
||||
VIDEO_LENGTH = 50
|
||||
CHANNEL = 3
|
||||
SIZE = 64
|
||||
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
device = torch.device("cuda")
|
||||
# device = torch.device("cpu")
|
||||
|
||||
import json
|
||||
|
||||
result = calculate_lpips(videos1, videos2, device)
|
||||
print(json.dumps(result, indent=4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,90 +0,0 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def img_psnr(img1, img2):
|
||||
# [0,1]
|
||||
# compute mse
|
||||
# mse = np.mean((img1-img2)**2)
|
||||
mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
|
||||
# compute psnr
|
||||
if mse < 1e-10:
|
||||
return 100
|
||||
psnr = 20 * math.log10(1 / math.sqrt(mse))
|
||||
return psnr
|
||||
|
||||
|
||||
def trans(x):
|
||||
return x
|
||||
|
||||
|
||||
def calculate_psnr(videos1, videos2):
|
||||
# videos [batch_size, timestamps, channel, h, w]
|
||||
|
||||
assert videos1.shape == videos2.shape
|
||||
|
||||
videos1 = trans(videos1)
|
||||
videos2 = trans(videos2)
|
||||
|
||||
psnr_results = []
|
||||
|
||||
for video_num in range(videos1.shape[0]):
|
||||
# get a video
|
||||
# video [timestamps, channel, h, w]
|
||||
video1 = videos1[video_num]
|
||||
video2 = videos2[video_num]
|
||||
|
||||
psnr_results_of_a_video = []
|
||||
for clip_timestamp in range(len(video1)):
|
||||
# get a img
|
||||
# img [timestamps[x], channel, h, w]
|
||||
# img [channel, h, w] numpy
|
||||
|
||||
img1 = video1[clip_timestamp].numpy()
|
||||
img2 = video2[clip_timestamp].numpy()
|
||||
|
||||
# calculate psnr of a video
|
||||
psnr_results_of_a_video.append(img_psnr(img1, img2))
|
||||
|
||||
psnr_results.append(psnr_results_of_a_video)
|
||||
|
||||
psnr_results = np.array(psnr_results)
|
||||
|
||||
psnr = {}
|
||||
psnr_std = {}
|
||||
|
||||
for clip_timestamp in range(len(video1)):
|
||||
psnr[clip_timestamp] = np.mean(psnr_results[:, clip_timestamp])
|
||||
psnr_std[clip_timestamp] = np.std(psnr_results[:, clip_timestamp])
|
||||
|
||||
result = {
|
||||
"value": psnr,
|
||||
"value_std": psnr_std,
|
||||
"video_setting": video1.shape,
|
||||
"video_setting_name": "time, channel, heigth, width",
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# test code / using example
|
||||
|
||||
|
||||
def main():
|
||||
NUMBER_OF_VIDEOS = 8
|
||||
VIDEO_LENGTH = 50
|
||||
CHANNEL = 3
|
||||
SIZE = 64
|
||||
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
|
||||
import json
|
||||
|
||||
result = calculate_psnr(videos1, videos2)
|
||||
print(json.dumps(result, indent=4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,116 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def ssim(img1, img2):
|
||||
C1 = 0.01**2
|
||||
C2 = 0.03**2
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1**2
|
||||
mu2_sq = mu2**2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
def calculate_ssim_function(img1, img2):
|
||||
# [0,1]
|
||||
# ssim is the only metric extremely sensitive to gray being compared to b/w
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError("Input images must have the same dimensions.")
|
||||
if img1.ndim == 2:
|
||||
return ssim(img1, img2)
|
||||
elif img1.ndim == 3:
|
||||
if img1.shape[0] == 3:
|
||||
ssims = []
|
||||
for i in range(3):
|
||||
ssims.append(ssim(img1[i], img2[i]))
|
||||
return np.array(ssims).mean()
|
||||
elif img1.shape[0] == 1:
|
||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||
else:
|
||||
raise ValueError("Wrong input image dimensions.")
|
||||
|
||||
|
||||
def trans(x):
|
||||
return x
|
||||
|
||||
|
||||
def calculate_ssim(videos1, videos2):
|
||||
# videos [batch_size, timestamps, channel, h, w]
|
||||
|
||||
assert videos1.shape == videos2.shape
|
||||
|
||||
videos1 = trans(videos1)
|
||||
videos2 = trans(videos2)
|
||||
|
||||
ssim_results = []
|
||||
|
||||
for video_num in range(videos1.shape[0]):
|
||||
# get a video
|
||||
# video [timestamps, channel, h, w]
|
||||
video1 = videos1[video_num]
|
||||
video2 = videos2[video_num]
|
||||
|
||||
ssim_results_of_a_video = []
|
||||
for clip_timestamp in range(len(video1)):
|
||||
# get a img
|
||||
# img [timestamps[x], channel, h, w]
|
||||
# img [channel, h, w] numpy
|
||||
|
||||
img1 = video1[clip_timestamp].numpy()
|
||||
img2 = video2[clip_timestamp].numpy()
|
||||
|
||||
# calculate ssim of a video
|
||||
ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
|
||||
|
||||
ssim_results.append(ssim_results_of_a_video)
|
||||
|
||||
ssim_results = np.array(ssim_results)
|
||||
|
||||
ssim = {}
|
||||
ssim_std = {}
|
||||
|
||||
for clip_timestamp in range(len(video1)):
|
||||
ssim[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp])
|
||||
ssim_std[clip_timestamp] = np.std(ssim_results[:, clip_timestamp])
|
||||
|
||||
result = {
|
||||
"value": ssim,
|
||||
"value_std": ssim_std,
|
||||
"video_setting": video1.shape,
|
||||
"video_setting_name": "time, channel, heigth, width",
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# test code / using example
|
||||
|
||||
|
||||
def main():
|
||||
NUMBER_OF_VIDEOS = 8
|
||||
VIDEO_LENGTH = 50
|
||||
CHANNEL = 3
|
||||
SIZE = 64
|
||||
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
torch.device("cuda")
|
||||
|
||||
import json
|
||||
|
||||
result = calculate_ssim(videos1, videos2)
|
||||
print(json.dumps(result, indent=4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,160 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import imageio
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
import tqdm
|
||||
from calculate_lpips import calculate_lpips
|
||||
from calculate_psnr import calculate_psnr
|
||||
from calculate_ssim import calculate_ssim
|
||||
|
||||
|
||||
def load_videos(directory, video_ids, file_extension):
|
||||
videos = []
|
||||
for video_id in video_ids:
|
||||
video_path = os.path.join(directory, f"{video_id}.{file_extension}")
|
||||
if os.path.exists(video_path):
|
||||
video = load_video(video_path) # Define load_video based on how videos are stored
|
||||
videos.append(video)
|
||||
else:
|
||||
raise ValueError(f"Video {video_id}.{file_extension} not found in {directory}")
|
||||
return videos
|
||||
|
||||
|
||||
def load_video(video_path):
|
||||
"""
|
||||
Load a video from the given path and convert it to a PyTorch tensor.
|
||||
"""
|
||||
# Read the video using imageio
|
||||
reader = imageio.get_reader(video_path, "ffmpeg")
|
||||
|
||||
# Extract frames and convert to a list of tensors
|
||||
frames = []
|
||||
for frame in reader:
|
||||
# Convert the frame to a tensor and permute the dimensions to match (C, H, W)
|
||||
frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
|
||||
frames.append(frame_tensor)
|
||||
|
||||
# Stack the list of tensors into a single tensor with shape (T, C, H, W)
|
||||
video_tensor = torch.stack(frames)
|
||||
|
||||
return video_tensor
|
||||
|
||||
|
||||
def resize_video(video, target_height, target_width):
|
||||
resized_frames = []
|
||||
for frame in video:
|
||||
resized_frame = F.resize(frame, [target_height, target_width])
|
||||
resized_frames.append(resized_frame)
|
||||
return torch.stack(resized_frames)
|
||||
|
||||
|
||||
def preprocess_eval_video(eval_video, generated_video_shape):
|
||||
T_gen, _, H_gen, W_gen = generated_video_shape
|
||||
T_eval, _, H_eval, W_eval = eval_video.shape
|
||||
|
||||
if T_eval < T_gen:
|
||||
raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
|
||||
|
||||
if H_eval < H_gen or W_eval < W_gen:
|
||||
# Resize the video maintaining the aspect ratio
|
||||
resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
|
||||
resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
|
||||
eval_video = resize_video(eval_video, resize_height, resize_width)
|
||||
# Recalculate the dimensions
|
||||
T_eval, _, H_eval, W_eval = eval_video.shape
|
||||
|
||||
# Center crop
|
||||
start_h = (H_eval - H_gen) // 2
|
||||
start_w = (W_eval - W_gen) // 2
|
||||
cropped_video = eval_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
|
||||
|
||||
return cropped_video
|
||||
|
||||
|
||||
def main(args):
|
||||
device = "cuda"
|
||||
gt_video_dir = args.gt_video_dir
|
||||
generated_video_dir = args.generated_video_dir
|
||||
|
||||
video_ids = []
|
||||
file_extension = "mp4"
|
||||
for f in os.listdir(generated_video_dir):
|
||||
if f.endswith(f".{file_extension}"):
|
||||
video_ids.append(f.replace(f".{file_extension}", ""))
|
||||
if not video_ids:
|
||||
raise ValueError("No videos found in the generated video dataset. Exiting.")
|
||||
|
||||
print(f"Find {len(video_ids)} videos")
|
||||
prompt_interval = 1
|
||||
batch_size = 16
|
||||
calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
|
||||
|
||||
lpips_results = []
|
||||
psnr_results = []
|
||||
ssim_results = []
|
||||
|
||||
total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
|
||||
|
||||
for idx, video_id in enumerate(tqdm.tqdm(range(total_len))):
|
||||
gt_videos_tensor = []
|
||||
generated_videos_tensor = []
|
||||
for i in range(batch_size):
|
||||
video_idx = idx * batch_size + i
|
||||
if video_idx >= len(video_ids):
|
||||
break
|
||||
video_id = video_ids[video_idx]
|
||||
generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.{file_extension}"))
|
||||
generated_videos_tensor.append(generated_video)
|
||||
eval_video = load_video(os.path.join(gt_video_dir, f"{video_id}.{file_extension}"))
|
||||
gt_videos_tensor.append(eval_video)
|
||||
gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
|
||||
generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
|
||||
|
||||
if calculate_lpips_flag:
|
||||
result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
|
||||
result = result["value"].values()
|
||||
result = sum(result) / len(result)
|
||||
lpips_results.append(result)
|
||||
|
||||
if calculate_psnr_flag:
|
||||
result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
|
||||
result = result["value"].values()
|
||||
result = sum(result) / len(result)
|
||||
psnr_results.append(result)
|
||||
|
||||
if calculate_ssim_flag:
|
||||
result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
|
||||
result = result["value"].values()
|
||||
result = sum(result) / len(result)
|
||||
ssim_results.append(result)
|
||||
|
||||
if (idx + 1) % prompt_interval == 0:
|
||||
out_str = ""
|
||||
for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
|
||||
result = sum(results) / len(results)
|
||||
out_str += f"{name}: {result:.4f}, "
|
||||
print(f"Processed {idx + 1} videos. {out_str[:-2]}")
|
||||
|
||||
out_str = ""
|
||||
for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
|
||||
result = sum(results) / len(results)
|
||||
out_str += f"{name}: {result:.4f}, "
|
||||
out_str = out_str[:-2]
|
||||
|
||||
# save
|
||||
with open(f"./{os.path.basename(generated_video_dir)}.txt", "w+") as f:
|
||||
f.write(out_str)
|
||||
|
||||
print(f"Processed all videos. {out_str}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--gt_video_dir", type=str)
|
||||
parser.add_argument("--generated_video_dir", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
Binary file not shown.
@ -1,532 +0,0 @@
|
||||
from utils import generate_func, read_prompt_list
|
||||
from videosys import LatteConfig, VideoSysEngine
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from videosys.core.parallel_mgr import (
|
||||
enable_sequence_parallel,
|
||||
get_cfg_parallel_size,
|
||||
get_data_parallel_group,
|
||||
get_sequence_parallel_group,
|
||||
)
|
||||
|
||||
def teacache_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
all_timesteps=None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_image_num: int = 0,
|
||||
enable_temporal_attentions: bool = True,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
||||
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
||||
|
||||
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
||||
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
||||
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
# 0. Split batch for data parallelism
|
||||
if get_cfg_parallel_size() > 1:
|
||||
(
|
||||
hidden_states,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
added_cond_kwargs,
|
||||
class_labels,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
) = batch_func(
|
||||
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
|
||||
hidden_states,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
added_cond_kwargs,
|
||||
class_labels,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
input_batch_size, c, frame, h, w = hidden_states.shape
|
||||
frame = frame - use_image_num
|
||||
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
|
||||
org_timestep = timestep
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous()
|
||||
elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
|
||||
encoder_attention_mask_video = repeat(
|
||||
encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame
|
||||
).contiguous()
|
||||
encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
|
||||
encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1)
|
||||
encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1)
|
||||
|
||||
# Retrieve lora scale.
|
||||
cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_patches: # here
|
||||
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
||||
num_patches = height * width
|
||||
|
||||
hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings
|
||||
|
||||
if self.adaln_single is not None:
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
||||
)
|
||||
# batch_size = hidden_states.shape[0]
|
||||
batch_size = input_batch_size
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = hidden_states.shape[0]
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
|
||||
|
||||
if use_image_num != 0 and self.training:
|
||||
encoder_hidden_states_video = encoder_hidden_states[:, :1, ...]
|
||||
encoder_hidden_states_video = repeat(
|
||||
encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame
|
||||
).contiguous()
|
||||
encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...]
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
|
||||
encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous()
|
||||
else:
|
||||
encoder_hidden_states_spatial = repeat(
|
||||
encoder_hidden_states, "b t d -> (b f) t d", f=frame
|
||||
).contiguous()
|
||||
|
||||
# prepare timesteps for spatial and temporal block
|
||||
timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
|
||||
timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
|
||||
|
||||
if self.enable_teacache:
|
||||
inp = hidden_states.clone()
|
||||
batch_size = inp.shape[0]
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.transformer_blocks[0].scale_shift_table[None] + timestep_spatial.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
modulated_inp = self.transformer_blocks[0].norm1(inp) * (1 + scale_msa) + shift_msa
|
||||
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = [-2.46434137e+03, 3.08044764e+02, 8.07447667e+01, -4.11385132e+00, 1.11001402e-01]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
|
||||
if self.enable_teacache:
|
||||
if not should_calc:
|
||||
hidden_states += self.previous_residual
|
||||
else:
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(frame + use_image_num)
|
||||
set_spatial_pad(num_patches)
|
||||
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
|
||||
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
|
||||
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
|
||||
temp_pos_embed = split_sequence(
|
||||
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
)
|
||||
else:
|
||||
temp_pos_embed = self.temp_pos_embed
|
||||
|
||||
hidden_states_origin = hidden_states.clone().detach()
|
||||
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
spatial_block,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states_spatial,
|
||||
encoder_attention_mask,
|
||||
timestep_spatial,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
if enable_temporal_attentions:
|
||||
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
|
||||
|
||||
if use_image_num != 0: # image-video joitn training
|
||||
hidden_states_video = hidden_states[:, :frame, ...]
|
||||
hidden_states_image = hidden_states[:, frame:, ...]
|
||||
|
||||
if i == 0:
|
||||
hidden_states_video = hidden_states_video + temp_pos_embed
|
||||
|
||||
hidden_states_video = torch.utils.checkpoint.checkpoint(
|
||||
temp_block,
|
||||
hidden_states_video,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
else:
|
||||
if i == 0:
|
||||
hidden_states = hidden_states + temp_pos_embed
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
temp_block,
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
else:
|
||||
hidden_states = spatial_block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states_spatial,
|
||||
encoder_attention_mask,
|
||||
timestep_spatial,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
None,
|
||||
org_timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
if enable_temporal_attentions:
|
||||
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
|
||||
|
||||
if use_image_num != 0 and self.training:
|
||||
hidden_states_video = hidden_states[:, :frame, ...]
|
||||
hidden_states_image = hidden_states[:, frame:, ...]
|
||||
|
||||
hidden_states_video = temp_block(
|
||||
hidden_states_video,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
org_timestep,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
else:
|
||||
if i == 0 and frame > 1:
|
||||
hidden_states = hidden_states + temp_pos_embed
|
||||
hidden_states = temp_block(
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
org_timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
self.previous_residual = hidden_states - hidden_states_origin
|
||||
else:
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(frame + use_image_num)
|
||||
set_spatial_pad(num_patches)
|
||||
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
|
||||
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
|
||||
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
|
||||
temp_pos_embed = split_sequence(
|
||||
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
)
|
||||
else:
|
||||
temp_pos_embed = self.temp_pos_embed
|
||||
|
||||
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
spatial_block,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states_spatial,
|
||||
encoder_attention_mask,
|
||||
timestep_spatial,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
if enable_temporal_attentions:
|
||||
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
|
||||
|
||||
if use_image_num != 0: # image-video joitn training
|
||||
hidden_states_video = hidden_states[:, :frame, ...]
|
||||
hidden_states_image = hidden_states[:, frame:, ...]
|
||||
|
||||
if i == 0:
|
||||
hidden_states_video = hidden_states_video + temp_pos_embed
|
||||
|
||||
hidden_states_video = torch.utils.checkpoint.checkpoint(
|
||||
temp_block,
|
||||
hidden_states_video,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
else:
|
||||
if i == 0:
|
||||
hidden_states = hidden_states + temp_pos_embed
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
temp_block,
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
else:
|
||||
hidden_states = spatial_block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states_spatial,
|
||||
encoder_attention_mask,
|
||||
timestep_spatial,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
None,
|
||||
org_timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
if enable_temporal_attentions:
|
||||
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
|
||||
|
||||
if use_image_num != 0 and self.training:
|
||||
hidden_states_video = hidden_states[:, :frame, ...]
|
||||
hidden_states_image = hidden_states[:, frame:, ...]
|
||||
|
||||
hidden_states_video = temp_block(
|
||||
hidden_states_video,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
org_timestep,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
else:
|
||||
if i == 0 and frame > 1:
|
||||
hidden_states = hidden_states + temp_pos_embed
|
||||
hidden_states = temp_block(
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
org_timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
if enable_sequence_parallel():
|
||||
if self.enable_teacache:
|
||||
if should_calc:
|
||||
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
|
||||
self.previous_residual = self.gather_from_second_dim(self.previous_residual, input_batch_size)
|
||||
else:
|
||||
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
|
||||
|
||||
if self.is_input_patches:
|
||||
if self.config.norm_type != "ada_norm_single":
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
elif self.config.norm_type == "ada_norm_single":
|
||||
embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
|
||||
|
||||
# 3. Gather batch for data parallelism
|
||||
if get_cfg_parallel_size() > 1:
|
||||
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer3DModelOutput(sample=output)
|
||||
|
||||
|
||||
def eval_base(prompt_list):
|
||||
config = LatteConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
generate_func(engine, prompt_list, "./samples/latte_base", loop=5)
|
||||
|
||||
def eval_teacache_slow(prompt_list):
|
||||
config = LatteConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
engine.driver_worker.transformer.enable_teacache = True
|
||||
engine.driver_worker.transformer.rel_l1_thresh = 0.1
|
||||
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
|
||||
engine.driver_worker.transformer.previous_modulated_input = None
|
||||
engine.driver_worker.transformer.previous_residual = None
|
||||
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
||||
generate_func(engine, prompt_list, "./samples/latte_teacache_slow", loop=5)
|
||||
|
||||
def eval_teacache_fast(prompt_list):
|
||||
config = LatteConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
engine.driver_worker.transformer.enable_teacache = True
|
||||
engine.driver_worker.transformer.rel_l1_thresh = 0.2
|
||||
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
|
||||
engine.driver_worker.transformer.previous_modulated_input = None
|
||||
engine.driver_worker.transformer.previous_residual = None
|
||||
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
||||
generate_func(engine, prompt_list, "./samples/latte_teacache_fast", loop=5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
||||
# eval_base(prompt_list)
|
||||
eval_teacache_slow(prompt_list)
|
||||
# eval_teacache_fast(prompt_list)
|
||||
|
||||
@ -1,243 +0,0 @@
|
||||
from utils import generate_func, read_prompt_list
|
||||
from videosys import OpenSoraConfig, VideoSysEngine
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from videosys.models.transformers.open_sora_transformer_3d import t2i_modulate, auto_grad_checkpoint
|
||||
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_temporal_pad, set_spatial_pad, set_temporal_pad, split_sequence
|
||||
import numpy as np
|
||||
from videosys.utils.utils import batch_func
|
||||
from videosys.core.parallel_mgr import (
|
||||
enable_sequence_parallel,
|
||||
get_cfg_parallel_size,
|
||||
get_data_parallel_group,
|
||||
get_sequence_parallel_group,
|
||||
)
|
||||
|
||||
def teacache_forward(
|
||||
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
|
||||
):
|
||||
# === Split batch ===
|
||||
if get_cfg_parallel_size() > 1:
|
||||
x, timestep, y, x_mask, mask = batch_func(
|
||||
partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
|
||||
)
|
||||
|
||||
dtype = self.x_embedder.proj.weight.dtype
|
||||
B = x.size(0)
|
||||
x = x.to(dtype)
|
||||
timestep = timestep.to(dtype)
|
||||
y = y.to(dtype)
|
||||
|
||||
# === get pos embed ===
|
||||
_, _, Tx, Hx, Wx = x.size()
|
||||
T, H, W = self.get_dynamic_size(x)
|
||||
S = H * W
|
||||
base_size = round(S**0.5)
|
||||
resolution_sq = (height[0].item() * width[0].item()) ** 0.5
|
||||
scale = resolution_sq / self.input_sq_size
|
||||
pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size)
|
||||
|
||||
# === get timestep embed ===
|
||||
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
|
||||
fps = self.fps_embedder(fps.unsqueeze(1), B)
|
||||
t = t + fps
|
||||
t_mlp = self.t_block(t)
|
||||
t0 = t0_mlp = None
|
||||
if x_mask is not None:
|
||||
t0_timestep = torch.zeros_like(timestep)
|
||||
t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
|
||||
t0 = t0 + fps
|
||||
t0_mlp = self.t_block(t0)
|
||||
|
||||
# === get y embed ===
|
||||
if self.config.skip_y_embedder:
|
||||
y_lens = mask
|
||||
if isinstance(y_lens, torch.Tensor):
|
||||
y_lens = y_lens.long().tolist()
|
||||
else:
|
||||
y, y_lens = self.encode_text(y, mask)
|
||||
|
||||
# === get x embed ===
|
||||
x = self.x_embedder(x) # [B, N, C]
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = x + pos_emb
|
||||
|
||||
if self.enable_teacache:
|
||||
inp = x.clone()
|
||||
inp = rearrange(inp, "B T S C -> B (T S) C", T=T, S=S)
|
||||
B, N, C = inp.shape
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.spatial_blocks[0].scale_shift_table[None] + t_mlp.reshape(B, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
modulated_inp = t2i_modulate(self.spatial_blocks[0].norm1(inp), shift_msa, scale_msa)
|
||||
if timestep[0] == all_timesteps[0] or timestep[0] == all_timesteps[-1]:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = [2.17546007e+02, -1.18329252e+02, 2.68662585e+01, -4.59364272e-02, 4.84426240e-02]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
|
||||
|
||||
# === blocks ===
|
||||
if self.enable_teacache:
|
||||
if not should_calc:
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
x += self.previous_residual
|
||||
else:
|
||||
# shard over the sequence dim if sp is enabled
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(T)
|
||||
set_spatial_pad(S)
|
||||
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
|
||||
T = x.shape[1]
|
||||
x_mask_org = x_mask
|
||||
x_mask = split_sequence(
|
||||
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
)
|
||||
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
origin_x = x.clone().detach()
|
||||
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
|
||||
x = auto_grad_checkpoint(
|
||||
spatial_block,
|
||||
x,
|
||||
y,
|
||||
t_mlp,
|
||||
y_lens,
|
||||
x_mask,
|
||||
t0_mlp,
|
||||
T,
|
||||
S,
|
||||
timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
x = auto_grad_checkpoint(
|
||||
temporal_block,
|
||||
x,
|
||||
y,
|
||||
t_mlp,
|
||||
y_lens,
|
||||
x_mask,
|
||||
t0_mlp,
|
||||
T,
|
||||
S,
|
||||
timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
self.previous_residual = x - origin_x
|
||||
else:
|
||||
# shard over the sequence dim if sp is enabled
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(T)
|
||||
set_spatial_pad(S)
|
||||
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
|
||||
T = x.shape[1]
|
||||
x_mask_org = x_mask
|
||||
x_mask = split_sequence(
|
||||
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
)
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
|
||||
x = auto_grad_checkpoint(
|
||||
spatial_block,
|
||||
x,
|
||||
y,
|
||||
t_mlp,
|
||||
y_lens,
|
||||
x_mask,
|
||||
t0_mlp,
|
||||
T,
|
||||
S,
|
||||
timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
x = auto_grad_checkpoint(
|
||||
temporal_block,
|
||||
x,
|
||||
y,
|
||||
t_mlp,
|
||||
y_lens,
|
||||
x_mask,
|
||||
t0_mlp,
|
||||
T,
|
||||
S,
|
||||
timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
if enable_sequence_parallel():
|
||||
if self.enable_teacache:
|
||||
if should_calc:
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
self.previous_residual = rearrange(self.previous_residual, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
|
||||
self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
|
||||
T, S = x.shape[1], x.shape[2]
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
self.previous_residual = rearrange(self.previous_residual, "B T S C -> B (T S) C", T=T, S=S)
|
||||
x_mask = x_mask_org
|
||||
else:
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_temporal_pad("temporal"))
|
||||
T, S = x.shape[1], x.shape[2]
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
x_mask = x_mask_org
|
||||
|
||||
|
||||
# === final layer ===
|
||||
x = self.final_layer(x, t, x_mask, t0, T, S)
|
||||
x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
|
||||
|
||||
# cast to float32 for better accuracy
|
||||
x = x.to(torch.float32)
|
||||
|
||||
# === Gather Output ===
|
||||
if get_cfg_parallel_size() > 1:
|
||||
x = gather_sequence(x, get_data_parallel_group(), dim=0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def eval_base(prompt_list):
|
||||
config = OpenSoraConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
generate_func(engine, prompt_list, "./samples/opensora_base", loop=5)
|
||||
|
||||
def eval_teacache_slow(prompt_list):
|
||||
config = OpenSoraConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
engine.driver_worker.transformer.enable_teacache = True
|
||||
engine.driver_worker.transformer.rel_l1_thresh = 0.1
|
||||
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
|
||||
engine.driver_worker.transformer.previous_modulated_input = None
|
||||
engine.driver_worker.transformer.previous_residual = None
|
||||
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
||||
generate_func(engine, prompt_list, "./samples/opensora_teacache_slow", loop=5)
|
||||
|
||||
def eval_teacache_fast(prompt_list):
|
||||
config = OpenSoraConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
engine.driver_worker.transformer.enable_teacache = True
|
||||
engine.driver_worker.transformer.rel_l1_thresh = 0.2
|
||||
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
|
||||
engine.driver_worker.transformer.previous_modulated_input = None
|
||||
engine.driver_worker.transformer.previous_residual = None
|
||||
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
||||
generate_func(engine, prompt_list, "./samples/opensora_teacache_fast", loop=5)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
||||
# eval_base(prompt_list)
|
||||
eval_teacache_slow(prompt_list)
|
||||
# eval_teacache_fast(prompt_list)
|
||||
@ -1,594 +0,0 @@
|
||||
from utils import generate_func, read_prompt_list
|
||||
from videosys import OpenSoraPlanConfig, VideoSysEngine
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from videosys.core.parallel_mgr import (
|
||||
enable_sequence_parallel,
|
||||
get_cfg_parallel_group,
|
||||
get_cfg_parallel_size,
|
||||
get_sequence_parallel_group,
|
||||
)
|
||||
|
||||
def teacache_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
all_timesteps=None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_image_num: int = 0,
|
||||
enable_temporal_attentions: bool = True,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
||||
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
||||
|
||||
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
||||
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
||||
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. Split batch
|
||||
if get_cfg_parallel_size() > 1:
|
||||
(
|
||||
hidden_states,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
class_labels,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
) = batch_func(
|
||||
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
|
||||
hidden_states,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
class_labels,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
input_batch_size, c, frame, h, w = hidden_states.shape
|
||||
frame = frame - use_image_num # 20-4=16
|
||||
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
|
||||
org_timestep = timestep
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(input_batch_size, frame + use_image_num, h, w), device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
attention_mask = self.vae_to_diff_mask(attention_mask, use_image_num)
|
||||
dtype = attention_mask.dtype
|
||||
attention_mask_compress = F.max_pool2d(
|
||||
attention_mask.float(), kernel_size=self.compress_kv_factor, stride=self.compress_kv_factor
|
||||
)
|
||||
attention_mask_compress = attention_mask_compress.to(dtype)
|
||||
|
||||
attention_mask = self.make_attn_mask(attention_mask, frame, hidden_states.dtype)
|
||||
attention_mask_compress = self.make_attn_mask(attention_mask_compress, frame, hidden_states.dtype)
|
||||
|
||||
# 1 + 4, 1 -> video condition, 4 -> image condition
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous()
|
||||
encoder_attention_mask = encoder_attention_mask.to(self.dtype)
|
||||
elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
|
||||
encoder_attention_mask_video = repeat(
|
||||
encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame
|
||||
).contiguous()
|
||||
encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
|
||||
encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1)
|
||||
encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1)
|
||||
encoder_attention_mask = encoder_attention_mask.to(self.dtype)
|
||||
|
||||
# Retrieve lora scale.
|
||||
cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_patches: # here
|
||||
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
||||
hw = (height, width)
|
||||
num_patches = height * width
|
||||
|
||||
hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # alrady add positional embeddings
|
||||
|
||||
if self.adaln_single is not None:
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
||||
)
|
||||
# batch_size = hidden_states.shape[0]
|
||||
batch_size = input_batch_size
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = hidden_states.shape[0]
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states.to(self.dtype)) # 3 120 1152
|
||||
|
||||
if use_image_num != 0 and self.training:
|
||||
encoder_hidden_states_video = encoder_hidden_states[:, :1, ...]
|
||||
encoder_hidden_states_video = repeat(
|
||||
encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame
|
||||
).contiguous()
|
||||
encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...]
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
|
||||
encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous()
|
||||
else:
|
||||
encoder_hidden_states_spatial = repeat(
|
||||
encoder_hidden_states, "b 1 t d -> (b f) t d", f=frame
|
||||
).contiguous()
|
||||
|
||||
# prepare timesteps for spatial and temporal block
|
||||
timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
|
||||
timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
|
||||
|
||||
pos_hw, pos_t = None, None
|
||||
if self.use_rope:
|
||||
pos_hw, pos_t = self.make_position(
|
||||
input_batch_size, frame, use_image_num, height, width, hidden_states.device
|
||||
)
|
||||
|
||||
if self.enable_teacache:
|
||||
inp = hidden_states.clone()
|
||||
batch_size = hidden_states.shape[0]
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.transformer_blocks[0].scale_shift_table[None] + timestep_spatial.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
modulated_inp = self.transformer_blocks[0].norm1(inp) * (1 + scale_msa) + shift_msa
|
||||
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = [2.05943668e+05, -1.48759286e+04, 3.06085986e+02, 1.31418080e+00, 2.39658469e-03]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
|
||||
if self.enable_teacache:
|
||||
if not should_calc:
|
||||
hidden_states += self.previous_residual
|
||||
else:
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(frame + use_image_num)
|
||||
set_spatial_pad(num_patches)
|
||||
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
|
||||
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
|
||||
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
|
||||
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
|
||||
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
|
||||
temp_pos_embed = split_sequence(
|
||||
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
)
|
||||
else:
|
||||
temp_pos_embed = self.temp_pos_embed
|
||||
ori_hidden_states = hidden_states.clone()
|
||||
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
spatial_block,
|
||||
hidden_states,
|
||||
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
|
||||
encoder_hidden_states_spatial,
|
||||
encoder_attention_mask,
|
||||
timestep_spatial,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
pos_hw,
|
||||
pos_hw,
|
||||
hw,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
if enable_temporal_attentions:
|
||||
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
|
||||
|
||||
if use_image_num != 0: # image-video joitn training
|
||||
hidden_states_video = hidden_states[:, :frame, ...]
|
||||
hidden_states_image = hidden_states[:, frame:, ...]
|
||||
|
||||
# if i == 0 and not self.use_rope:
|
||||
if i == 0:
|
||||
hidden_states_video = hidden_states_video + temp_pos_embed
|
||||
|
||||
hidden_states_video = torch.utils.checkpoint.checkpoint(
|
||||
temp_block,
|
||||
hidden_states_video,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
pos_t,
|
||||
pos_t,
|
||||
(frame,),
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
else:
|
||||
# if i == 0 and not self.use_rope:
|
||||
if i == 0:
|
||||
hidden_states = hidden_states + temp_pos_embed
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
temp_block,
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
pos_t,
|
||||
pos_t,
|
||||
(frame,),
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
else:
|
||||
hidden_states = spatial_block(
|
||||
hidden_states,
|
||||
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
|
||||
encoder_hidden_states_spatial,
|
||||
encoder_attention_mask,
|
||||
timestep_spatial,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
pos_hw,
|
||||
pos_hw,
|
||||
hw,
|
||||
org_timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
if enable_temporal_attentions:
|
||||
# b c f h w, f = 16 + 4
|
||||
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
|
||||
|
||||
if use_image_num != 0 and self.training:
|
||||
hidden_states_video = hidden_states[:, :frame, ...]
|
||||
hidden_states_image = hidden_states[:, frame:, ...]
|
||||
|
||||
# if i == 0 and not self.use_rope:
|
||||
# hidden_states_video = hidden_states_video + temp_pos_embed
|
||||
|
||||
hidden_states_video = temp_block(
|
||||
hidden_states_video,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
pos_t,
|
||||
pos_t,
|
||||
(frame,),
|
||||
org_timestep,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
else:
|
||||
# if i == 0 and not self.use_rope:
|
||||
if i == 0:
|
||||
hidden_states = hidden_states + temp_pos_embed
|
||||
hidden_states = temp_block(
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
pos_t,
|
||||
pos_t,
|
||||
(frame,),
|
||||
org_timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
self.previous_residual = hidden_states - ori_hidden_states
|
||||
else:
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(frame + use_image_num)
|
||||
set_spatial_pad(num_patches)
|
||||
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
|
||||
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
|
||||
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
|
||||
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
|
||||
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
|
||||
temp_pos_embed = split_sequence(
|
||||
self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
)
|
||||
else:
|
||||
temp_pos_embed = self.temp_pos_embed
|
||||
|
||||
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
spatial_block,
|
||||
hidden_states,
|
||||
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
|
||||
encoder_hidden_states_spatial,
|
||||
encoder_attention_mask,
|
||||
timestep_spatial,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
pos_hw,
|
||||
pos_hw,
|
||||
hw,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
if enable_temporal_attentions:
|
||||
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
|
||||
|
||||
if use_image_num != 0: # image-video joitn training
|
||||
hidden_states_video = hidden_states[:, :frame, ...]
|
||||
hidden_states_image = hidden_states[:, frame:, ...]
|
||||
|
||||
# if i == 0 and not self.use_rope:
|
||||
if i == 0:
|
||||
hidden_states_video = hidden_states_video + temp_pos_embed
|
||||
|
||||
hidden_states_video = torch.utils.checkpoint.checkpoint(
|
||||
temp_block,
|
||||
hidden_states_video,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
pos_t,
|
||||
pos_t,
|
||||
(frame,),
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
else:
|
||||
# if i == 0 and not self.use_rope:
|
||||
if i == 0:
|
||||
hidden_states = hidden_states + temp_pos_embed
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
temp_block,
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
pos_t,
|
||||
pos_t,
|
||||
(frame,),
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
else:
|
||||
hidden_states = spatial_block(
|
||||
hidden_states,
|
||||
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
|
||||
encoder_hidden_states_spatial,
|
||||
encoder_attention_mask,
|
||||
timestep_spatial,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
pos_hw,
|
||||
pos_hw,
|
||||
hw,
|
||||
org_timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
if enable_temporal_attentions:
|
||||
# b c f h w, f = 16 + 4
|
||||
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
|
||||
|
||||
if use_image_num != 0 and self.training:
|
||||
hidden_states_video = hidden_states[:, :frame, ...]
|
||||
hidden_states_image = hidden_states[:, frame:, ...]
|
||||
|
||||
# if i == 0 and not self.use_rope:
|
||||
# hidden_states_video = hidden_states_video + temp_pos_embed
|
||||
|
||||
hidden_states_video = temp_block(
|
||||
hidden_states_video,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
pos_t,
|
||||
pos_t,
|
||||
(frame,),
|
||||
org_timestep,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
else:
|
||||
# if i == 0 and not self.use_rope:
|
||||
if i == 0:
|
||||
hidden_states = hidden_states + temp_pos_embed
|
||||
hidden_states = temp_block(
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
pos_t,
|
||||
pos_t,
|
||||
(frame,),
|
||||
org_timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
||||
).contiguous()
|
||||
|
||||
if enable_sequence_parallel():
|
||||
if self.enable_teacache:
|
||||
if should_calc:
|
||||
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
|
||||
self.previous_residual = self.gather_from_second_dim(self.previous_residual, input_batch_size)
|
||||
else:
|
||||
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
|
||||
|
||||
if self.is_input_patches:
|
||||
if self.config.norm_type != "ada_norm_single":
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
elif self.config.norm_type == "ada_norm_single":
|
||||
embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
|
||||
|
||||
# 3. Gather batch for data parallelism
|
||||
if get_cfg_parallel_size() > 1:
|
||||
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer3DModelOutput(sample=output)
|
||||
|
||||
|
||||
|
||||
def eval_base(prompt_list):
|
||||
config = OpenSoraPlanConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)
|
||||
|
||||
def eval_teacache_slow(prompt_list):
|
||||
config = OpenSoraPlanConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
engine.driver_worker.transformer.enable_teacache = True
|
||||
engine.driver_worker.transformer.rel_l1_thresh = 0.1
|
||||
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
|
||||
engine.driver_worker.transformer.previous_modulated_input = None
|
||||
engine.driver_worker.transformer.previous_residual = None
|
||||
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
||||
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_slow", loop=5)
|
||||
|
||||
def eval_teacache_fast(prompt_list):
|
||||
config = OpenSoraPlanConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
engine.driver_worker.transformer.enable_teacache = True
|
||||
engine.driver_worker.transformer.rel_l1_thresh = 0.2
|
||||
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
|
||||
engine.driver_worker.transformer.previous_modulated_input = None
|
||||
engine.driver_worker.transformer.previous_residual = None
|
||||
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
||||
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_fast", loop=5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
||||
# eval_base(prompt_list)
|
||||
eval_teacache_slow(prompt_list)
|
||||
# eval_teacache_fast(prompt_list)
|
||||
@ -1,22 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import tqdm
|
||||
|
||||
from videosys.utils.utils import set_seed
|
||||
|
||||
|
||||
def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict = {}):
|
||||
kwargs["verbose"] = False
|
||||
for prompt in tqdm.tqdm(prompt_list):
|
||||
for l in range(loop):
|
||||
set_seed(l)
|
||||
video = pipeline.generate(prompt, **kwargs).video[0]
|
||||
pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4"))
|
||||
|
||||
|
||||
def read_prompt_list(prompt_list_path):
|
||||
with open(prompt_list_path, "r") as f:
|
||||
prompt_list = json.load(f)
|
||||
prompt_list = [prompt["prompt_en"] for prompt in prompt_list]
|
||||
return prompt_list
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,154 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
SEMANTIC_WEIGHT = 1
|
||||
QUALITY_WEIGHT = 4
|
||||
|
||||
QUALITY_LIST = [
|
||||
"subject consistency",
|
||||
"background consistency",
|
||||
"temporal flickering",
|
||||
"motion smoothness",
|
||||
"aesthetic quality",
|
||||
"imaging quality",
|
||||
"dynamic degree",
|
||||
]
|
||||
|
||||
SEMANTIC_LIST = [
|
||||
"object class",
|
||||
"multiple objects",
|
||||
"human action",
|
||||
"color",
|
||||
"spatial relationship",
|
||||
"scene",
|
||||
"appearance style",
|
||||
"temporal style",
|
||||
"overall consistency",
|
||||
]
|
||||
|
||||
NORMALIZE_DIC = {
|
||||
"subject consistency": {"Min": 0.1462, "Max": 1.0},
|
||||
"background consistency": {"Min": 0.2615, "Max": 1.0},
|
||||
"temporal flickering": {"Min": 0.6293, "Max": 1.0},
|
||||
"motion smoothness": {"Min": 0.706, "Max": 0.9975},
|
||||
"dynamic degree": {"Min": 0.0, "Max": 1.0},
|
||||
"aesthetic quality": {"Min": 0.0, "Max": 1.0},
|
||||
"imaging quality": {"Min": 0.0, "Max": 1.0},
|
||||
"object class": {"Min": 0.0, "Max": 1.0},
|
||||
"multiple objects": {"Min": 0.0, "Max": 1.0},
|
||||
"human action": {"Min": 0.0, "Max": 1.0},
|
||||
"color": {"Min": 0.0, "Max": 1.0},
|
||||
"spatial relationship": {"Min": 0.0, "Max": 1.0},
|
||||
"scene": {"Min": 0.0, "Max": 0.8222},
|
||||
"appearance style": {"Min": 0.0009, "Max": 0.2855},
|
||||
"temporal style": {"Min": 0.0, "Max": 0.364},
|
||||
"overall consistency": {"Min": 0.0, "Max": 0.364},
|
||||
}
|
||||
|
||||
DIM_WEIGHT = {
|
||||
"subject consistency": 1,
|
||||
"background consistency": 1,
|
||||
"temporal flickering": 1,
|
||||
"motion smoothness": 1,
|
||||
"aesthetic quality": 1,
|
||||
"imaging quality": 1,
|
||||
"dynamic degree": 0.5,
|
||||
"object class": 1,
|
||||
"multiple objects": 1,
|
||||
"human action": 1,
|
||||
"color": 1,
|
||||
"spatial relationship": 1,
|
||||
"scene": 1,
|
||||
"appearance style": 1,
|
||||
"temporal style": 1,
|
||||
"overall consistency": 1,
|
||||
}
|
||||
|
||||
ordered_scaled_res = [
|
||||
"total score",
|
||||
"quality score",
|
||||
"semantic score",
|
||||
"subject consistency",
|
||||
"background consistency",
|
||||
"temporal flickering",
|
||||
"motion smoothness",
|
||||
"dynamic degree",
|
||||
"aesthetic quality",
|
||||
"imaging quality",
|
||||
"object class",
|
||||
"multiple objects",
|
||||
"human action",
|
||||
"color",
|
||||
"spatial relationship",
|
||||
"scene",
|
||||
"appearance style",
|
||||
"temporal style",
|
||||
"overall consistency",
|
||||
]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--score_dir", required=True, type=str)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
res_postfix = "_eval_results.json"
|
||||
info_postfix = "_full_info.json"
|
||||
files = os.listdir(args.score_dir)
|
||||
res_files = [x for x in files if res_postfix in x]
|
||||
info_files = [x for x in files if info_postfix in x]
|
||||
assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files"
|
||||
|
||||
full_results = {}
|
||||
for res_file in res_files:
|
||||
# first check if results is normal
|
||||
info_file = res_file.split(res_postfix)[0] + info_postfix
|
||||
with open(os.path.join(args.score_dir, info_file), "r", encoding="utf-8") as f:
|
||||
info = json.load(f)
|
||||
assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list"
|
||||
# read results
|
||||
with open(os.path.join(args.score_dir, res_file), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
for key, val in data.items():
|
||||
full_results[key] = format(val[0], ".4f")
|
||||
|
||||
scaled_results = {}
|
||||
dims = set()
|
||||
for key, val in full_results.items():
|
||||
dim = key.replace("_", " ") if "_" in key else key
|
||||
scaled_score = (float(val) - NORMALIZE_DIC[dim]["Min"]) / (
|
||||
NORMALIZE_DIC[dim]["Max"] - NORMALIZE_DIC[dim]["Min"]
|
||||
)
|
||||
scaled_score *= DIM_WEIGHT[dim]
|
||||
scaled_results[dim] = scaled_score
|
||||
dims.add(dim)
|
||||
|
||||
assert len(dims) == len(NORMALIZE_DIC), f"{set(NORMALIZE_DIC.keys())-dims} not calculated yet"
|
||||
|
||||
quality_score = sum([scaled_results[i] for i in QUALITY_LIST]) / sum([DIM_WEIGHT[i] for i in QUALITY_LIST])
|
||||
semantic_score = sum([scaled_results[i] for i in SEMANTIC_LIST]) / sum([DIM_WEIGHT[i] for i in SEMANTIC_LIST])
|
||||
scaled_results["quality score"] = quality_score
|
||||
scaled_results["semantic score"] = semantic_score
|
||||
scaled_results["total score"] = (quality_score * QUALITY_WEIGHT + semantic_score * SEMANTIC_WEIGHT) / (
|
||||
QUALITY_WEIGHT + SEMANTIC_WEIGHT
|
||||
)
|
||||
|
||||
formated_scaled_results = {"items": []}
|
||||
for key in ordered_scaled_res:
|
||||
formated_score = format(scaled_results[key] * 100, ".2f") + "%"
|
||||
formated_scaled_results["items"].append({key: formated_score})
|
||||
|
||||
output_file_path = os.path.join(args.score_dir, "all_results.json")
|
||||
with open(output_file_path, "w") as outfile:
|
||||
json.dump(full_results, outfile, indent=4, sort_keys=True)
|
||||
print(f"results saved to: {output_file_path}")
|
||||
|
||||
scaled_file_path = os.path.join(args.score_dir, "scaled_results.json")
|
||||
with open(scaled_file_path, "w") as outfile:
|
||||
json.dump(formated_scaled_results, outfile, indent=4, sort_keys=True)
|
||||
print(f"results saved to: {scaled_file_path}")
|
||||
@ -1,52 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from vbench import VBench
|
||||
|
||||
full_info_path = "./vbench/VBench_full_info.json"
|
||||
|
||||
dimensions = [
|
||||
"subject_consistency",
|
||||
"imaging_quality",
|
||||
"background_consistency",
|
||||
"motion_smoothness",
|
||||
"overall_consistency",
|
||||
"human_action",
|
||||
"multiple_objects",
|
||||
"spatial_relationship",
|
||||
"object_class",
|
||||
"color",
|
||||
"aesthetic_quality",
|
||||
"appearance_style",
|
||||
"temporal_flickering",
|
||||
"scene",
|
||||
"temporal_style",
|
||||
"dynamic_degree",
|
||||
]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--video_path", required=True, type=str)
|
||||
parser.add_argument("--save_path", required=True, type=str)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
kwargs = {}
|
||||
kwargs["imaging_quality_preprocessing_mode"] = "longer" # use VBench/evaluate.py default
|
||||
|
||||
for dimension in dimensions:
|
||||
my_VBench = VBench(torch.device("cuda"), full_info_path, args.save_path)
|
||||
my_VBench.evaluate(
|
||||
videos_path=args.video_path,
|
||||
name=dimension,
|
||||
local=False,
|
||||
read_frame=False,
|
||||
dimension_list=[dimension],
|
||||
mode="vbench_standard",
|
||||
**kwargs,
|
||||
)
|
||||
@ -1,22 +0,0 @@
|
||||
click
|
||||
colossalai
|
||||
diffusers==0.30.0
|
||||
einops
|
||||
fabric
|
||||
ftfy
|
||||
imageio
|
||||
imageio-ffmpeg
|
||||
matplotlib
|
||||
ninja
|
||||
numpy<2.0.0
|
||||
omegaconf
|
||||
packaging
|
||||
psutil
|
||||
pydantic
|
||||
ray
|
||||
rich
|
||||
safetensors
|
||||
timm
|
||||
torch>=1.13
|
||||
tqdm
|
||||
transformers
|
||||
55
setup.py
55
setup.py
@ -1,55 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def fetch_requirements(path) -> List[str]:
|
||||
"""
|
||||
This function reads the requirements file.
|
||||
|
||||
Args:
|
||||
path (str): the path to the requirements file.
|
||||
|
||||
Returns:
|
||||
The lines in the requirements file.
|
||||
"""
|
||||
with open(path, "r") as fd:
|
||||
return [r.strip() for r in fd.readlines()]
|
||||
|
||||
|
||||
def fetch_readme() -> str:
|
||||
"""
|
||||
This function reads the README.md file in the current directory.
|
||||
|
||||
Returns:
|
||||
The lines in the README file.
|
||||
"""
|
||||
with open("README.md", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
setup(
|
||||
name="videosys",
|
||||
version="2.0.0",
|
||||
packages=find_packages(
|
||||
exclude=(
|
||||
"videos",
|
||||
"tests",
|
||||
"figure",
|
||||
"*.egg-info",
|
||||
)
|
||||
),
|
||||
description="VideoSys",
|
||||
long_description=fetch_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
license="Apache Software License 2.0",
|
||||
install_requires=fetch_requirements("requirements.txt"),
|
||||
python_requires=">=3.6",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Environment :: GPU :: NVIDIA CUDA",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: System :: Distributed Computing",
|
||||
],
|
||||
)
|
||||
@ -1,223 +0,0 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: videosys
|
||||
Version: 2.0.0
|
||||
Summary: VideoSys
|
||||
License: Apache Software License 2.0
|
||||
Platform: UNKNOWN
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: License :: OSI Approved :: Apache Software License
|
||||
Classifier: Environment :: GPU :: NVIDIA CUDA
|
||||
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
||||
Classifier: Topic :: System :: Distributed Computing
|
||||
Requires-Python: >=3.6
|
||||
Description-Content-Type: text/markdown
|
||||
License-File: LICENSE
|
||||
Requires-Dist: click
|
||||
Requires-Dist: colossalai
|
||||
Requires-Dist: contexttimer
|
||||
Requires-Dist: diffusers==0.30.0
|
||||
Requires-Dist: einops
|
||||
Requires-Dist: fabric
|
||||
Requires-Dist: ftfy
|
||||
Requires-Dist: imageio
|
||||
Requires-Dist: imageio-ffmpeg
|
||||
Requires-Dist: matplotlib
|
||||
Requires-Dist: ninja
|
||||
Requires-Dist: numpy<2.0.0
|
||||
Requires-Dist: omegaconf
|
||||
Requires-Dist: packaging
|
||||
Requires-Dist: psutil
|
||||
Requires-Dist: pydantic
|
||||
Requires-Dist: ray
|
||||
Requires-Dist: rich
|
||||
Requires-Dist: safetensors
|
||||
Requires-Dist: timm
|
||||
Requires-Dist: torch>=1.13
|
||||
Requires-Dist: tqdm
|
||||
Requires-Dist: transformers
|
||||
|
||||
<p align="center">
|
||||
<img width="55%" alt="VideoSys" src="./assets/figures/logo.png?raw=true">
|
||||
</p>
|
||||
<h3 align="center">
|
||||
An easy and efficient system for video generation
|
||||
</h3>
|
||||
</p>
|
||||
|
||||
### Latest News 🔥
|
||||
- [2024/08] 🔥 Evole from [OpenDiT](https://github.com/NUS-HPC-AI-Lab/VideoSys/tree/v1.0.0) to <b>VideoSys: An easy and efficient system for video generation.</b>
|
||||
- [2024/08] 🔥 <b>Release PAB paper: [Real-Time Video Generation with Pyramid Attention Broadcast](https://arxiv.org/abs/2408.12588).</b>
|
||||
- [2024/06] Propose Pyramid Attention Broadcast (PAB)[[paper](https://arxiv.org/abs/2408.12588)][[blog](https://oahzxl.github.io/PAB/)][[doc](./docs/pab.md)], the first approach to achieve <b>real-time</b> DiT-based video generation, delivering <b>negligible quality loss</b> without <b>requiring any training</b>.
|
||||
- [2024/06] Support [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) and [Latte](https://github.com/Vchitect/Latte).
|
||||
- [2024/03] Propose Dynamic Sequence Parallel (DSP)[[paper](https://arxiv.org/abs/2403.10266)][[doc](./docs/dsp.md)], achieves **3x** speed for training and **2x** speed for inference in Open-Sora compared with sota sequence parallelism.
|
||||
- [2024/03] Support [Open-Sora: Democratizing Efficient Video Production for All](https://github.com/hpcaitech/Open-Sora).
|
||||
- [2024/02] 🎉 Release [OpenDiT](https://github.com/NUS-HPC-AI-Lab/VideoSys/tree/v1.0.0): An Easy, Fast and Memory-Efficent System for DiT Training and Inference.
|
||||
|
||||
# About
|
||||
|
||||
VideoSys is an open-source project that provides a user-friendly and high-performance infrastructure for video generation. This comprehensive toolkit will support the entire pipeline from training and inference to serving and compression.
|
||||
|
||||
We are committed to continually integrating cutting-edge open-source video models and techniques. Stay tuned for exciting enhancements and new features on the horizon!
|
||||
|
||||
## Installation
|
||||
|
||||
Prerequisites:
|
||||
|
||||
- Python >= 3.10
|
||||
- PyTorch >= 1.13 (We recommend to use a >2.0 version)
|
||||
- CUDA >= 11.6
|
||||
|
||||
We strongly recommend using Anaconda to create a new environment (Python >= 3.10) to run our examples:
|
||||
|
||||
```shell
|
||||
conda create -n videosys python=3.10 -y
|
||||
conda activate videosys
|
||||
```
|
||||
|
||||
Install VideoSys:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/NUS-HPC-AI-Lab/VideoSys
|
||||
cd VideoSys
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
VideoSys supports many diffusion models with our various acceleration techniques, enabling these models to run faster and consume less memory.
|
||||
|
||||
<b>You can find all available models and their supported acceleration techniques in the following table. Click `Doc` to see how to use them.</b>
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th rowspan="2">Model</th>
|
||||
<th rowspan="2">Train</th>
|
||||
<th rowspan="2">Infer</th>
|
||||
<th colspan="2">Acceleration Techniques</th>
|
||||
<th rowspan="2">Usage</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th><a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#dyanmic-sequence-parallelism-dsp-paperdoc">DSP</a></th>
|
||||
<th><a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#pyramid-attention-broadcast-pab-blogdoc">PAB</a></th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Open-Sora [<a href="https://github.com/hpcaitech/Open-Sora">source</a>]</td>
|
||||
<td align="center">🟡</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center"><a href="./examples/open_sora/sample.py">Code</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Open-Sora-Plan [<a href="https://github.com/PKU-YuanGroup/Open-Sora-Plan">source</a>]</td>
|
||||
<td align="center">/</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center"><a href="./examples/open_sora_plan/sample.py">Code</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Latte [<a href="https://github.com/Vchitect/Latte">source</a>]</td>
|
||||
<td align="center">/</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center"><a href="./examples/latte/sample.py">Code</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>CogVideoX [<a href="https://github.com/THUDM/CogVideo">source</a>]</td>
|
||||
<td align="center">/</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">/</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center"><a href="./examples/cogvideox/sample.py">Code</a></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Acceleration Techniques
|
||||
|
||||
### Pyramid Attention Broadcast (PAB) [[paper](https://arxiv.org/abs/2408.12588)][[blog](https://arxiv.org/abs/2403.10266)][[doc](./docs/pab.md)]
|
||||
|
||||
Real-Time Video Generation with Pyramid Attention Broadcast
|
||||
|
||||
Authors: [Xuanlei Zhao](https://oahzxl.github.io/)<sup>1*</sup>, [Xiaolong Jin]()<sup>2*</sup>, [Kai Wang](https://kaiwang960112.github.io/)<sup>1*</sup>, and [Yang You](https://www.comp.nus.edu.sg/~youy/)<sup>1</sup> (* indicates equal contribution)
|
||||
|
||||
<sup>1</sup>National University of Singapore, <sup>2</sup>Purdue University
|
||||
|
||||

|
||||
|
||||
PAB is the first approach to achieve <b>real-time</b> DiT-based video generation, delivering <b>lossless quality</b> without <b>requiring any training</b>. By mitigating redundant attention computation, PAB achieves up to 21.6 FPS with 10.6x acceleration, without sacrificing quality across popular DiT-based video generation models including [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Latte](https://github.com/Vchitect/Latte) and [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan).
|
||||
|
||||
See its details [here](./docs/pab.md).
|
||||
|
||||
----
|
||||
|
||||
### Dyanmic Sequence Parallelism (DSP) [[paper](https://arxiv.org/abs/2403.10266)][[doc](./docs/dsp.md)]
|
||||
|
||||

|
||||
|
||||
DSP is a novel, elegant and super efficient sequence parallelism for [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Latte](https://github.com/Vchitect/Latte) and other multi-dimensional transformer architecture.
|
||||
|
||||
It achieves **3x** speed for training and **2x** speed for inference in Open-Sora compared with sota sequence parallelism ([DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509)). For a 10s (80 frames) of 512x512 video, the inference latency of Open-Sora is:
|
||||
|
||||
| Method | 1xH800 | 8xH800 (DS Ulysses) | 8xH800 (DSP) |
|
||||
| ------ | ------ | ------ | ------ |
|
||||
| Latency(s) | 106 | 45 | 22 |
|
||||
|
||||
See its details [here](./docs/dsp.md).
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome and value any contributions and collaborations. Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
||||
|
||||
## Contributors
|
||||
|
||||
<a href="https://github.com/NUS-HPC-AI-Lab/VideoSys/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=NUS-HPC-AI-Lab/VideoSys"/>
|
||||
</a>
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#NUS-HPC-AI-Lab/VideoSys&Date)
|
||||
|
||||
## Citation
|
||||
|
||||
```
|
||||
@misc{videosys2024,
|
||||
author={VideoSys Team},
|
||||
title={VideoSys: An Easy and Efficient System for Video Generation},
|
||||
year={2024},
|
||||
publisher={GitHub},
|
||||
url = {https://github.com/NUS-HPC-AI-Lab/VideoSys},
|
||||
}
|
||||
|
||||
@misc{zhao2024pab,
|
||||
title={Real-Time Video Generation with Pyramid Attention Broadcast},
|
||||
author={Xuanlei Zhao and Xiaolong Jin and Kai Wang and Yang You},
|
||||
year={2024},
|
||||
eprint={2408.12588},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV},
|
||||
url={https://arxiv.org/abs/2408.12588},
|
||||
}
|
||||
|
||||
@misc{zhao2024dsp,
|
||||
title={DSP: Dynamic Sequence Parallelism for Multi-Dimensional Transformers},
|
||||
author={Xuanlei Zhao and Shenggan Cheng and Chang Chen and Zangwei Zheng and Ziming Liu and Zheming Yang and Yang You},
|
||||
year={2024},
|
||||
eprint={2403.10266},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.DC},
|
||||
url={https://arxiv.org/abs/2403.10266},
|
||||
}
|
||||
|
||||
@misc{zhao2024opendit,
|
||||
author={Xuanlei Zhao, Zhongkai Zhao, Ziming Liu, Haotian Zhou, Qianli Ma, and Yang You},
|
||||
title={OpenDiT: An Easy, Fast and Memory-Efficient System for DiT Training and Inference},
|
||||
year={2024},
|
||||
publisher={GitHub},
|
||||
url={https://github.com/NUS-HPC-AI-Lab/VideoSys/tree/v1.0.0},
|
||||
}
|
||||
```
|
||||
@ -1,60 +0,0 @@
|
||||
LICENSE
|
||||
README.md
|
||||
setup.py
|
||||
tests/pipelines/__init__.py
|
||||
tests/pipelines/cogvideox/__init__.py
|
||||
tests/pipelines/cogvideox/test_cogvideox.py
|
||||
tests/pipelines/latte/__init__.py
|
||||
tests/pipelines/latte/test_latte.py
|
||||
tests/pipelines/open_sora/__init__.py
|
||||
tests/pipelines/open_sora/test_open_sora.py
|
||||
tests/pipelines/open_sora_plan/__init__.py
|
||||
tests/pipelines/open_sora_plan/test_open_sora_plan.py
|
||||
videosys/__init__.py
|
||||
videosys.egg-info/PKG-INFO
|
||||
videosys.egg-info/SOURCES.txt
|
||||
videosys.egg-info/dependency_links.txt
|
||||
videosys.egg-info/requires.txt
|
||||
videosys.egg-info/top_level.txt
|
||||
videosys/core/__init__.py
|
||||
videosys/core/comm.py
|
||||
videosys/core/engine.py
|
||||
videosys/core/mp_utils.py
|
||||
videosys/core/pab_mgr.py
|
||||
videosys/core/parallel_mgr.py
|
||||
videosys/core/pipeline.py
|
||||
videosys/core/shardformer/__init__.py
|
||||
videosys/core/shardformer/t5/__init__.py
|
||||
videosys/core/shardformer/t5/modeling.py
|
||||
videosys/core/shardformer/t5/policy.py
|
||||
videosys/models/__init__.py
|
||||
videosys/models/autoencoders/__init__.py
|
||||
videosys/models/autoencoders/autoencoder_kl_cogvideox.py
|
||||
videosys/models/autoencoders/autoencoder_kl_open_sora.py
|
||||
videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py
|
||||
videosys/models/modules/__init__.py
|
||||
videosys/models/modules/activations.py
|
||||
videosys/models/modules/attentions.py
|
||||
videosys/models/modules/downsampling.py
|
||||
videosys/models/modules/embeddings.py
|
||||
videosys/models/modules/normalization.py
|
||||
videosys/models/modules/upsampling.py
|
||||
videosys/models/transformers/__init__.py
|
||||
videosys/models/transformers/cogvideox_transformer_3d.py
|
||||
videosys/models/transformers/latte_transformer_3d.py
|
||||
videosys/models/transformers/open_sora_plan_transformer_3d.py
|
||||
videosys/models/transformers/open_sora_transformer_3d.py
|
||||
videosys/pipelines/__init__.py
|
||||
videosys/pipelines/cogvideox/__init__.py
|
||||
videosys/pipelines/cogvideox/pipeline_cogvideox.py
|
||||
videosys/pipelines/latte/__init__.py
|
||||
videosys/pipelines/latte/pipeline_latte.py
|
||||
videosys/pipelines/open_sora/__init__.py
|
||||
videosys/pipelines/open_sora/data_process.py
|
||||
videosys/pipelines/open_sora/pipeline_open_sora.py
|
||||
videosys/pipelines/open_sora_plan/__init__.py
|
||||
videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py
|
||||
videosys/schedulers/__init__.py
|
||||
videosys/schedulers/scheduling_ddim_cogvideox.py
|
||||
videosys/schedulers/scheduling_dpm_cogvideox.py
|
||||
videosys/schedulers/scheduling_rflow_open_sora.py
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
click
|
||||
colossalai
|
||||
contexttimer
|
||||
diffusers==0.30.0
|
||||
einops
|
||||
fabric
|
||||
ftfy
|
||||
imageio
|
||||
imageio-ffmpeg
|
||||
matplotlib
|
||||
ninja
|
||||
numpy<2.0.0
|
||||
omegaconf
|
||||
packaging
|
||||
psutil
|
||||
pydantic
|
||||
ray
|
||||
rich
|
||||
safetensors
|
||||
timm
|
||||
torch>=1.13
|
||||
tqdm
|
||||
transformers
|
||||
@ -1,2 +0,0 @@
|
||||
tests
|
||||
videosys
|
||||
@ -1,15 +0,0 @@
|
||||
from .core.engine import VideoSysEngine
|
||||
from .core.parallel_mgr import initialize
|
||||
from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
|
||||
from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline
|
||||
from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
|
||||
from .pipelines.open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
|
||||
|
||||
__all__ = [
|
||||
"initialize",
|
||||
"VideoSysEngine",
|
||||
"LattePipeline", "LatteConfig", "LattePABConfig",
|
||||
"OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanPABConfig",
|
||||
"OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig",
|
||||
"CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"
|
||||
] # fmt: skip
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,420 +0,0 @@
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from videosys.core.parallel_mgr import get_sequence_parallel_size
|
||||
|
||||
# ======================================================
|
||||
# Model
|
||||
# ======================================================
|
||||
|
||||
|
||||
def model_sharding(model: torch.nn.Module):
|
||||
global_rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
for _, param in model.named_parameters():
|
||||
padding_size = (world_size - param.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
||||
else:
|
||||
padding_param = param.data.view(-1)
|
||||
splited_params = padding_param.split(padding_param.numel() // world_size)
|
||||
splited_params = splited_params[global_rank]
|
||||
param.data = splited_params
|
||||
|
||||
|
||||
# ======================================================
|
||||
# AllGather & ReduceScatter
|
||||
# ======================================================
|
||||
|
||||
|
||||
class AsyncAllGatherForTwo(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
sp_rank: int,
|
||||
sp_size: int,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
from torch.distributed._functional_collectives import all_gather_tensor
|
||||
|
||||
ctx.group = group
|
||||
ctx.sp_rank = sp_rank
|
||||
ctx.sp_size = sp_size
|
||||
|
||||
# all gather inputs
|
||||
all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group)
|
||||
# compute local qkv
|
||||
local_qkv = F.linear(inputs, weight, bias).unsqueeze(0)
|
||||
|
||||
# remote compute
|
||||
remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1])
|
||||
# compute remote qkv
|
||||
remote_qkv = F.linear(remote_inputs, weight, bias)
|
||||
|
||||
# concat local and remote qkv
|
||||
if sp_rank == 0:
|
||||
qkv = torch.cat([local_qkv, remote_qkv], dim=0)
|
||||
else:
|
||||
qkv = torch.cat([remote_qkv, local_qkv], dim=0)
|
||||
qkv = rearrange(qkv, "sp b n c -> b (sp n) c")
|
||||
|
||||
ctx.save_for_backward(inputs, weight, remote_inputs)
|
||||
return qkv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
from torch.distributed._functional_collectives import reduce_scatter_tensor
|
||||
|
||||
group = ctx.group
|
||||
sp_rank = ctx.sp_rank
|
||||
sp_size = ctx.sp_size
|
||||
inputs, weight, remote_inputs = ctx.saved_tensors
|
||||
|
||||
# split qkv_grad
|
||||
qkv_grad = grad_outputs[0]
|
||||
qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size)
|
||||
qkv_grad = torch.chunk(qkv_grad, 2, dim=0)
|
||||
if sp_rank == 0:
|
||||
local_qkv_grad, remote_qkv_grad = qkv_grad
|
||||
else:
|
||||
remote_qkv_grad, local_qkv_grad = qkv_grad
|
||||
|
||||
# compute remote grad
|
||||
remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0)
|
||||
weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0)
|
||||
bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0)
|
||||
|
||||
# launch async reduce scatter
|
||||
remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad)
|
||||
if sp_rank == 0:
|
||||
remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0)
|
||||
else:
|
||||
remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0)
|
||||
remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group)
|
||||
|
||||
# compute local grad and wait for reduce scatter
|
||||
local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0)
|
||||
weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0)
|
||||
bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0)
|
||||
|
||||
# sum remote and local grad
|
||||
inputs_grad = remote_inputs_grad + local_input_grad
|
||||
return inputs_grad, weight_grad, bias_grad, None, None, None
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
overlap: bool = False,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
assert ctx is not None or not overlap
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.unsqueeze(0), None
|
||||
|
||||
buffer_shape = (comm_size,) + inputs.shape
|
||||
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||||
if not overlap:
|
||||
dist.all_gather(buffer_list, inputs, group=group)
|
||||
return outputs, None
|
||||
else:
|
||||
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
|
||||
return outputs, handle
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
return (
|
||||
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class ReduceScatter(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
group: ProcessGroup,
|
||||
overlap: bool = False,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
assert ctx is not None or not overlap
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.squeeze(0), None
|
||||
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
output_shape = inputs.shape[1:]
|
||||
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||||
if not overlap:
|
||||
dist.reduce_scatter(outputs, buffer_list, group=group)
|
||||
return outputs, None
|
||||
else:
|
||||
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
|
||||
return outputs, handle
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
# TODO: support async backward
|
||||
return (
|
||||
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# ======================================================
|
||||
# AlltoAll
|
||||
# ======================================================
|
||||
|
||||
|
||||
def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim):
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
"""All-to-all communication.
|
||||
|
||||
Args:
|
||||
input_: input matrix
|
||||
process_group: communication group
|
||||
scatter_dim: scatter dimension
|
||||
gather_dim: gather dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
||||
ctx.process_group = process_group
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
world_size = dist.get_world_size(process_group)
|
||||
|
||||
return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_output):
|
||||
process_group = ctx.process_group
|
||||
scatter_dim = ctx.gather_dim
|
||||
gather_dim = ctx.scatter_dim
|
||||
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
|
||||
return (return_grad, None, None, None)
|
||||
|
||||
|
||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||
|
||||
|
||||
# ======================================================
|
||||
# Sequence Gather & Split
|
||||
# ======================================================
|
||||
|
||||
|
||||
def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
|
||||
# skip if only one rank involved
|
||||
world_size = dist.get_world_size(pg)
|
||||
rank = dist.get_rank(pg)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
if pad > 0:
|
||||
pad_size = list(input_.shape)
|
||||
pad_size[dim] = pad
|
||||
input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim)
|
||||
|
||||
dim_size = input_.size(dim)
|
||||
assert dim_size % world_size == 0, f"dim_size ({dim_size}) is not divisible by world_size ({world_size})"
|
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||
output = tensor_list[rank].contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _gather_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
|
||||
# skip if only one rank involved
|
||||
input_ = input_.contiguous()
|
||||
world_size = dist.get_world_size(pg)
|
||||
dist.get_rank(pg)
|
||||
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
assert input_.device.type == "cuda"
|
||||
torch.distributed.all_gather(tensor_list, input_, group=pg)
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim)
|
||||
|
||||
if pad > 0:
|
||||
output = output.narrow(dim, 0, output.size(dim) - pad)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""
|
||||
Gather the input sequence.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: process group.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _gather_sequence_func(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim, grad_scale, pad):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.pad = pad
|
||||
return _gather_sequence_func(input_, process_group, dim, pad)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale == "up":
|
||||
grad_output = grad_output * dist.get_world_size(ctx.process_group)
|
||||
elif ctx.grad_scale == "down":
|
||||
grad_output = grad_output / dist.get_world_size(ctx.process_group)
|
||||
|
||||
return _split_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
Split sequence.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split_sequence_func(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim, grad_scale, pad):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.pad = pad
|
||||
return _split_sequence_func(input_, process_group, dim, pad)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale == "up":
|
||||
grad_output = grad_output * dist.get_world_size(ctx.process_group)
|
||||
elif ctx.grad_scale == "down":
|
||||
grad_output = grad_output / dist.get_world_size(ctx.process_group)
|
||||
return _gather_sequence_func(grad_output, ctx.process_group, ctx.pad), None, None, None, None
|
||||
|
||||
|
||||
def split_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
|
||||
return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad)
|
||||
|
||||
|
||||
def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
|
||||
return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad)
|
||||
|
||||
|
||||
# ==============================
|
||||
# Pad
|
||||
# ==============================
|
||||
|
||||
SPTIAL_PAD = 0
|
||||
TEMPORAL_PAD = 0
|
||||
|
||||
|
||||
def set_spatial_pad(dim_size: int):
|
||||
sp_size = get_sequence_parallel_size()
|
||||
pad = (sp_size - (dim_size % sp_size)) % sp_size
|
||||
global SPTIAL_PAD
|
||||
SPTIAL_PAD = pad
|
||||
|
||||
|
||||
def get_spatial_pad() -> int:
|
||||
return SPTIAL_PAD
|
||||
|
||||
|
||||
def set_temporal_pad(dim_size: int):
|
||||
sp_size = get_sequence_parallel_size()
|
||||
pad = (sp_size - (dim_size % sp_size)) % sp_size
|
||||
global TEMPORAL_PAD
|
||||
TEMPORAL_PAD = pad
|
||||
|
||||
|
||||
def get_temporal_pad() -> int:
|
||||
return TEMPORAL_PAD
|
||||
|
||||
|
||||
def all_to_all_with_pad(
|
||||
input_: torch.Tensor,
|
||||
process_group: dist.ProcessGroup,
|
||||
scatter_dim: int = 2,
|
||||
gather_dim: int = 1,
|
||||
scatter_pad: int = 0,
|
||||
gather_pad: int = 0,
|
||||
):
|
||||
if scatter_pad > 0:
|
||||
pad_shape = list(input_.shape)
|
||||
pad_shape[scatter_dim] = scatter_pad
|
||||
pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype)
|
||||
input_ = torch.cat([input_, pad_tensor], dim=scatter_dim)
|
||||
|
||||
assert (
|
||||
input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0
|
||||
), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})"
|
||||
input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||
|
||||
if gather_pad > 0:
|
||||
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
|
||||
|
||||
return input_
|
||||
@ -1,128 +0,0 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import videosys
|
||||
|
||||
from .mp_utils import ProcessWorkerWrapper, ResultHandler, WorkerMonitor, get_distributed_init_method, get_open_port
|
||||
|
||||
|
||||
class VideoSysEngine:
|
||||
"""
|
||||
this is partly inspired by vllm
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.parallel_worker_tasks = None
|
||||
self._init_worker(config.pipeline_cls)
|
||||
|
||||
def _init_worker(self, pipeline_cls):
|
||||
world_size = self.config.num_gpus
|
||||
|
||||
# Disable torch async compiling which won't work with daemonic processes
|
||||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
||||
|
||||
# Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU
|
||||
# contention amongst the shards
|
||||
if "OMP_NUM_THREADS" not in os.environ:
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
|
||||
# NOTE: The two following lines need adaption for multi-node
|
||||
assert world_size <= torch.cuda.device_count()
|
||||
|
||||
# change addr for multi-node
|
||||
distributed_init_method = get_distributed_init_method("127.0.0.1", get_open_port())
|
||||
|
||||
if world_size == 1:
|
||||
self.workers = []
|
||||
self.worker_monitor = None
|
||||
else:
|
||||
result_handler = ResultHandler()
|
||||
self.workers = [
|
||||
ProcessWorkerWrapper(
|
||||
result_handler,
|
||||
partial(
|
||||
self._create_pipeline,
|
||||
pipeline_cls=pipeline_cls,
|
||||
rank=rank,
|
||||
local_rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
),
|
||||
)
|
||||
for rank in range(1, world_size)
|
||||
]
|
||||
|
||||
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
|
||||
result_handler.start()
|
||||
self.worker_monitor.start()
|
||||
|
||||
self.driver_worker = self._create_pipeline(
|
||||
pipeline_cls=pipeline_cls, distributed_init_method=distributed_init_method
|
||||
)
|
||||
|
||||
# TODO: add more options here for pipeline, or wrap all options into config
|
||||
def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None):
|
||||
videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method, seed=42)
|
||||
|
||||
pipeline = pipeline_cls(self.config)
|
||||
return pipeline
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
|
||||
# Start the workers first.
|
||||
worker_outputs = [worker.execute_method(method, *args, **kwargs) for worker in self.workers]
|
||||
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
# Just return futures
|
||||
return worker_outputs
|
||||
|
||||
driver_worker_method = getattr(self.driver_worker, method)
|
||||
driver_worker_output = driver_worker_method(*args, **kwargs)
|
||||
|
||||
# Get the results of the workers.
|
||||
return [driver_worker_output] + [output.get() for output in worker_outputs]
|
||||
|
||||
def _driver_execute_model(self, *args, **kwargs):
|
||||
return self.driver_worker.generate(*args, **kwargs)
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
return self._run_workers("generate", *args, **kwargs)[0]
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
if self.parallel_worker_tasks is None:
|
||||
return
|
||||
|
||||
parallel_worker_tasks = self.parallel_worker_tasks
|
||||
self.parallel_worker_tasks = None
|
||||
# Ensure that workers exit model loop cleanly
|
||||
# (this will raise otherwise)
|
||||
self._wait_for_tasks_completion(parallel_worker_tasks)
|
||||
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
for result in parallel_worker_tasks:
|
||||
result.get()
|
||||
|
||||
def save_video(self, video, output_path):
|
||||
return self.driver_worker.save_video(video, output_path)
|
||||
|
||||
def shutdown(self):
|
||||
if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
|
||||
worker_monitor.close()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
@ -1,270 +0,0 @@
|
||||
# adapted from vllm
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/executor/multiproc_worker_utils.py
|
||||
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Queue
|
||||
from multiprocessing.connection import wait
|
||||
from typing import Any, Callable, Dict, Generic, List, Optional, TextIO, TypeVar, Union
|
||||
|
||||
from videosys.utils.logging import create_logger
|
||||
|
||||
T = TypeVar("T")
|
||||
_TERMINATE = "TERMINATE" # sentinel
|
||||
# ANSI color codes
|
||||
CYAN = "\033[1;36m"
|
||||
RESET = "\033[0;0m"
|
||||
JOIN_TIMEOUT_S = 2
|
||||
|
||||
mp_method = "spawn" # fork cann't work
|
||||
mp = multiprocessing.get_context(mp_method)
|
||||
|
||||
logger = create_logger()
|
||||
|
||||
|
||||
def get_distributed_init_method(ip: str, port: int) -> str:
|
||||
# Brackets are not permitted in ipv4 addresses,
|
||||
# see https://github.com/python/cpython/issues/103848
|
||||
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
# try ipv6
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result(Generic[T]):
|
||||
"""Result of task dispatched to worker"""
|
||||
|
||||
task_id: uuid.UUID
|
||||
value: Optional[T] = None
|
||||
exception: Optional[BaseException] = None
|
||||
|
||||
|
||||
class ResultFuture(threading.Event, Generic[T]):
|
||||
"""Synchronous future for non-async case"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.result: Optional[Result[T]] = None
|
||||
|
||||
def set_result(self, result: Result[T]):
|
||||
self.result = result
|
||||
self.set()
|
||||
|
||||
def get(self) -> T:
|
||||
self.wait()
|
||||
assert self.result is not None
|
||||
if self.result.exception is not None:
|
||||
raise self.result.exception
|
||||
return self.result.value # type: ignore[return-value]
|
||||
|
||||
|
||||
def _set_future_result(future: Union[ResultFuture, asyncio.Future], result: Result):
|
||||
if isinstance(future, ResultFuture):
|
||||
future.set_result(result)
|
||||
return
|
||||
loop = future.get_loop()
|
||||
if not loop.is_closed():
|
||||
if result.exception is not None:
|
||||
loop.call_soon_threadsafe(future.set_exception, result.exception)
|
||||
else:
|
||||
loop.call_soon_threadsafe(future.set_result, result.value)
|
||||
|
||||
|
||||
class ResultHandler(threading.Thread):
|
||||
"""Handle results from all workers (in background thread)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(daemon=True)
|
||||
self.result_queue = mp.Queue()
|
||||
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
|
||||
|
||||
def run(self):
|
||||
for result in iter(self.result_queue.get, _TERMINATE):
|
||||
future = self.tasks.pop(result.task_id)
|
||||
_set_future_result(future, result)
|
||||
# Ensure that all waiters will receive an exception
|
||||
for task_id, future in self.tasks.items():
|
||||
_set_future_result(future, Result(task_id=task_id, exception=ChildProcessError("worker died")))
|
||||
|
||||
def close(self):
|
||||
self.result_queue.put(_TERMINATE)
|
||||
|
||||
|
||||
class WorkerMonitor(threading.Thread):
|
||||
"""Monitor worker status (in background thread)"""
|
||||
|
||||
def __init__(self, workers: List["ProcessWorkerWrapper"], result_handler: ResultHandler):
|
||||
super().__init__(daemon=True)
|
||||
self.workers = workers
|
||||
self.result_handler = result_handler
|
||||
self._close = False
|
||||
|
||||
def run(self) -> None:
|
||||
# Blocks until any worker exits
|
||||
dead_sentinels = wait([w.process.sentinel for w in self.workers])
|
||||
if not self._close:
|
||||
self._close = True
|
||||
|
||||
# Kill / cleanup all workers
|
||||
for worker in self.workers:
|
||||
process = worker.process
|
||||
if process.sentinel in dead_sentinels:
|
||||
process.join(JOIN_TIMEOUT_S)
|
||||
if process.exitcode is not None and process.exitcode != 0:
|
||||
logger.error("Worker %s pid %s died, exit code: %s", process.name, process.pid, process.exitcode)
|
||||
# Cleanup any remaining workers
|
||||
logger.info("Killing local worker processes")
|
||||
for worker in self.workers:
|
||||
worker.kill_worker()
|
||||
# Must be done after worker task queues are all closed
|
||||
self.result_handler.close()
|
||||
|
||||
for worker in self.workers:
|
||||
worker.process.join(JOIN_TIMEOUT_S)
|
||||
|
||||
def close(self):
|
||||
if self._close:
|
||||
return
|
||||
self._close = True
|
||||
logger.info("Terminating local worker processes")
|
||||
for worker in self.workers:
|
||||
worker.terminate_worker()
|
||||
# Must be done after worker task queues are all closed
|
||||
self.result_handler.close()
|
||||
|
||||
|
||||
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
||||
"""Prepend each output line with process-specific prefix"""
|
||||
|
||||
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
|
||||
file_write = file.write
|
||||
|
||||
def write_with_prefix(s: str):
|
||||
if not s:
|
||||
return
|
||||
if file.start_new_line: # type: ignore[attr-defined]
|
||||
file_write(prefix)
|
||||
idx = 0
|
||||
while (next_idx := s.find("\n", idx)) != -1:
|
||||
next_idx += 1
|
||||
file_write(s[idx:next_idx])
|
||||
if next_idx == len(s):
|
||||
file.start_new_line = True # type: ignore[attr-defined]
|
||||
return
|
||||
file_write(prefix)
|
||||
idx = next_idx
|
||||
file_write(s[idx:])
|
||||
file.start_new_line = False # type: ignore[attr-defined]
|
||||
|
||||
file.start_new_line = True # type: ignore[attr-defined]
|
||||
file.write = write_with_prefix # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _run_worker_process(
|
||||
worker_factory: Callable[[], Any],
|
||||
task_queue: Queue,
|
||||
result_queue: Queue,
|
||||
) -> None:
|
||||
"""Worker process event loop"""
|
||||
|
||||
# Add process-specific prefix to stdout and stderr
|
||||
process_name = mp.current_process().name
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
# Initialize worker
|
||||
worker = worker_factory()
|
||||
del worker_factory
|
||||
|
||||
# Accept tasks from the engine in task_queue
|
||||
# and return task output in result_queue
|
||||
logger.info("Worker ready; awaiting tasks")
|
||||
try:
|
||||
for items in iter(task_queue.get, _TERMINATE):
|
||||
output = None
|
||||
exception = None
|
||||
task_id, method, args, kwargs = items
|
||||
try:
|
||||
executor = getattr(worker, method)
|
||||
output = executor(*args, **kwargs)
|
||||
except BaseException as e:
|
||||
tb = traceback.format_exc()
|
||||
logger.error("Exception in worker %s while processing method %s: %s, %s", process_name, method, e, tb)
|
||||
exception = e
|
||||
result_queue.put(Result(task_id=task_id, value=output, exception=exception))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Worker failed")
|
||||
|
||||
logger.info("Worker exiting")
|
||||
|
||||
|
||||
class ProcessWorkerWrapper:
|
||||
"""Local process wrapper for handling single-node multi-GPU."""
|
||||
|
||||
def __init__(self, result_handler: ResultHandler, worker_factory: Callable[[], Any]) -> None:
|
||||
self._task_queue = mp.Queue()
|
||||
self.result_queue = result_handler.result_queue
|
||||
self.tasks = result_handler.tasks
|
||||
self.process = mp.Process( # type: ignore[attr-defined]
|
||||
target=_run_worker_process,
|
||||
name="VideoSysWorkerProcess",
|
||||
kwargs=dict(
|
||||
worker_factory=worker_factory,
|
||||
task_queue=self._task_queue,
|
||||
result_queue=self.result_queue,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
self.process.start()
|
||||
|
||||
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], method: str, args, kwargs):
|
||||
task_id = uuid.uuid4()
|
||||
self.tasks[task_id] = future
|
||||
try:
|
||||
self._task_queue.put((task_id, method, args, kwargs))
|
||||
except BaseException as e:
|
||||
del self.tasks[task_id]
|
||||
raise ChildProcessError("worker died") from e
|
||||
|
||||
def execute_method(self, method: str, *args, **kwargs):
|
||||
future: ResultFuture = ResultFuture()
|
||||
self._enqueue_task(future, method, args, kwargs)
|
||||
return future
|
||||
|
||||
async def execute_method_async(self, method: str, *args, **kwargs):
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self._enqueue_task(future, method, args, kwargs)
|
||||
return await future
|
||||
|
||||
def terminate_worker(self):
|
||||
try:
|
||||
self._task_queue.put(_TERMINATE)
|
||||
except ValueError:
|
||||
self.process.kill()
|
||||
self._task_queue.close()
|
||||
|
||||
def kill_worker(self):
|
||||
self._task_queue.close()
|
||||
self.process.kill()
|
||||
@ -1,233 +0,0 @@
|
||||
from videosys.utils.logging import logger
|
||||
|
||||
PAB_MANAGER = None
|
||||
|
||||
|
||||
class PABConfig:
|
||||
def __init__(
|
||||
self,
|
||||
steps: int,
|
||||
cross_broadcast: bool = False,
|
||||
cross_threshold: list = None,
|
||||
cross_range: int = None,
|
||||
spatial_broadcast: bool = False,
|
||||
spatial_threshold: list = None,
|
||||
spatial_range: int = None,
|
||||
temporal_broadcast: bool = False,
|
||||
temporal_threshold: list = None,
|
||||
temporal_range: int = None,
|
||||
mlp_broadcast: bool = False,
|
||||
mlp_spatial_broadcast_config: dict = None,
|
||||
mlp_temporal_broadcast_config: dict = None,
|
||||
):
|
||||
self.steps = steps
|
||||
|
||||
self.cross_broadcast = cross_broadcast
|
||||
self.cross_threshold = cross_threshold
|
||||
self.cross_range = cross_range
|
||||
|
||||
self.spatial_broadcast = spatial_broadcast
|
||||
self.spatial_threshold = spatial_threshold
|
||||
self.spatial_range = spatial_range
|
||||
|
||||
self.temporal_broadcast = temporal_broadcast
|
||||
self.temporal_threshold = temporal_threshold
|
||||
self.temporal_range = temporal_range
|
||||
|
||||
self.mlp_broadcast = mlp_broadcast
|
||||
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
|
||||
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
|
||||
self.mlp_temporal_outputs = {}
|
||||
self.mlp_spatial_outputs = {}
|
||||
|
||||
|
||||
class PABManager:
|
||||
def __init__(self, config: PABConfig):
|
||||
self.config: PABConfig = config
|
||||
|
||||
init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
|
||||
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
|
||||
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
|
||||
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
|
||||
init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
|
||||
logger.info(init_prompt)
|
||||
|
||||
def if_broadcast_cross(self, timestep: int, count: int):
|
||||
if (
|
||||
self.config.cross_broadcast
|
||||
and (timestep is not None)
|
||||
and (count % self.config.cross_range != 0)
|
||||
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
|
||||
):
|
||||
flag = True
|
||||
else:
|
||||
flag = False
|
||||
count = (count + 1) % self.config.steps
|
||||
return flag, count
|
||||
|
||||
def if_broadcast_temporal(self, timestep: int, count: int):
|
||||
if (
|
||||
self.config.temporal_broadcast
|
||||
and (timestep is not None)
|
||||
and (count % self.config.temporal_range != 0)
|
||||
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
|
||||
):
|
||||
flag = True
|
||||
else:
|
||||
flag = False
|
||||
count = (count + 1) % self.config.steps
|
||||
return flag, count
|
||||
|
||||
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
|
||||
if (
|
||||
self.config.spatial_broadcast
|
||||
and (timestep is not None)
|
||||
and (count % self.config.spatial_range != 0)
|
||||
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
|
||||
):
|
||||
flag = True
|
||||
else:
|
||||
flag = False
|
||||
count = (count + 1) % self.config.steps
|
||||
return flag, count
|
||||
|
||||
@staticmethod
|
||||
def _is_t_in_skip_config(all_timesteps, timestep, config):
|
||||
is_t_in_skip_config = False
|
||||
skip_range = None
|
||||
for key in config:
|
||||
if key not in all_timesteps:
|
||||
continue
|
||||
index = all_timesteps.index(key)
|
||||
skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
|
||||
if timestep in skip_range:
|
||||
is_t_in_skip_config = True
|
||||
skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
|
||||
break
|
||||
return is_t_in_skip_config, skip_range
|
||||
|
||||
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
||||
if not self.config.mlp_broadcast:
|
||||
return False, None, False, None
|
||||
|
||||
if is_temporal:
|
||||
cur_config = self.config.mlp_temporal_broadcast_config
|
||||
else:
|
||||
cur_config = self.config.mlp_spatial_broadcast_config
|
||||
|
||||
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
|
||||
next_flag = False
|
||||
if (
|
||||
self.config.mlp_broadcast
|
||||
and (timestep is not None)
|
||||
and (timestep in cur_config)
|
||||
and (block_idx in cur_config[timestep]["block"])
|
||||
):
|
||||
flag = False
|
||||
next_flag = True
|
||||
count = count + 1
|
||||
elif (
|
||||
self.config.mlp_broadcast
|
||||
and (timestep is not None)
|
||||
and (is_t_in_skip_config)
|
||||
and (block_idx in cur_config[skip_range[0]]["block"])
|
||||
):
|
||||
flag = True
|
||||
count = 0
|
||||
else:
|
||||
flag = False
|
||||
|
||||
return flag, count, next_flag, skip_range
|
||||
|
||||
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
|
||||
if is_temporal:
|
||||
self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
|
||||
else:
|
||||
self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
|
||||
|
||||
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
|
||||
skip_start_t = skip_range[0]
|
||||
if is_temporal:
|
||||
skip_output = (
|
||||
self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
|
||||
if self.config.mlp_temporal_outputs is not None
|
||||
else None
|
||||
)
|
||||
else:
|
||||
skip_output = (
|
||||
self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
|
||||
if self.config.mlp_spatial_outputs is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if skip_output is not None:
|
||||
if timestep == skip_range[-1]:
|
||||
# TODO: save memory
|
||||
if is_temporal:
|
||||
del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
|
||||
else:
|
||||
del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
|
||||
)
|
||||
|
||||
return skip_output
|
||||
|
||||
def get_spatial_mlp_outputs(self):
|
||||
return self.config.mlp_spatial_outputs
|
||||
|
||||
def get_temporal_mlp_outputs(self):
|
||||
return self.config.mlp_temporal_outputs
|
||||
|
||||
|
||||
def set_pab_manager(config: PABConfig):
|
||||
global PAB_MANAGER
|
||||
PAB_MANAGER = PABManager(config)
|
||||
|
||||
|
||||
def enable_pab():
|
||||
if PAB_MANAGER is None:
|
||||
return False
|
||||
return (
|
||||
PAB_MANAGER.config.cross_broadcast
|
||||
or PAB_MANAGER.config.spatial_broadcast
|
||||
or PAB_MANAGER.config.temporal_broadcast
|
||||
)
|
||||
|
||||
|
||||
def update_steps(steps: int):
|
||||
if PAB_MANAGER is not None:
|
||||
PAB_MANAGER.config.steps = steps
|
||||
|
||||
|
||||
def if_broadcast_cross(timestep: int, count: int):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_broadcast_cross(timestep, count)
|
||||
|
||||
|
||||
def if_broadcast_temporal(timestep: int, count: int):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
|
||||
|
||||
|
||||
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
|
||||
|
||||
|
||||
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
|
||||
|
||||
|
||||
def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
|
||||
return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
|
||||
|
||||
|
||||
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
|
||||
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
|
||||
@ -1,120 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from videosys.utils.logging import init_dist_logger, logger
|
||||
from videosys.utils.utils import set_seed
|
||||
|
||||
PARALLEL_MANAGER = None
|
||||
|
||||
|
||||
class ParallelManager(ProcessGroupMesh):
|
||||
def __init__(self, dp_size, cp_size, sp_size):
|
||||
super().__init__(dp_size, cp_size, sp_size)
|
||||
dp_axis, cp_axis, sp_axis = 0, 1, 2
|
||||
|
||||
self.dp_size = dp_size
|
||||
self.dp_group: ProcessGroup = self.get_group_along_axis(dp_axis)
|
||||
self.dp_rank = dist.get_rank(self.dp_group)
|
||||
|
||||
self.cp_size = cp_size
|
||||
self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
|
||||
self.cp_rank = dist.get_rank(self.cp_group)
|
||||
|
||||
self.sp_size = sp_size
|
||||
self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
|
||||
self.sp_rank = dist.get_rank(self.sp_group)
|
||||
self.enable_sp = sp_size > 1
|
||||
|
||||
logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}")
|
||||
|
||||
|
||||
def set_parallel_manager(dp_size, cp_size, sp_size):
|
||||
global PARALLEL_MANAGER
|
||||
PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size)
|
||||
|
||||
|
||||
def get_data_parallel_group():
|
||||
return PARALLEL_MANAGER.dp_group
|
||||
|
||||
|
||||
def get_data_parallel_size():
|
||||
return PARALLEL_MANAGER.dp_size
|
||||
|
||||
|
||||
def get_data_parallel_rank():
|
||||
return PARALLEL_MANAGER.dp_rank
|
||||
|
||||
|
||||
def get_sequence_parallel_group():
|
||||
return PARALLEL_MANAGER.sp_group
|
||||
|
||||
|
||||
def get_sequence_parallel_size():
|
||||
return PARALLEL_MANAGER.sp_size
|
||||
|
||||
|
||||
def get_sequence_parallel_rank():
|
||||
return PARALLEL_MANAGER.sp_rank
|
||||
|
||||
|
||||
def get_cfg_parallel_group():
|
||||
return PARALLEL_MANAGER.cp_group
|
||||
|
||||
|
||||
def get_cfg_parallel_size():
|
||||
return PARALLEL_MANAGER.cp_size
|
||||
|
||||
|
||||
def enable_sequence_parallel():
|
||||
if PARALLEL_MANAGER is None:
|
||||
return False
|
||||
return PARALLEL_MANAGER.enable_sp
|
||||
|
||||
|
||||
def get_parallel_manager():
|
||||
return PARALLEL_MANAGER
|
||||
|
||||
|
||||
def initialize(
|
||||
rank=0,
|
||||
world_size=1,
|
||||
init_method=None,
|
||||
seed: Optional[int] = None,
|
||||
sp_size: Optional[int] = None,
|
||||
enable_cp: bool = False,
|
||||
):
|
||||
if not dist.is_initialized():
|
||||
try:
|
||||
dist.destroy_process_group()
|
||||
except Exception:
|
||||
pass
|
||||
dist.init_process_group(backend="nccl", init_method=init_method, world_size=world_size, rank=rank)
|
||||
torch.cuda.set_device(rank)
|
||||
init_dist_logger()
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# init sequence parallel
|
||||
if sp_size is None:
|
||||
sp_size = dist.get_world_size()
|
||||
dp_size = 1
|
||||
else:
|
||||
assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size"
|
||||
dp_size = dist.get_world_size() // sp_size
|
||||
|
||||
# update cfg parallel
|
||||
# NOTE: enable cp parallel will be slower. disable it for now.
|
||||
if False and enable_cp and sp_size % 2 == 0:
|
||||
sp_size = sp_size // 2
|
||||
cp_size = 2
|
||||
else:
|
||||
cp_size = 1
|
||||
|
||||
set_parallel_manager(dp_size, cp_size, sp_size)
|
||||
|
||||
if seed is not None:
|
||||
set_seed(seed + get_data_parallel_rank())
|
||||
@ -1,52 +0,0 @@
|
||||
import inspect
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
class VideoSysPipeline(DiffusionPipeline):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def set_eval_and_device(device: torch.device, *modules):
|
||||
for module in modules:
|
||||
module.eval()
|
||||
module.to(device)
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
In diffusers, it is a convention to call the pipeline object.
|
||||
But in VideoSys, we will use the generate method for better prompt.
|
||||
This is a wrapper for the generate method to support the diffusers usage.
|
||||
"""
|
||||
return self.generate(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
expected_modules = set(required_parameters.keys()) - {"self"}
|
||||
# modify: remove the config module from the expected modules
|
||||
expected_modules = expected_modules - {"config"}
|
||||
|
||||
optional_names = list(optional_parameters)
|
||||
for name in optional_names:
|
||||
if name in cls._optional_components:
|
||||
expected_modules.add(name)
|
||||
optional_parameters.remove(name)
|
||||
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoSysPipelineOutput(BaseOutput):
|
||||
video: torch.Tensor
|
||||
@ -1,39 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||
# half-precision inputs is done in fp32
|
||||
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module, *args, **kwargs):
|
||||
assert module.__class__.__name__ == "FusedRMSNorm", (
|
||||
"Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm."
|
||||
"Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48"
|
||||
)
|
||||
|
||||
layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps)
|
||||
layer_norm.weight.data.copy_(module.weight.data)
|
||||
layer_norm = layer_norm.to(module.weight.device)
|
||||
return layer_norm
|
||||
@ -1,68 +0,0 @@
|
||||
from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func
|
||||
from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward
|
||||
from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class T5EncoderPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
assert not self.shard_config.enable_tensor_parallelism
|
||||
assert not self.shard_config.enable_flash_attention
|
||||
|
||||
def preprocess(self):
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack
|
||||
|
||||
policy = {}
|
||||
|
||||
# check whether apex is installed
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm # noqa
|
||||
from videosys.core.shardformer.t5.modeling import T5LayerNorm
|
||||
|
||||
# recover hf from fused rms norm to T5 norm which is faster
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=T5LayerNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm),
|
||||
policy=policy,
|
||||
target_key=T5LayerSelfAttention,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm),
|
||||
policy=policy,
|
||||
target_key=T5Stack,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pass
|
||||
|
||||
# use jit operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_T5_layer_ff_forward(),
|
||||
"dropout_add": get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_T5_layer_self_attention_forward(),
|
||||
"dropout_add": get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5LayerSelfAttention,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@ -1,758 +0,0 @@
|
||||
# Adapted from OpenSora
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# OpenSora: https://github.com/hpcaitech/Open-Sora
|
||||
# --------------------------------------------------------
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
|
||||
from einops import rearrange
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(
|
||||
self,
|
||||
parameters,
|
||||
deterministic=False,
|
||||
):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device, dtype=self.mean.dtype)
|
||||
|
||||
def sample(self):
|
||||
# torch.randn: standard normal distribution
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device, dtype=self.mean.dtype)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None: # SCH: assumes other is a standard normal distribution
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3, 4])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var
|
||||
- 1.0
|
||||
- self.logvar
|
||||
+ other.logvar,
|
||||
dim=[1, 2, 3, 4],
|
||||
)
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3, 4]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
|
||||
def is_odd(n):
|
||||
return not divisible_by(n, 2)
|
||||
|
||||
|
||||
def pad_at_dim(t, pad, dim=-1):
|
||||
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
||||
zeros = (0, 0) * dims_from_right
|
||||
return F.pad(t, (*zeros, *pad), mode="constant")
|
||||
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
chan_in,
|
||||
chan_out,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
pad_mode="constant",
|
||||
strides=None, # allow custom stride
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
|
||||
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
||||
|
||||
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
|
||||
|
||||
dilation = kwargs.pop("dilation", 1)
|
||||
stride = strides[0] if strides is not None else kwargs.pop("stride", 1)
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
|
||||
height_pad = height_kernel_size // 2
|
||||
width_pad = width_kernel_size // 2
|
||||
|
||||
self.time_pad = time_pad
|
||||
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
|
||||
|
||||
stride = strides if strides is not None else (stride, 1, 1)
|
||||
dilation = (dilation, 1, 1)
|
||||
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels, # SCH: added
|
||||
filters,
|
||||
conv_fn,
|
||||
activation_fn=nn.SiLU,
|
||||
use_conv_shortcut=False,
|
||||
num_groups=32,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.filters = filters
|
||||
self.activate = activation_fn()
|
||||
self.use_conv_shortcut = use_conv_shortcut
|
||||
|
||||
# SCH: MAGVIT uses GroupNorm by default
|
||||
self.norm1 = nn.GroupNorm(num_groups, in_channels)
|
||||
self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
|
||||
self.norm2 = nn.GroupNorm(num_groups, self.filters)
|
||||
self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False)
|
||||
if in_channels != filters:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
|
||||
else:
|
||||
self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
x = self.activate(x)
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.activate(x)
|
||||
x = self.conv2(x)
|
||||
if self.in_channels != self.filters: # SCH: ResBlock X->Y
|
||||
residual = self.conv3(residual)
|
||||
return x + residual
|
||||
|
||||
|
||||
def get_activation_fn(activation):
|
||||
if activation == "relu":
|
||||
activation_fn = nn.ReLU
|
||||
elif activation == "swish":
|
||||
activation_fn = nn.SiLU
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return activation_fn
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Encoder Blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_out_channels=4,
|
||||
latent_embed_dim=512, # num channels for latent vector
|
||||
filters=128,
|
||||
num_res_blocks=4,
|
||||
channel_multipliers=(1, 2, 2, 4),
|
||||
temporal_downsample=(False, True, True),
|
||||
num_groups=32, # for nn.GroupNorm
|
||||
activation_fn="swish",
|
||||
):
|
||||
super().__init__()
|
||||
self.filters = filters
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.num_blocks = len(channel_multipliers)
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.temporal_downsample = temporal_downsample
|
||||
self.num_groups = num_groups
|
||||
self.embedding_dim = latent_embed_dim
|
||||
|
||||
self.activation_fn = get_activation_fn(activation_fn)
|
||||
self.activate = self.activation_fn()
|
||||
self.conv_fn = CausalConv3d
|
||||
self.block_args = dict(
|
||||
conv_fn=self.conv_fn,
|
||||
activation_fn=self.activation_fn,
|
||||
use_conv_shortcut=False,
|
||||
num_groups=self.num_groups,
|
||||
)
|
||||
|
||||
# first layer conv
|
||||
self.conv_in = self.conv_fn(
|
||||
in_out_channels,
|
||||
filters,
|
||||
kernel_size=(3, 3, 3),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# ResBlocks and conv downsample
|
||||
self.block_res_blocks = nn.ModuleList([])
|
||||
self.conv_blocks = nn.ModuleList([])
|
||||
|
||||
filters = self.filters
|
||||
prev_filters = filters # record for in_channels
|
||||
for i in range(self.num_blocks):
|
||||
filters = self.filters * self.channel_multipliers[i]
|
||||
block_items = nn.ModuleList([])
|
||||
for _ in range(self.num_res_blocks):
|
||||
block_items.append(ResBlock(prev_filters, filters, **self.block_args))
|
||||
prev_filters = filters # update in_channels
|
||||
self.block_res_blocks.append(block_items)
|
||||
|
||||
if i < self.num_blocks - 1:
|
||||
if self.temporal_downsample[i]:
|
||||
t_stride = 2 if self.temporal_downsample[i] else 1
|
||||
s_stride = 1
|
||||
self.conv_blocks.append(
|
||||
self.conv_fn(
|
||||
prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride)
|
||||
)
|
||||
)
|
||||
prev_filters = filters # update in_channels
|
||||
else:
|
||||
# if no t downsample, don't add since this does nothing for pipeline models
|
||||
self.conv_blocks.append(nn.Identity(prev_filters)) # Identity
|
||||
prev_filters = filters # update in_channels
|
||||
|
||||
# last layer res block
|
||||
self.res_blocks = nn.ModuleList([])
|
||||
for _ in range(self.num_res_blocks):
|
||||
self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args))
|
||||
prev_filters = filters # update in_channels
|
||||
|
||||
# MAGVIT uses Group Normalization
|
||||
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters)
|
||||
|
||||
self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
for j in range(self.num_res_blocks):
|
||||
x = self.block_res_blocks[i][j](x)
|
||||
if i < self.num_blocks - 1:
|
||||
x = self.conv_blocks[i](x)
|
||||
for i in range(self.num_res_blocks):
|
||||
x = self.res_blocks[i](x)
|
||||
|
||||
x = self.norm1(x)
|
||||
x = self.activate(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Decoder Blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_out_channels=4,
|
||||
latent_embed_dim=512,
|
||||
filters=128,
|
||||
num_res_blocks=4,
|
||||
channel_multipliers=(1, 2, 2, 4),
|
||||
temporal_downsample=(False, True, True),
|
||||
num_groups=32, # for nn.GroupNorm
|
||||
activation_fn="swish",
|
||||
):
|
||||
super().__init__()
|
||||
self.filters = filters
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.num_blocks = len(channel_multipliers)
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.temporal_downsample = temporal_downsample
|
||||
self.num_groups = num_groups
|
||||
self.embedding_dim = latent_embed_dim
|
||||
self.s_stride = 1
|
||||
|
||||
self.activation_fn = get_activation_fn(activation_fn)
|
||||
self.activate = self.activation_fn()
|
||||
self.conv_fn = CausalConv3d
|
||||
self.block_args = dict(
|
||||
conv_fn=self.conv_fn,
|
||||
activation_fn=self.activation_fn,
|
||||
use_conv_shortcut=False,
|
||||
num_groups=self.num_groups,
|
||||
)
|
||||
|
||||
filters = self.filters * self.channel_multipliers[-1]
|
||||
prev_filters = filters
|
||||
|
||||
# last conv
|
||||
self.conv1 = self.conv_fn(self.embedding_dim, filters, kernel_size=(3, 3, 3), bias=True)
|
||||
|
||||
# last layer res block
|
||||
self.res_blocks = nn.ModuleList([])
|
||||
for _ in range(self.num_res_blocks):
|
||||
self.res_blocks.append(ResBlock(filters, filters, **self.block_args))
|
||||
|
||||
# ResBlocks and conv upsample
|
||||
self.block_res_blocks = nn.ModuleList([])
|
||||
self.num_blocks = len(self.channel_multipliers)
|
||||
self.conv_blocks = nn.ModuleList([])
|
||||
# reverse to keep track of the in_channels, but append also in a reverse direction
|
||||
for i in reversed(range(self.num_blocks)):
|
||||
filters = self.filters * self.channel_multipliers[i]
|
||||
# resblock handling
|
||||
block_items = nn.ModuleList([])
|
||||
for _ in range(self.num_res_blocks):
|
||||
block_items.append(ResBlock(prev_filters, filters, **self.block_args))
|
||||
prev_filters = filters # SCH: update in_channels
|
||||
self.block_res_blocks.insert(0, block_items) # SCH: append in front
|
||||
|
||||
# conv blocks with upsampling
|
||||
if i > 0:
|
||||
if self.temporal_downsample[i - 1]:
|
||||
t_stride = 2 if self.temporal_downsample[i - 1] else 1
|
||||
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
|
||||
self.conv_blocks.insert(
|
||||
0,
|
||||
self.conv_fn(
|
||||
prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3)
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.conv_blocks.insert(
|
||||
0,
|
||||
nn.Identity(prev_filters),
|
||||
)
|
||||
|
||||
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters)
|
||||
|
||||
self.conv_out = self.conv_fn(filters, in_out_channels, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
for i in range(self.num_res_blocks):
|
||||
x = self.res_blocks[i](x)
|
||||
for i in reversed(range(self.num_blocks)):
|
||||
for j in range(self.num_res_blocks):
|
||||
x = self.block_res_blocks[i][j](x)
|
||||
if i > 0:
|
||||
t_stride = 2 if self.temporal_downsample[i - 1] else 1
|
||||
x = self.conv_blocks[i - 1](x)
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)",
|
||||
ts=t_stride,
|
||||
hs=self.s_stride,
|
||||
ws=self.s_stride,
|
||||
)
|
||||
|
||||
x = self.norm1(x)
|
||||
x = self.activate(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class VAE_Temporal(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_out_channels=4,
|
||||
latent_embed_dim=4,
|
||||
embed_dim=4,
|
||||
filters=128,
|
||||
num_res_blocks=4,
|
||||
channel_multipliers=(1, 2, 2, 4),
|
||||
temporal_downsample=(True, True, False),
|
||||
num_groups=32, # for nn.GroupNorm
|
||||
activation_fn="swish",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_downsample_factor = 2 ** sum(temporal_downsample)
|
||||
# self.time_padding = self.time_downsample_factor - 1
|
||||
self.patch_size = (self.time_downsample_factor, 1, 1)
|
||||
self.out_channels = in_out_channels
|
||||
|
||||
# NOTE: following MAGVIT, conv in bias=False in encoder first conv
|
||||
self.encoder = Encoder(
|
||||
in_out_channels=in_out_channels,
|
||||
latent_embed_dim=latent_embed_dim * 2,
|
||||
filters=filters,
|
||||
num_res_blocks=num_res_blocks,
|
||||
channel_multipliers=channel_multipliers,
|
||||
temporal_downsample=temporal_downsample,
|
||||
num_groups=num_groups, # for nn.GroupNorm
|
||||
activation_fn=activation_fn,
|
||||
)
|
||||
self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)
|
||||
|
||||
self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)
|
||||
self.decoder = Decoder(
|
||||
in_out_channels=in_out_channels,
|
||||
latent_embed_dim=latent_embed_dim,
|
||||
filters=filters,
|
||||
num_res_blocks=num_res_blocks,
|
||||
channel_multipliers=channel_multipliers,
|
||||
temporal_downsample=temporal_downsample,
|
||||
num_groups=num_groups, # for nn.GroupNorm
|
||||
activation_fn=activation_fn,
|
||||
)
|
||||
|
||||
def get_latent_size(self, input_size):
|
||||
latent_size = []
|
||||
for i in range(3):
|
||||
if input_size[i] is None:
|
||||
lsize = None
|
||||
elif i == 0:
|
||||
time_padding = (
|
||||
0
|
||||
if (input_size[i] % self.time_downsample_factor == 0)
|
||||
else self.time_downsample_factor - input_size[i] % self.time_downsample_factor
|
||||
)
|
||||
lsize = (input_size[i] + time_padding) // self.patch_size[i]
|
||||
else:
|
||||
lsize = input_size[i] // self.patch_size[i]
|
||||
latent_size.append(lsize)
|
||||
return latent_size
|
||||
|
||||
def encode(self, x):
|
||||
time_padding = (
|
||||
0
|
||||
if (x.shape[2] % self.time_downsample_factor == 0)
|
||||
else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor
|
||||
)
|
||||
x = pad_at_dim(x, (time_padding, 0), dim=2)
|
||||
encoded_feature = self.encoder(x)
|
||||
moments = self.quant_conv(encoded_feature).to(x.dtype)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z, num_frames=None):
|
||||
time_padding = (
|
||||
0
|
||||
if (num_frames % self.time_downsample_factor == 0)
|
||||
else self.time_downsample_factor - num_frames % self.time_downsample_factor
|
||||
)
|
||||
z = self.post_quant_conv(z)
|
||||
x = self.decoder(z)
|
||||
x = x[:, :, time_padding:]
|
||||
return x
|
||||
|
||||
def forward(self, x, sample_posterior=True):
|
||||
posterior = self.encode(x)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
recon_video = self.decode(z, num_frames=x.shape[2])
|
||||
return recon_video, posterior, z
|
||||
|
||||
|
||||
def VAE_Temporal_SD(**kwargs):
|
||||
model = VAE_Temporal(
|
||||
in_out_channels=4,
|
||||
latent_embed_dim=4,
|
||||
embed_dim=4,
|
||||
filters=128,
|
||||
num_res_blocks=4,
|
||||
channel_multipliers=(1, 2, 2, 4),
|
||||
temporal_downsample=(False, True, True),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
class VideoAutoencoderKL(nn.Module):
|
||||
def __init__(
|
||||
self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None
|
||||
):
|
||||
super().__init__()
|
||||
self.module = AutoencoderKL.from_pretrained(
|
||||
from_pretrained,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
self.out_channels = self.module.config.latent_channels
|
||||
self.patch_size = (1, 8, 8)
|
||||
self.micro_batch_size = micro_batch_size
|
||||
|
||||
def encode(self, x):
|
||||
# x: (B, C, T, H, W)
|
||||
B = x.shape[0]
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||
|
||||
if self.micro_batch_size is None:
|
||||
x = self.module.encode(x).latent_dist.sample().mul_(0.18215)
|
||||
else:
|
||||
# NOTE: cannot be used for training
|
||||
bs = self.micro_batch_size
|
||||
x_out = []
|
||||
for i in range(0, x.shape[0], bs):
|
||||
x_bs = x[i : i + bs]
|
||||
x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215)
|
||||
x_out.append(x_bs)
|
||||
x = torch.cat(x_out, dim=0)
|
||||
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
|
||||
return x
|
||||
|
||||
def decode(self, x, **kwargs):
|
||||
# x: (B, C, T, H, W)
|
||||
B = x.shape[0]
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||
if self.micro_batch_size is None:
|
||||
x = self.module.decode(x / 0.18215).sample
|
||||
else:
|
||||
# NOTE: cannot be used for training
|
||||
bs = self.micro_batch_size
|
||||
x_out = []
|
||||
for i in range(0, x.shape[0], bs):
|
||||
x_bs = x[i : i + bs]
|
||||
x_bs = self.module.decode(x_bs / 0.18215).sample
|
||||
x_out.append(x_bs)
|
||||
x = torch.cat(x_out, dim=0)
|
||||
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
|
||||
return x
|
||||
|
||||
def get_latent_size(self, input_size):
|
||||
latent_size = []
|
||||
for i in range(3):
|
||||
# assert (
|
||||
# input_size[i] is None or input_size[i] % self.patch_size[i] == 0
|
||||
# ), "Input size must be divisible by patch size"
|
||||
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
|
||||
return latent_size
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class VideoAutoencoderKLTemporalDecoder(nn.Module):
|
||||
def __init__(self, from_pretrained=None, cache_dir=None, local_files_only=False):
|
||||
super().__init__()
|
||||
self.module = AutoencoderKLTemporalDecoder.from_pretrained(
|
||||
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only
|
||||
)
|
||||
self.out_channels = self.module.config.latent_channels
|
||||
self.patch_size = (1, 8, 8)
|
||||
|
||||
def encode(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, x, **kwargs):
|
||||
B, _, T = x.shape[:3]
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||
x = self.module.decode(x / 0.18215, num_frames=T).sample
|
||||
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
|
||||
return x
|
||||
|
||||
def get_latent_size(self, input_size):
|
||||
latent_size = []
|
||||
for i in range(3):
|
||||
# assert (
|
||||
# input_size[i] is None or input_size[i] % self.patch_size[i] == 0
|
||||
# ), "Input size must be divisible by patch size"
|
||||
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
|
||||
return latent_size
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class VideoAutoencoderPipelineConfig(PretrainedConfig):
|
||||
model_type = "VideoAutoencoderPipeline"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_2d=None,
|
||||
vae_temporal=None,
|
||||
from_pretrained=None,
|
||||
freeze_vae_2d=False,
|
||||
cal_loss=False,
|
||||
micro_frame_size=None,
|
||||
shift=0.0,
|
||||
scale=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vae_2d = vae_2d
|
||||
self.vae_temporal = vae_temporal
|
||||
self.from_pretrained = from_pretrained
|
||||
self.freeze_vae_2d = freeze_vae_2d
|
||||
self.cal_loss = cal_loss
|
||||
self.micro_frame_size = micro_frame_size
|
||||
self.shift = shift
|
||||
self.scale = scale
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class VideoAutoencoderPipeline(PreTrainedModel):
|
||||
config_class = VideoAutoencoderPipelineConfig
|
||||
|
||||
def __init__(self, config: VideoAutoencoderPipelineConfig):
|
||||
super().__init__(config=config)
|
||||
self.spatial_vae = VideoAutoencoderKL(
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
local_files_only=False,
|
||||
micro_batch_size=4,
|
||||
subfolder="vae",
|
||||
)
|
||||
self.temporal_vae = VAE_Temporal_SD()
|
||||
self.cal_loss = config.cal_loss
|
||||
self.micro_frame_size = config.micro_frame_size
|
||||
self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0]
|
||||
|
||||
if config.freeze_vae_2d:
|
||||
for param in self.spatial_vae.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.out_channels = self.temporal_vae.out_channels
|
||||
|
||||
# normalization parameters
|
||||
scale = torch.tensor(config.scale)
|
||||
shift = torch.tensor(config.shift)
|
||||
if len(scale.shape) > 0:
|
||||
scale = scale[None, :, None, None, None]
|
||||
if len(shift.shape) > 0:
|
||||
shift = shift[None, :, None, None, None]
|
||||
self.register_buffer("scale", scale)
|
||||
self.register_buffer("shift", shift)
|
||||
|
||||
def encode(self, x):
|
||||
x_z = self.spatial_vae.encode(x)
|
||||
|
||||
if self.micro_frame_size is None:
|
||||
posterior = self.temporal_vae.encode(x_z)
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z_list = []
|
||||
for i in range(0, x_z.shape[2], self.micro_frame_size):
|
||||
x_z_bs = x_z[:, :, i : i + self.micro_frame_size]
|
||||
posterior = self.temporal_vae.encode(x_z_bs)
|
||||
z_list.append(posterior.sample())
|
||||
z = torch.cat(z_list, dim=2)
|
||||
|
||||
if self.cal_loss:
|
||||
return z, posterior, x_z
|
||||
else:
|
||||
return (z - self.shift) / self.scale
|
||||
|
||||
def decode(self, z, num_frames=None):
|
||||
device = z.device
|
||||
self.scale = self.scale.to(device)
|
||||
self.shift = self.shift.to(device)
|
||||
if not self.cal_loss:
|
||||
z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype)
|
||||
|
||||
if self.micro_frame_size is None:
|
||||
x_z = self.temporal_vae.decode(z, num_frames=num_frames)
|
||||
x = self.spatial_vae.decode(x_z)
|
||||
else:
|
||||
x_z_list = []
|
||||
for i in range(0, z.size(2), self.micro_z_frame_size):
|
||||
z_bs = z[:, :, i : i + self.micro_z_frame_size]
|
||||
x_z_bs = self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames))
|
||||
x_z_list.append(x_z_bs)
|
||||
num_frames -= self.micro_frame_size
|
||||
x_z = torch.cat(x_z_list, dim=2)
|
||||
x = self.spatial_vae.decode(x_z)
|
||||
|
||||
if self.cal_loss:
|
||||
return x, x_z
|
||||
else:
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
assert self.cal_loss, "This method is only available when cal_loss is True"
|
||||
z, posterior, x_z = self.encode(x)
|
||||
x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2])
|
||||
return x_rec, x_z_rec, z, posterior, x_z
|
||||
|
||||
def get_latent_size(self, input_size):
|
||||
if self.micro_frame_size is None or input_size[0] is None:
|
||||
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
|
||||
else:
|
||||
sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]]
|
||||
sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size))
|
||||
sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size)
|
||||
remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None]
|
||||
if remain_temporal_size[0] > 0:
|
||||
remain_size = self.temporal_vae.get_latent_size(remain_temporal_size)
|
||||
sub_latent_size[0] += remain_size[0]
|
||||
return sub_latent_size
|
||||
|
||||
def get_temporal_last_layer(self):
|
||||
return self.temporal_vae.decoder.conv_out.conv.weight
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
def OpenSoraVAE_V1_2(
|
||||
micro_batch_size=4,
|
||||
micro_frame_size=17,
|
||||
from_pretrained=None,
|
||||
freeze_vae_2d=False,
|
||||
cal_loss=False,
|
||||
):
|
||||
vae_2d = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=micro_batch_size,
|
||||
)
|
||||
vae_temporal = dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
)
|
||||
shift = (-0.10, 0.34, 0.27, 0.98)
|
||||
scale = (3.85, 2.32, 2.33, 3.06)
|
||||
kwargs = dict(
|
||||
vae_2d=vae_2d,
|
||||
vae_temporal=vae_temporal,
|
||||
freeze_vae_2d=freeze_vae_2d,
|
||||
cal_loss=cal_loss,
|
||||
micro_frame_size=micro_frame_size,
|
||||
shift=shift,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs)
|
||||
return model
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,3 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
@ -1,205 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from videosys.models.modules.normalization import LlamaRMSNorm
|
||||
|
||||
|
||||
class OpenSoraAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = LlamaRMSNorm,
|
||||
enable_flash_attn: bool = False,
|
||||
rope=None,
|
||||
qk_norm_legacy: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.enable_flash_attn = enable_flash_attn
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.qk_norm_legacy = qk_norm_legacy
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.rope = False
|
||||
if rope is not None:
|
||||
self.rope = True
|
||||
self.rotary_emb = rope
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
# flash attn is not memory efficient for small sequences, this is empirical
|
||||
enable_flash_attn = self.enable_flash_attn and (N > B)
|
||||
qkv = self.qkv(x)
|
||||
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||
|
||||
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
if self.qk_norm_legacy:
|
||||
# WARNING: this may be a bug
|
||||
if self.rope:
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
else:
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
if self.rope:
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
|
||||
if enable_flash_attn:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
x = flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
else:
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
x_output_shape = (B, N, C)
|
||||
if not enable_flash_attn:
|
||||
x = x.transpose(1, 2)
|
||||
x = x.reshape(x_output_shape)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class OpenSoraMultiHeadCrossAttention(nn.Module):
|
||||
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, enable_flash_attn=False):
|
||||
super(OpenSoraMultiHeadCrossAttention, self).__init__()
|
||||
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||
|
||||
self.d_model = d_model
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_model // num_heads
|
||||
|
||||
self.q_linear = nn.Linear(d_model, d_model)
|
||||
self.kv_linear = nn.Linear(d_model, d_model * 2)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(d_model, d_model)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.enable_flash_attn = enable_flash_attn
|
||||
|
||||
def forward(self, x, cond, mask=None):
|
||||
# query/value: img tokens; key: condition; mask: if padding tokens
|
||||
B, N, C = x.shape
|
||||
|
||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
k, v = kv.unbind(2)
|
||||
|
||||
if self.enable_flash_attn:
|
||||
x = self.flash_attn_impl(q, k, v, mask, B, N, C)
|
||||
else:
|
||||
x = self.torch_impl(q, k, v, mask, B, N, C)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def flash_attn_impl(self, q, k, v, mask, B, N, C):
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
q_seqinfo = _SeqLenInfo.from_seqlens([N] * B)
|
||||
k_seqinfo = _SeqLenInfo.from_seqlens(mask)
|
||||
|
||||
x = flash_attn_varlen_func(
|
||||
q.view(-1, self.num_heads, self.head_dim),
|
||||
k.view(-1, self.num_heads, self.head_dim),
|
||||
v.view(-1, self.num_heads, self.head_dim),
|
||||
cu_seqlens_q=q_seqinfo.seqstart.cuda(),
|
||||
cu_seqlens_k=k_seqinfo.seqstart.cuda(),
|
||||
max_seqlen_q=q_seqinfo.max_seqlen,
|
||||
max_seqlen_k=k_seqinfo.max_seqlen,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
x = x.view(B, N, C)
|
||||
return x
|
||||
|
||||
def torch_impl(self, q, k, v, mask, B, N, C):
|
||||
q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attn_mask = torch.zeros(B, 1, N, k.shape[2], dtype=torch.bool, device=q.device)
|
||||
for i, m in enumerate(mask):
|
||||
attn_mask[i, :, :, :m] = -1e9
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = out.transpose(1, 2).contiguous().view(B, N, C)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SeqLenInfo:
|
||||
"""
|
||||
from xformers
|
||||
|
||||
(Internal) Represents the division of a dimension into blocks.
|
||||
For example, to represents a dimension of length 7 divided into
|
||||
three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
|
||||
The members will be:
|
||||
max_seqlen: 3
|
||||
min_seqlen: 2
|
||||
seqstart_py: [0, 2, 5, 7]
|
||||
seqstart: torch.IntTensor([0, 2, 5, 7])
|
||||
"""
|
||||
|
||||
seqstart: torch.Tensor
|
||||
max_seqlen: int
|
||||
min_seqlen: int
|
||||
seqstart_py: List[int]
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
self.seqstart = self.seqstart.to(device, non_blocking=True)
|
||||
|
||||
def intervals(self) -> Iterable[Tuple[int, int]]:
|
||||
yield from zip(self.seqstart_py, self.seqstart_py[1:])
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
|
||||
"""
|
||||
Input tensors are assumed to be in shape [B, M, *]
|
||||
"""
|
||||
assert not isinstance(seqlens, torch.Tensor)
|
||||
seqstart_py = [0]
|
||||
max_seqlen = -1
|
||||
min_seqlen = -1
|
||||
for seqlen in seqlens:
|
||||
min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
|
||||
max_seqlen = max(max_seqlen, seqlen)
|
||||
seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
|
||||
seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
|
||||
return cls(
|
||||
max_seqlen=max_seqlen,
|
||||
min_seqlen=min_seqlen,
|
||||
seqstart=seqstart,
|
||||
seqstart_py=seqstart_py,
|
||||
)
|
||||
@ -1,71 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CogVideoXDownsample3D(nn.Module):
|
||||
# Todo: Wait for paper relase.
|
||||
r"""
|
||||
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of channels in the input image.
|
||||
out_channels (`int`):
|
||||
Number of channels produced by the convolution.
|
||||
kernel_size (`int`, defaults to `3`):
|
||||
Size of the convolving kernel.
|
||||
stride (`int`, defaults to `2`):
|
||||
Stride of the convolution.
|
||||
padding (`int`, defaults to `0`):
|
||||
Padding added to all four sides of the input.
|
||||
compress_time (`bool`, defaults to `False`):
|
||||
Whether or not to compress the time dimension.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 2,
|
||||
padding: int = 0,
|
||||
compress_time: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.compress_time:
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
|
||||
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
||||
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
||||
|
||||
if x.shape[-1] % 2 == 1:
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
if x_rest.shape[-1] > 0:
|
||||
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
||||
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
else:
|
||||
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
||||
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
|
||||
# Pad the tensor
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
||||
x = self.conv(x)
|
||||
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
||||
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
@ -1,412 +0,0 @@
|
||||
import functools
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from einops import rearrange
|
||||
from timm.models.vision_transformer import Mlp
|
||||
|
||||
|
||||
class CogVideoXPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
embed_dim: int = 1920,
|
||||
text_embed_dim: int = 4096,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
||||
|
||||
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
||||
r"""
|
||||
Args:
|
||||
text_embeds (`torch.Tensor`):
|
||||
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
||||
image_embeds (`torch.Tensor`):
|
||||
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
||||
"""
|
||||
text_embeds = self.text_proj(text_embeds)
|
||||
|
||||
batch, num_frames, channels, height, width = image_embeds.shape
|
||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
||||
image_embeds = self.proj(image_embeds)
|
||||
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
||||
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
||||
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
||||
|
||||
embeds = torch.cat(
|
||||
[text_embeds, image_embeds], dim=1
|
||||
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
||||
return embeds
|
||||
|
||||
|
||||
class OpenSoraPatchEmbed3D(nn.Module):
|
||||
"""Video to Patch Embedding.
|
||||
|
||||
Args:
|
||||
patch_size (int): Patch token size. Default: (2,4,4).
|
||||
in_chans (int): Number of input video channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=(2, 4, 4),
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.flatten = flatten
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
# padding
|
||||
_, _, D, H, W = x.size()
|
||||
if W % self.patch_size[2] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
||||
if H % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
||||
if D % self.patch_size[0] != 0:
|
||||
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
|
||||
|
||||
x = self.proj(x) # (B C T H W)
|
||||
if self.norm is not None:
|
||||
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
||||
freqs = freqs.to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
if t_freq.dtype != dtype:
|
||||
t_freq = t_freq.to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class SizeEmbedder(TimestepEmbedder):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.outdim = hidden_size
|
||||
|
||||
def forward(self, s, bs):
|
||||
if s.ndim == 1:
|
||||
s = s[:, None]
|
||||
assert s.ndim == 2
|
||||
if s.shape[0] != bs:
|
||||
s = s.repeat(bs // s.shape[0], 1)
|
||||
assert s.shape[0] == bs
|
||||
b, dims = s.shape[0], s.shape[1]
|
||||
s = rearrange(s, "b d -> (b d)")
|
||||
s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
|
||||
s_emb = self.mlp(s_freq)
|
||||
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
||||
return s_emb
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class OpenSoraCaptionEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
uncond_prob,
|
||||
act_layer=nn.GELU(approximate="tanh"),
|
||||
token_num=120,
|
||||
):
|
||||
super().__init__()
|
||||
self.y_proj = Mlp(
|
||||
in_features=in_channels,
|
||||
hidden_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
act_layer=act_layer,
|
||||
drop=0,
|
||||
)
|
||||
self.register_buffer(
|
||||
"y_embedding",
|
||||
torch.randn(token_num, in_channels) / in_channels**0.5,
|
||||
)
|
||||
self.uncond_prob = uncond_prob
|
||||
|
||||
def token_drop(self, caption, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
return caption
|
||||
|
||||
def forward(self, caption, train, force_drop_ids=None):
|
||||
if train:
|
||||
assert caption.shape[2:] == self.y_embedding.shape
|
||||
use_dropout = self.uncond_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
caption = self.token_drop(caption, force_drop_ids)
|
||||
caption = self.y_proj(caption)
|
||||
return caption
|
||||
|
||||
|
||||
class OpenSoraPositionEmbedding2D(nn.Module):
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert dim % 4 == 0, "dim must be divisible by 4"
|
||||
half_dim = dim // 2
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def _get_sin_cos_emb(self, t: torch.Tensor):
|
||||
out = torch.einsum("i,d->id", t, self.inv_freq)
|
||||
emb_cos = torch.cos(out)
|
||||
emb_sin = torch.sin(out)
|
||||
return torch.cat((emb_sin, emb_cos), dim=-1)
|
||||
|
||||
@functools.lru_cache(maxsize=512)
|
||||
def _get_cached_emb(
|
||||
self,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
h: int,
|
||||
w: int,
|
||||
scale: float = 1.0,
|
||||
base_size: Optional[int] = None,
|
||||
):
|
||||
grid_h = torch.arange(h, device=device) / scale
|
||||
grid_w = torch.arange(w, device=device) / scale
|
||||
if base_size is not None:
|
||||
grid_h *= base_size / h
|
||||
grid_w *= base_size / w
|
||||
grid_h, grid_w = torch.meshgrid(
|
||||
grid_w,
|
||||
grid_h,
|
||||
indexing="ij",
|
||||
) # here w goes first
|
||||
grid_h = grid_h.t().reshape(-1)
|
||||
grid_w = grid_w.t().reshape(-1)
|
||||
emb_h = self._get_sin_cos_emb(grid_h)
|
||||
emb_w = self._get_sin_cos_emb(grid_w)
|
||||
return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
h: int,
|
||||
w: int,
|
||||
scale: Optional[float] = 1.0,
|
||||
base_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
start, stop = crops_coords
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Broadcast and concatenate tensors along specified dimension
|
||||
def broadcast(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = {len(t.shape) for t in tensors}
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatenation"
|
||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
|
||||
t, h, w, d = freqs.shape
|
||||
freqs = freqs.view(t * h * w, d)
|
||||
|
||||
# Generate sine and cosine components
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos()
|
||||
|
||||
if use_real:
|
||||
return cos, sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Use for example in Lumina
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Use for example in Stable Audio
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
else:
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x)
|
||||
@ -1,102 +0,0 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class CogVideoXLayerNormZero(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_dim: int,
|
||||
embedding_dim: int,
|
||||
elementwise_affine: bool = True,
|
||||
eps: float = 1e-5,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
||||
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
r"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
||||
output_dim (`int`, *optional*):
|
||||
norm_elementwise_affine (`bool`, defaults to `False):
|
||||
norm_eps (`bool`, defaults to `False`):
|
||||
chunk_dim (`int`, defaults to `0`):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_embeddings: Optional[int] = None,
|
||||
output_dim: Optional[int] = None,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
chunk_dim: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = output_dim or embedding_dim * 2
|
||||
|
||||
if num_embeddings is not None:
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if self.emb is not None:
|
||||
temb = self.emb(timestep)
|
||||
|
||||
temb = self.linear(self.silu(temb))
|
||||
|
||||
if self.chunk_dim == 1:
|
||||
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
||||
# other if-branch. This branch is specific to CogVideoX for now.
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
else:
|
||||
scale, shift = temb.chunk(2, dim=0)
|
||||
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
@ -1,67 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CogVideoXUpsample3D(nn.Module):
|
||||
r"""
|
||||
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of channels in the input image.
|
||||
out_channels (`int`):
|
||||
Number of channels produced by the convolution.
|
||||
kernel_size (`int`, defaults to `3`):
|
||||
Size of the convolving kernel.
|
||||
stride (`int`, defaults to `1`):
|
||||
Stride of the convolution.
|
||||
padding (`int`, defaults to `1`):
|
||||
Padding added to all four sides of the input.
|
||||
compress_time (`bool`, defaults to `False`):
|
||||
Whether or not to compress the time dimension.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
padding: int = 1,
|
||||
compress_time: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
if self.compress_time:
|
||||
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
||||
# split first frame
|
||||
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
||||
|
||||
x_first = F.interpolate(x_first, scale_factor=2.0)
|
||||
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
||||
x_first = x_first[:, :, None, :, :]
|
||||
inputs = torch.cat([x_first, x_rest], dim=2)
|
||||
elif inputs.shape[2] > 1:
|
||||
inputs = F.interpolate(inputs, scale_factor=2.0)
|
||||
else:
|
||||
inputs = inputs.squeeze(2)
|
||||
inputs = F.interpolate(inputs, scale_factor=2.0)
|
||||
inputs = inputs[:, :, None, :, :]
|
||||
else:
|
||||
# only interpolate 2D
|
||||
b, c, t, h, w = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
inputs = F.interpolate(inputs, scale_factor=2.0)
|
||||
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
b, c, t, h, w = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
inputs = self.conv(inputs)
|
||||
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
return inputs
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,591 +0,0 @@
|
||||
# Adapted from CogVideo
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# CogVideo: https://github.com/THUDM/CogVideo
|
||||
# diffusers: https://github.com/huggingface/diffusers
|
||||
# --------------------------------------------------------
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.utils import is_torch_version
|
||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||
from torch import nn
|
||||
|
||||
from videosys.core.comm import all_to_all_comm, gather_sequence, get_spatial_pad, set_spatial_pad, split_sequence
|
||||
from videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
|
||||
from videosys.core.parallel_mgr import (
|
||||
enable_sequence_parallel,
|
||||
get_cfg_parallel_group,
|
||||
get_cfg_parallel_size,
|
||||
get_sequence_parallel_group,
|
||||
get_sequence_parallel_size,
|
||||
)
|
||||
from videosys.models.modules.embeddings import apply_rotary_emb
|
||||
from videosys.utils.utils import batch_func
|
||||
|
||||
from ..modules.embeddings import CogVideoXPatchEmbed
|
||||
from ..modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
if enable_sequence_parallel():
|
||||
assert (
|
||||
attn.heads % get_sequence_parallel_size() == 0
|
||||
), f"Number of heads {attn.heads} must be divisible by sequence parallel size {get_sequence_parallel_size()}"
|
||||
attn_heads = attn.heads // get_sequence_parallel_size()
|
||||
query, key, value = map(
|
||||
lambda x: all_to_all_comm(x, get_sequence_parallel_group(), scatter_dim=2, gather_dim=1),
|
||||
[query, key, value],
|
||||
)
|
||||
else:
|
||||
attn_heads = attn.heads
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn_heads
|
||||
|
||||
query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
emb_len = image_rotary_emb[0].shape[0]
|
||||
query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
|
||||
query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
|
||||
)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
|
||||
)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
|
||||
|
||||
if enable_sequence_parallel():
|
||||
hidden_states = all_to_all_comm(hidden_states, get_sequence_parallel_group(), scatter_dim=1, gather_dim=2)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedCogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class CogVideoXBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
||||
|
||||
Parameters:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
time_embed_dim (`int`):
|
||||
The number of channels in timestep embedding.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to be used in feed-forward.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether or not to use bias in attention projection layers.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Whether or not to use normalization after query and key projections in Attention.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
Epsilon value for normalization layers.
|
||||
final_dropout (`bool` defaults to `False`):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
||||
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
||||
ff_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Feed-forward layer.
|
||||
attention_out_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Attention output projection layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
time_embed_dim: int,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
attention_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = True,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
block_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self Attention
|
||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
qk_norm="layer_norm" if qk_norm else None,
|
||||
eps=1e-6,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
processor=CogVideoXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Feed Forward
|
||||
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# pab
|
||||
self.attn_count = 0
|
||||
self.last_attn = None
|
||||
self.block_idx = block_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
timestep=None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# attention
|
||||
if enable_pab():
|
||||
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
|
||||
if enable_pab() and broadcast_attn:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
|
||||
else:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
if enable_pab():
|
||||
self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# feed-forward
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, defaults to `30`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
attention_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in the attention projection layers.
|
||||
sample_width (`int`, defaults to `90`):
|
||||
The width of the input latents.
|
||||
sample_height (`int`, defaults to `60`):
|
||||
The height of the input latents.
|
||||
sample_frames (`int`, defaults to `49`):
|
||||
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
||||
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
||||
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
||||
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
temporal_compression_ratio (`int`, defaults to `4`):
|
||||
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
||||
max_text_seq_length (`int`, defaults to `226`):
|
||||
The maximum sequence length of the input text embeddings.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to use in feed-forward.
|
||||
timestep_activation_fn (`str`, defaults to `"silu"`):
|
||||
Activation function to use when generating the timestep embeddings.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether or not to use elementwise affine in normalization layers.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
The epsilon value to use in normalization layers.
|
||||
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
||||
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
||||
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
||||
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 64,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = 16,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
time_embed_dim: int = 512,
|
||||
text_embed_dim: int = 4096,
|
||||
num_layers: int = 30,
|
||||
dropout: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
sample_width: int = 90,
|
||||
sample_height: int = 60,
|
||||
sample_frames: int = 49,
|
||||
patch_size: int = 2,
|
||||
temporal_compression_ratio: int = 4,
|
||||
max_text_seq_length: int = 226,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
timestep_activation_fn: str = "silu",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
spatial_interpolation_scale: float = 1.875,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_rotary_positional_embeddings: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
post_patch_height = sample_height // patch_size
|
||||
post_patch_width = sample_width // patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
||||
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
# 1. Patch embedding
|
||||
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
|
||||
self.embedding_dropout = nn.Dropout(dropout)
|
||||
|
||||
# 2. 3D positional embeddings
|
||||
spatial_pos_embedding = get_3d_sincos_pos_embed(
|
||||
inner_dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
spatial_interpolation_scale,
|
||||
temporal_interpolation_scale,
|
||||
)
|
||||
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
|
||||
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
|
||||
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
||||
|
||||
# 3. Time embeddings
|
||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
||||
|
||||
# 4. Define spatio-temporal transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
CogVideoXBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
# 5. Output blocks
|
||||
self.norm_out = AdaLayerNorm(
|
||||
embedding_dim=time_embed_dim,
|
||||
output_dim=2 * inner_dim,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
chunk_dim=1,
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
all_timesteps=None,
|
||||
):
|
||||
if get_cfg_parallel_size() > 1:
|
||||
(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep,
|
||||
timestep_cond,
|
||||
image_rotary_emb,
|
||||
) = batch_func(
|
||||
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep,
|
||||
timestep_cond,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. Time embedding
|
||||
timesteps = timestep
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
# 2. Patch embedding
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
|
||||
# 3. Position embedding
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
|
||||
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
if enable_sequence_parallel():
|
||||
set_spatial_pad(hidden_states.shape[1])
|
||||
hidden_states = split_sequence(hidden_states, get_sequence_parallel_group(), dim=1, pad=get_spatial_pad())
|
||||
|
||||
# 4. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
timestep=timesteps if enable_pab() else None,
|
||||
)
|
||||
|
||||
if enable_sequence_parallel():
|
||||
hidden_states = gather_sequence(hidden_states, get_sequence_parallel_group(), dim=1, pad=get_spatial_pad())
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
else:
|
||||
# CogVideoX-5B
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 5. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 6. Unpatchify
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
if get_cfg_parallel_size() > 1:
|
||||
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,636 +0,0 @@
|
||||
# Adapted from OpenSora
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# OpenSora: https://github.com/hpcaitech/Open-Sora
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from timm.models.layers import DropPath
|
||||
from timm.models.vision_transformer import Mlp
|
||||
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from videosys.core.comm import (
|
||||
all_to_all_with_pad,
|
||||
gather_sequence,
|
||||
get_spatial_pad,
|
||||
get_temporal_pad,
|
||||
set_spatial_pad,
|
||||
set_temporal_pad,
|
||||
split_sequence,
|
||||
)
|
||||
from videosys.core.pab_mgr import (
|
||||
enable_pab,
|
||||
get_mlp_output,
|
||||
if_broadcast_cross,
|
||||
if_broadcast_mlp,
|
||||
if_broadcast_spatial,
|
||||
if_broadcast_temporal,
|
||||
save_mlp_output,
|
||||
)
|
||||
from videosys.core.parallel_mgr import (
|
||||
enable_sequence_parallel,
|
||||
get_cfg_parallel_size,
|
||||
get_data_parallel_group,
|
||||
get_sequence_parallel_group,
|
||||
)
|
||||
from videosys.models.modules.activations import approx_gelu
|
||||
from videosys.models.modules.attentions import OpenSoraAttention, OpenSoraMultiHeadCrossAttention
|
||||
from videosys.models.modules.embeddings import (
|
||||
OpenSoraCaptionEmbedder,
|
||||
OpenSoraPatchEmbed3D,
|
||||
OpenSoraPositionEmbedding2D,
|
||||
SizeEmbedder,
|
||||
TimestepEmbedder,
|
||||
)
|
||||
from videosys.utils.utils import batch_func
|
||||
|
||||
|
||||
def t2i_modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class T2IFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
|
||||
self.out_channels = out_channels
|
||||
self.d_t = d_t
|
||||
self.d_s = d_s
|
||||
|
||||
def t_mask_select(self, x_mask, x, masked_x, T, S):
|
||||
# x: [B, (T, S), C]
|
||||
# mased_x: [B, (T, S), C]
|
||||
# x_mask: [B, T]
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = torch.where(x_mask[:, :, None, None], x, masked_x)
|
||||
x = rearrange(x, "B T S C -> B (T S) C")
|
||||
return x
|
||||
|
||||
def forward(self, x, t, x_mask=None, t0=None, T=None, S=None):
|
||||
if T is None:
|
||||
T = self.d_t
|
||||
if S is None:
|
||||
S = self.d_s
|
||||
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
if x_mask is not None:
|
||||
shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1)
|
||||
x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
|
||||
x = self.t_mask_select(x_mask, x, x_zero, T, S)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def auto_grad_checkpoint(module, *args, **kwargs):
|
||||
if getattr(module, "grad_checkpointing", False):
|
||||
if not isinstance(module, Iterable):
|
||||
return checkpoint(module, *args, use_reentrant=False, **kwargs)
|
||||
gc_step = module[0].grad_checkpointing_step
|
||||
return checkpoint_sequential(module, gc_step, *args, use_reentrant=False, **kwargs)
|
||||
return module(*args, **kwargs)
|
||||
|
||||
|
||||
class STDiT3Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
drop_path=0.0,
|
||||
rope=None,
|
||||
qk_norm=False,
|
||||
temporal=False,
|
||||
enable_flash_attn=False,
|
||||
block_idx=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.temporal = temporal
|
||||
self.hidden_size = hidden_size
|
||||
self.enable_flash_attn = enable_flash_attn
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
|
||||
self.attn = OpenSoraAttention(
|
||||
hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=qk_norm,
|
||||
rope=rope,
|
||||
enable_flash_attn=enable_flash_attn,
|
||||
)
|
||||
self.cross_attn = OpenSoraMultiHeadCrossAttention(hidden_size, num_heads, enable_flash_attn=enable_flash_attn)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
|
||||
|
||||
# pab
|
||||
self.block_idx = block_idx
|
||||
self.attn_count = 0
|
||||
self.last_attn = None
|
||||
self.cross_count = 0
|
||||
self.last_cross = None
|
||||
self.mlp_count = 0
|
||||
|
||||
def t_mask_select(self, x_mask, x, masked_x, T, S):
|
||||
# x: [B, (T, S), C]
|
||||
# mased_x: [B, (T, S), C]
|
||||
# x_mask: [B, T]
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = torch.where(x_mask[:, :, None, None], x, masked_x)
|
||||
x = rearrange(x, "B T S C -> B (T S) C")
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
y,
|
||||
t,
|
||||
mask=None, # text mask
|
||||
x_mask=None, # temporal mask
|
||||
t0=None, # t with timestamp=0
|
||||
T=None, # number of frames
|
||||
S=None, # number of pixel patches
|
||||
timestep=None,
|
||||
all_timesteps=None,
|
||||
):
|
||||
# prepare modulate parameters
|
||||
B, N, C = x.shape
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + t.reshape(B, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
if x_mask is not None:
|
||||
shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = (
|
||||
self.scale_shift_table[None] + t0.reshape(B, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
|
||||
if enable_pab():
|
||||
if self.temporal:
|
||||
broadcast_attn, self.attn_count = if_broadcast_temporal(int(timestep[0]), self.attn_count)
|
||||
else:
|
||||
broadcast_attn, self.attn_count = if_broadcast_spatial(
|
||||
int(timestep[0]), self.attn_count, self.block_idx
|
||||
)
|
||||
|
||||
if enable_pab() and broadcast_attn:
|
||||
x_m_s = self.last_attn
|
||||
else:
|
||||
# modulate (attention)
|
||||
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
if x_mask is not None:
|
||||
x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero)
|
||||
x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
|
||||
|
||||
# attention
|
||||
if self.temporal:
|
||||
if enable_sequence_parallel():
|
||||
x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=True)
|
||||
x_m = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S)
|
||||
x_m = self.attn(x_m)
|
||||
x_m = rearrange(x_m, "(B S) T C -> B (T S) C", T=T, S=S)
|
||||
if enable_sequence_parallel():
|
||||
x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=False)
|
||||
else:
|
||||
x_m = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S)
|
||||
x_m = self.attn(x_m)
|
||||
x_m = rearrange(x_m, "(B T) S C -> B (T S) C", T=T, S=S)
|
||||
|
||||
# modulate (attention)
|
||||
x_m_s = gate_msa * x_m
|
||||
if x_mask is not None:
|
||||
x_m_s_zero = gate_msa_zero * x_m
|
||||
x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, T, S)
|
||||
|
||||
if enable_pab():
|
||||
self.last_attn = x_m_s
|
||||
|
||||
# residual
|
||||
x = x + self.drop_path(x_m_s)
|
||||
|
||||
# cross attention
|
||||
if enable_pab():
|
||||
broadcast_cross, self.cross_count = if_broadcast_cross(int(timestep[0]), self.cross_count)
|
||||
|
||||
if enable_pab() and broadcast_cross:
|
||||
x = x + self.last_cross
|
||||
else:
|
||||
x_cross = self.cross_attn(x, y, mask)
|
||||
if enable_pab():
|
||||
self.last_cross = x_cross
|
||||
x = x + x_cross
|
||||
|
||||
if enable_pab():
|
||||
broadcast_mlp, self.mlp_count, broadcast_next, skip_range = if_broadcast_mlp(
|
||||
int(timestep[0]),
|
||||
self.mlp_count,
|
||||
self.block_idx,
|
||||
all_timesteps,
|
||||
is_temporal=self.temporal,
|
||||
)
|
||||
|
||||
if enable_pab() and broadcast_mlp:
|
||||
x_m_s = get_mlp_output(
|
||||
skip_range,
|
||||
timestep=int(timestep[0]),
|
||||
block_idx=self.block_idx,
|
||||
is_temporal=self.temporal,
|
||||
)
|
||||
else:
|
||||
# modulate (MLP)
|
||||
x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
if x_mask is not None:
|
||||
x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero)
|
||||
x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
|
||||
|
||||
# MLP
|
||||
x_m = self.mlp(x_m)
|
||||
|
||||
# modulate (MLP)
|
||||
x_m_s = gate_mlp * x_m
|
||||
if x_mask is not None:
|
||||
x_m_s_zero = gate_mlp_zero * x_m
|
||||
x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, T, S)
|
||||
|
||||
if enable_pab() and broadcast_next:
|
||||
save_mlp_output(
|
||||
timestep=int(timestep[0]),
|
||||
block_idx=self.block_idx,
|
||||
ff_output=x_m_s,
|
||||
is_temporal=self.temporal,
|
||||
)
|
||||
|
||||
# residual
|
||||
x = x + self.drop_path(x_m_s)
|
||||
|
||||
return x
|
||||
|
||||
def dynamic_switch(self, x, s, t, to_spatial_shard: bool):
|
||||
if to_spatial_shard:
|
||||
scatter_dim, gather_dim = 2, 1
|
||||
scatter_pad = get_spatial_pad()
|
||||
gather_pad = get_temporal_pad()
|
||||
else:
|
||||
scatter_dim, gather_dim = 1, 2
|
||||
scatter_pad = get_temporal_pad()
|
||||
gather_pad = get_spatial_pad()
|
||||
|
||||
x = rearrange(x, "b (t s) d -> b t s d", t=t, s=s)
|
||||
x = all_to_all_with_pad(
|
||||
x,
|
||||
get_sequence_parallel_group(),
|
||||
scatter_dim=scatter_dim,
|
||||
gather_dim=gather_dim,
|
||||
scatter_pad=scatter_pad,
|
||||
gather_pad=gather_pad,
|
||||
)
|
||||
new_s, new_t = x.shape[2], x.shape[1]
|
||||
x = rearrange(x, "b t s d -> b (t s) d")
|
||||
return x, new_s, new_t
|
||||
|
||||
|
||||
class STDiT3Config(PretrainedConfig):
|
||||
model_type = "STDiT3"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=(None, None, None),
|
||||
input_sq_size=512,
|
||||
in_channels=4,
|
||||
patch_size=(1, 2, 2),
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
pred_sigma=True,
|
||||
drop_path=0.0,
|
||||
caption_channels=4096,
|
||||
model_max_length=300,
|
||||
qk_norm=True,
|
||||
enable_flash_attn=False,
|
||||
only_train_temporal=False,
|
||||
freeze_y_embedder=False,
|
||||
skip_y_embedder=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.input_size = input_size
|
||||
self.input_sq_size = input_sq_size
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.hidden_size = hidden_size
|
||||
self.depth = depth
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.class_dropout_prob = class_dropout_prob
|
||||
self.pred_sigma = pred_sigma
|
||||
self.drop_path = drop_path
|
||||
self.caption_channels = caption_channels
|
||||
self.model_max_length = model_max_length
|
||||
self.qk_norm = qk_norm
|
||||
self.enable_flash_attn = enable_flash_attn
|
||||
self.only_train_temporal = only_train_temporal
|
||||
self.freeze_y_embedder = freeze_y_embedder
|
||||
self.skip_y_embedder = skip_y_embedder
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class STDiT3(PreTrainedModel):
|
||||
config_class = STDiT3Config
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.pred_sigma = config.pred_sigma
|
||||
self.in_channels = config.in_channels
|
||||
self.out_channels = config.in_channels * 2 if config.pred_sigma else config.in_channels
|
||||
|
||||
# model size related
|
||||
self.depth = config.depth
|
||||
self.mlp_ratio = config.mlp_ratio
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
|
||||
# computation related
|
||||
self.drop_path = config.drop_path
|
||||
self.enable_flash_attn = config.enable_flash_attn
|
||||
|
||||
# input size related
|
||||
self.patch_size = config.patch_size
|
||||
self.input_sq_size = config.input_sq_size
|
||||
self.pos_embed = OpenSoraPositionEmbedding2D(config.hidden_size)
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
|
||||
self.rope = RotaryEmbedding(dim=self.hidden_size // self.num_heads)
|
||||
|
||||
# embedding
|
||||
self.x_embedder = OpenSoraPatchEmbed3D(config.patch_size, config.in_channels, config.hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(config.hidden_size)
|
||||
self.fps_embedder = SizeEmbedder(self.hidden_size)
|
||||
self.t_block = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True),
|
||||
)
|
||||
self.y_embedder = OpenSoraCaptionEmbedder(
|
||||
in_channels=config.caption_channels,
|
||||
hidden_size=config.hidden_size,
|
||||
uncond_prob=config.class_dropout_prob,
|
||||
act_layer=approx_gelu,
|
||||
token_num=config.model_max_length,
|
||||
)
|
||||
|
||||
# spatial blocks
|
||||
drop_path = [x.item() for x in torch.linspace(0, self.drop_path, config.depth)]
|
||||
self.spatial_blocks = nn.ModuleList(
|
||||
[
|
||||
STDiT3Block(
|
||||
hidden_size=config.hidden_size,
|
||||
num_heads=config.num_heads,
|
||||
mlp_ratio=config.mlp_ratio,
|
||||
drop_path=drop_path[i],
|
||||
qk_norm=config.qk_norm,
|
||||
enable_flash_attn=config.enable_flash_attn,
|
||||
block_idx=i,
|
||||
)
|
||||
for i in range(config.depth)
|
||||
]
|
||||
)
|
||||
|
||||
# temporal blocks
|
||||
drop_path = [x.item() for x in torch.linspace(0, self.drop_path, config.depth)]
|
||||
self.temporal_blocks = nn.ModuleList(
|
||||
[
|
||||
STDiT3Block(
|
||||
hidden_size=config.hidden_size,
|
||||
num_heads=config.num_heads,
|
||||
mlp_ratio=config.mlp_ratio,
|
||||
drop_path=drop_path[i],
|
||||
qk_norm=config.qk_norm,
|
||||
enable_flash_attn=config.enable_flash_attn,
|
||||
# temporal
|
||||
temporal=True,
|
||||
rope=self.rope.rotate_queries_or_keys,
|
||||
block_idx=i,
|
||||
)
|
||||
for i in range(config.depth)
|
||||
]
|
||||
)
|
||||
# final layer
|
||||
self.final_layer = T2IFinalLayer(config.hidden_size, np.prod(self.patch_size), self.out_channels)
|
||||
|
||||
self.initialize_weights()
|
||||
if config.only_train_temporal:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
for block in self.temporal_blocks:
|
||||
for param in block.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
if config.freeze_y_embedder:
|
||||
for param in self.y_embedder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
# Initialize fps_embedder
|
||||
nn.init.normal_(self.fps_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.constant_(self.fps_embedder.mlp[0].bias, 0)
|
||||
nn.init.constant_(self.fps_embedder.mlp[2].weight, 0)
|
||||
nn.init.constant_(self.fps_embedder.mlp[2].bias, 0)
|
||||
|
||||
# Initialize timporal blocks
|
||||
for block in self.temporal_blocks:
|
||||
nn.init.constant_(block.attn.proj.weight, 0)
|
||||
nn.init.constant_(block.cross_attn.proj.weight, 0)
|
||||
nn.init.constant_(block.mlp.fc2.weight, 0)
|
||||
|
||||
def get_dynamic_size(self, x):
|
||||
_, _, T, H, W = x.size()
|
||||
if T % self.patch_size[0] != 0:
|
||||
T += self.patch_size[0] - T % self.patch_size[0]
|
||||
if H % self.patch_size[1] != 0:
|
||||
H += self.patch_size[1] - H % self.patch_size[1]
|
||||
if W % self.patch_size[2] != 0:
|
||||
W += self.patch_size[2] - W % self.patch_size[2]
|
||||
T = T // self.patch_size[0]
|
||||
H = H // self.patch_size[1]
|
||||
W = W // self.patch_size[2]
|
||||
return (T, H, W)
|
||||
|
||||
def encode_text(self, y, mask=None):
|
||||
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, self.hidden_size)
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y = y.squeeze(1).view(1, -1, self.hidden_size)
|
||||
return y, y_lens
|
||||
|
||||
def forward(
|
||||
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
|
||||
):
|
||||
# === Split batch ===
|
||||
if get_cfg_parallel_size() > 1:
|
||||
x, timestep, y, x_mask, mask = batch_func(
|
||||
partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
|
||||
)
|
||||
|
||||
dtype = self.x_embedder.proj.weight.dtype
|
||||
B = x.size(0)
|
||||
x = x.to(dtype)
|
||||
timestep = timestep.to(dtype)
|
||||
y = y.to(dtype)
|
||||
|
||||
# === get pos embed ===
|
||||
_, _, Tx, Hx, Wx = x.size()
|
||||
T, H, W = self.get_dynamic_size(x)
|
||||
S = H * W
|
||||
base_size = round(S**0.5)
|
||||
resolution_sq = (height[0].item() * width[0].item()) ** 0.5
|
||||
scale = resolution_sq / self.input_sq_size
|
||||
pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size)
|
||||
|
||||
# === get timestep embed ===
|
||||
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
|
||||
fps = self.fps_embedder(fps.unsqueeze(1), B)
|
||||
t = t + fps
|
||||
t_mlp = self.t_block(t)
|
||||
t0 = t0_mlp = None
|
||||
if x_mask is not None:
|
||||
t0_timestep = torch.zeros_like(timestep)
|
||||
t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
|
||||
t0 = t0 + fps
|
||||
t0_mlp = self.t_block(t0)
|
||||
|
||||
# === get y embed ===
|
||||
if self.config.skip_y_embedder:
|
||||
y_lens = mask
|
||||
if isinstance(y_lens, torch.Tensor):
|
||||
y_lens = y_lens.long().tolist()
|
||||
else:
|
||||
y, y_lens = self.encode_text(y, mask)
|
||||
|
||||
# === get x embed ===
|
||||
x = self.x_embedder(x) # [B, N, C]
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = x + pos_emb
|
||||
|
||||
# shard over the sequence dim if sp is enabled
|
||||
if enable_sequence_parallel():
|
||||
set_temporal_pad(T)
|
||||
set_spatial_pad(S)
|
||||
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
|
||||
T = x.shape[1]
|
||||
x_mask_org = x_mask
|
||||
x_mask = split_sequence(
|
||||
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
|
||||
)
|
||||
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
|
||||
# === blocks ===
|
||||
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
|
||||
x = auto_grad_checkpoint(
|
||||
spatial_block,
|
||||
x,
|
||||
y,
|
||||
t_mlp,
|
||||
y_lens,
|
||||
x_mask,
|
||||
t0_mlp,
|
||||
T,
|
||||
S,
|
||||
timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
x = auto_grad_checkpoint(
|
||||
temporal_block,
|
||||
x,
|
||||
y,
|
||||
t_mlp,
|
||||
y_lens,
|
||||
x_mask,
|
||||
t0_mlp,
|
||||
T,
|
||||
S,
|
||||
timestep,
|
||||
all_timesteps=all_timesteps,
|
||||
)
|
||||
|
||||
if enable_sequence_parallel():
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad())
|
||||
T, S = x.shape[1], x.shape[2]
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
||||
x_mask = x_mask_org
|
||||
|
||||
# === final layer ===
|
||||
x = self.final_layer(x, t, x_mask, t0, T, S)
|
||||
x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
|
||||
|
||||
# cast to float32 for better accuracy
|
||||
x = x.to(torch.float32)
|
||||
|
||||
# === Gather Output ===
|
||||
if get_cfg_parallel_size() > 1:
|
||||
x = gather_sequence(x, get_data_parallel_group(), dim=0)
|
||||
|
||||
return x
|
||||
|
||||
def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): of shape [B, N, C]
|
||||
|
||||
Return:
|
||||
x (torch.Tensor): of shape [B, C_out, T, H, W]
|
||||
"""
|
||||
|
||||
# N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
|
||||
T_p, H_p, W_p = self.patch_size
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
|
||||
N_t=N_t,
|
||||
N_h=N_h,
|
||||
N_w=N_w,
|
||||
T_p=T_p,
|
||||
H_p=H_p,
|
||||
W_p=W_p,
|
||||
C_out=self.out_channels,
|
||||
)
|
||||
# unpad
|
||||
x = x[:, :, :R_t, :R_h, :R_w]
|
||||
return x
|
||||
Binary file not shown.
@ -1,3 +0,0 @@
|
||||
from .pipeline_cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
|
||||
|
||||
__all__ = ["CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"]
|
||||
Binary file not shown.
Binary file not shown.
@ -1,813 +0,0 @@
|
||||
# Adapted from CogVideo
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# CogVideo: https://github.com/THUDM/CogVideo
|
||||
# diffusers: https://github.com/huggingface/diffusers
|
||||
# --------------------------------------------------------
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
|
||||
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
|
||||
from videosys.models.autoencoders.autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
from videosys.models.modules.embeddings import get_3d_rotary_pos_embed
|
||||
from videosys.models.transformers.cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
||||
from videosys.schedulers.scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
|
||||
from videosys.schedulers.scheduling_dpm_cogvideox import CogVideoXDPMScheduler
|
||||
from videosys.utils.logging import logger
|
||||
from videosys.utils.utils import save_video
|
||||
|
||||
|
||||
class CogVideoXPABConfig(PABConfig):
|
||||
def __init__(
|
||||
self,
|
||||
steps: int = 50,
|
||||
spatial_broadcast: bool = True,
|
||||
spatial_threshold: list = [100, 850],
|
||||
spatial_range: int = 2,
|
||||
):
|
||||
super().__init__(
|
||||
steps=steps,
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
spatial_threshold=spatial_threshold,
|
||||
spatial_range=spatial_range,
|
||||
)
|
||||
|
||||
|
||||
class CogVideoXConfig:
|
||||
"""
|
||||
This config is to instantiate a `CogVideoXPipeline` class for video generation.
|
||||
|
||||
To be specific, this config will be passed to engine by `VideoSysEngine(config)`.
|
||||
In the engine, it will be used to instantiate the corresponding pipeline class.
|
||||
And the engine will call the `generate` function of the pipeline to generate the video.
|
||||
If you want to explore the detail of generation, please refer to the pipeline class below.
|
||||
|
||||
Args:
|
||||
model_path (str):
|
||||
A path to the pretrained pipeline. Defaults to "THUDM/CogVideoX-2b".
|
||||
num_gpus (int):
|
||||
The number of GPUs to use. Defaults to 1.
|
||||
cpu_offload (bool):
|
||||
Whether to enable CPU offload. Defaults to False.
|
||||
vae_tiling (bool):
|
||||
Whether to enable tiling for the VAE. Defaults to True.
|
||||
enable_pab (bool):
|
||||
Whether to enable Pyramid Attention Broadcast. Defaults to False.
|
||||
pab_config (CogVideoXPABConfig):
|
||||
The configuration for Pyramid Attention Broadcast. Defaults to `CogVideoXPABConfig()`.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
from videosys import CogVideoXConfig, VideoSysEngine
|
||||
|
||||
# models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
|
||||
# change num_gpus for multi-gpu inference
|
||||
config = CogVideoXConfig("THUDM/CogVideoX-2b", num_gpus=1)
|
||||
engine = VideoSysEngine(config)
|
||||
|
||||
prompt = "Sunset over the sea."
|
||||
# num frames should be <= 49. resolution is fixed to 720p.
|
||||
video = engine.generate(
|
||||
prompt=prompt,
|
||||
guidance_scale=6,
|
||||
num_inference_steps=50,
|
||||
num_frames=49,
|
||||
).video[0]
|
||||
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str = "THUDM/CogVideoX-2b",
|
||||
# ======= distributed ========
|
||||
num_gpus: int = 1,
|
||||
# ======= memory =======
|
||||
cpu_offload: bool = False,
|
||||
vae_tiling: bool = True,
|
||||
# ======= pab ========
|
||||
enable_pab: bool = False,
|
||||
pab_config=CogVideoXPABConfig(),
|
||||
):
|
||||
self.model_path = model_path
|
||||
self.pipeline_cls = CogVideoXPipeline
|
||||
# ======= distributed ========
|
||||
self.num_gpus = num_gpus
|
||||
# ======= memory ========
|
||||
self.cpu_offload = cpu_offload
|
||||
self.vae_tiling = vae_tiling
|
||||
# ======= pab ========
|
||||
self.enable_pab = enable_pab
|
||||
self.pab_config = pab_config
|
||||
|
||||
|
||||
class CogVideoXPipeline(VideoSysPipeline):
|
||||
_optional_components = ["text_encoder", "tokenizer"]
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CogVideoXConfig,
|
||||
tokenizer: Optional[T5Tokenizer] = None,
|
||||
text_encoder: Optional[T5EncoderModel] = None,
|
||||
vae: Optional[AutoencoderKLCogVideoX] = None,
|
||||
transformer: Optional[CogVideoXTransformer3DModel] = None,
|
||||
scheduler: Optional[CogVideoXDDIMScheduler] = None,
|
||||
device: torch.device = torch.device("cuda"),
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
self._config = config
|
||||
self._device = device
|
||||
if config.model_path == "THUDM/CogVideoX-2b":
|
||||
dtype = torch.float16
|
||||
self._dtype = dtype
|
||||
|
||||
if transformer is None:
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
config.model_path, subfolder="transformer", torch_dtype=self._dtype
|
||||
)
|
||||
if vae is None:
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained(config.model_path, subfolder="vae", torch_dtype=self._dtype)
|
||||
if tokenizer is None:
|
||||
tokenizer = T5Tokenizer.from_pretrained(config.model_path, subfolder="tokenizer")
|
||||
if text_encoder is None:
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
config.model_path, subfolder="text_encoder", torch_dtype=self._dtype
|
||||
)
|
||||
if scheduler is None:
|
||||
scheduler = CogVideoXDDIMScheduler.from_pretrained(
|
||||
config.model_path,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
# set eval and device
|
||||
self.set_eval_and_device(self._device, vae, transformer)
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
# cpu offload
|
||||
if config.cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
else:
|
||||
self.set_eval_and_device(self._device, text_encoder)
|
||||
|
||||
# vae tiling
|
||||
if config.vae_tiling:
|
||||
vae.enable_tiling()
|
||||
|
||||
# pab
|
||||
if config.enable_pab:
|
||||
set_pab_manager(config.pab_config)
|
||||
|
||||
self.vae_scale_factor_spatial = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
||||
)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
frames = self.vae.decode(latents).sample
|
||||
return frames
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def fuse_qkv_projections(self) -> None:
|
||||
r"""Enables fused QKV projections."""
|
||||
self.fusing_transformer = True
|
||||
self.transformer.fuse_qkv_projections()
|
||||
|
||||
def unfuse_qkv_projections(self) -> None:
|
||||
r"""Disable QKV projection fusion if enabled."""
|
||||
if not self.fusing_transformer:
|
||||
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.transformer.unfuse_qkv_projections()
|
||||
self.fusing_transformer = False
|
||||
|
||||
def _prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
|
||||
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
use_real=True,
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
num_frames: int = 49,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 6,
|
||||
use_dynamic_cfg: bool = False,
|
||||
num_videos_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 226,
|
||||
verbose=True,
|
||||
) -> Union[VideoSysPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_frames (`int`, defaults to `48`):
|
||||
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
||||
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
||||
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
||||
needs to be satisfied is that of divisibility mentioned above.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `226`):
|
||||
Maximum sequence length in encoded prompt. Must be consistent with
|
||||
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if num_frames > 49:
|
||||
raise ValueError(
|
||||
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
||||
)
|
||||
update_steps(num_inference_steps)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
||||
num_videos_per_prompt = 1
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
latent_channels,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
|
||||
# with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
# for DPM-solver++
|
||||
old_pred_original_sample = None
|
||||
for i, t in progress_wrap(list(enumerate(timesteps))):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
all_timesteps=timesteps,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
if use_dynamic_cfg:
|
||||
self._guidance_scale = 1 + guidance_scale * (
|
||||
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
else:
|
||||
latents, old_pred_original_sample = self.scheduler.step(
|
||||
noise_pred,
|
||||
old_pred_original_sample,
|
||||
t,
|
||||
timesteps[i - 1] if i > 0 else None,
|
||||
latents,
|
||||
**extra_step_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
# progress_bar.update()
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return VideoSysPipelineOutput(video=video)
|
||||
|
||||
def save_video(self, video, output_path):
|
||||
save_video(video, output_path, fps=8)
|
||||
|
||||
|
||||
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
@ -1,3 +0,0 @@
|
||||
from .pipeline_latte import LatteConfig, LattePABConfig, LattePipeline
|
||||
|
||||
__all__ = ["LatteConfig", "LattePipeline", "LattePABConfig"]
|
||||
Binary file not shown.
Binary file not shown.
@ -1,929 +0,0 @@
|
||||
# Adapted from Latte
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# Latte: https://github.com/Vchitect/Latte
|
||||
# --------------------------------------------------------
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import ftfy
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import tqdm
|
||||
from bs4 import BeautifulSoup
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
|
||||
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
|
||||
from videosys.models.transformers.latte_transformer_3d import LatteT2V
|
||||
from videosys.utils.logging import logger
|
||||
from videosys.utils.utils import save_video
|
||||
|
||||
|
||||
class LattePABConfig(PABConfig):
|
||||
def __init__(
|
||||
self,
|
||||
steps: int = 50,
|
||||
spatial_broadcast: bool = True,
|
||||
spatial_threshold: list = [100, 800],
|
||||
spatial_range: int = 2,
|
||||
temporal_broadcast: bool = True,
|
||||
temporal_threshold: list = [100, 800],
|
||||
temporal_range: int = 3,
|
||||
cross_broadcast: bool = True,
|
||||
cross_threshold: list = [100, 800],
|
||||
cross_range: int = 6,
|
||||
mlp_broadcast: bool = True,
|
||||
mlp_spatial_broadcast_config: dict = {
|
||||
720: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
640: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
560: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
480: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
400: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
},
|
||||
mlp_temporal_broadcast_config: dict = {
|
||||
720: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
640: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
560: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
480: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
400: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
},
|
||||
):
|
||||
super().__init__(
|
||||
steps=steps,
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
spatial_threshold=spatial_threshold,
|
||||
spatial_range=spatial_range,
|
||||
temporal_broadcast=temporal_broadcast,
|
||||
temporal_threshold=temporal_threshold,
|
||||
temporal_range=temporal_range,
|
||||
cross_broadcast=cross_broadcast,
|
||||
cross_threshold=cross_threshold,
|
||||
cross_range=cross_range,
|
||||
mlp_broadcast=mlp_broadcast,
|
||||
mlp_spatial_broadcast_config=mlp_spatial_broadcast_config,
|
||||
mlp_temporal_broadcast_config=mlp_temporal_broadcast_config,
|
||||
)
|
||||
|
||||
|
||||
class LatteConfig:
|
||||
"""
|
||||
This config is to instantiate a `LattePipeline` class for video generation.
|
||||
|
||||
To be specific, this config will be passed to engine by `VideoSysEngine(config)`.
|
||||
In the engine, it will be used to instantiate the corresponding pipeline class.
|
||||
And the engine will call the `generate` function of the pipeline to generate the video.
|
||||
If you want to explore the detail of generation, please refer to the pipeline class below.
|
||||
|
||||
Args:
|
||||
model_path (str):
|
||||
A path to the pretrained pipeline. Defaults to "maxin-cn/Latte-1".
|
||||
num_gpus (int):
|
||||
The number of GPUs to use. Defaults to 1.
|
||||
enable_vae_temporal_decoder (bool):
|
||||
Whether to enable VAE Temporal Decoder. Defaults to True.
|
||||
beta_start (float):
|
||||
The initial value of beta for DDIM. Defaults to 0.0001.
|
||||
beta_end (float):
|
||||
The final value of beta for DDIM. Defaults to 0.02.
|
||||
beta_schedule (str):
|
||||
The schedule of beta for DDIM. Defaults to "linear".
|
||||
variance_type (str):
|
||||
The type of variance for DDIM. Defaults to "learned_range".
|
||||
enable_pab (bool):
|
||||
Whether to enable Pyramid Attention Broadcast. Defaults to False.
|
||||
pab_config (CogVideoXPABConfig):
|
||||
The configuration for Pyramid Attention Broadcast. Defaults to `LattePABConfig()`.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
from videosys import LatteConfig, VideoSysEngine
|
||||
|
||||
# change num_gpus for multi-gpu inference
|
||||
config = LatteConfig("maxin-cn/Latte-1", num_gpus=1)
|
||||
engine = VideoSysEngine(config)
|
||||
|
||||
prompt = "Sunset over the sea."
|
||||
# video size is fixed to 16 frames, 512x512.
|
||||
video = engine.generate(
|
||||
prompt=prompt,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=50,
|
||||
).video[0]
|
||||
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str = "maxin-cn/Latte-1",
|
||||
# ======= distributed =======
|
||||
num_gpus: int = 1,
|
||||
# ======= vae ========
|
||||
enable_vae_temporal_decoder: bool = True,
|
||||
# ======= scheduler ========
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
variance_type: str = "learned_range",
|
||||
# ======= memory =======
|
||||
cpu_offload: bool = False,
|
||||
# ======= pab ========
|
||||
enable_pab: bool = False,
|
||||
pab_config: PABConfig = LattePABConfig(),
|
||||
):
|
||||
self.model_path = model_path
|
||||
self.pipeline_cls = LattePipeline
|
||||
# ======= distributed =======
|
||||
self.num_gpus = num_gpus
|
||||
# ======= vae ========
|
||||
self.enable_vae_temporal_decoder = enable_vae_temporal_decoder
|
||||
# ======= memory ========
|
||||
self.cpu_offload = cpu_offload
|
||||
# ======= scheduler ========
|
||||
self.beta_start = beta_start
|
||||
self.beta_end = beta_end
|
||||
self.beta_schedule = beta_schedule
|
||||
self.variance_type = variance_type
|
||||
# ======= pab ========
|
||||
self.enable_pab = enable_pab
|
||||
self.pab_config = pab_config
|
||||
|
||||
|
||||
class LattePipeline(VideoSysPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using PixArt-Alpha.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. PixArt-Alpha uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`Transformer2DModel`]):
|
||||
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
bad_punct_regex = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LatteConfig,
|
||||
tokenizer: Optional[T5Tokenizer] = None,
|
||||
text_encoder: Optional[T5EncoderModel] = None,
|
||||
vae: Optional[AutoencoderKL] = None,
|
||||
transformer: Optional[LatteT2V] = None,
|
||||
scheduler: Optional[DDIMScheduler] = None,
|
||||
device: torch.device = torch.device("cuda"),
|
||||
dtype: torch.dtype = torch.float16,
|
||||
):
|
||||
super().__init__()
|
||||
self._config = config
|
||||
|
||||
# initialize the model if not provided
|
||||
if transformer is None:
|
||||
transformer = LatteT2V.from_pretrained(config.model_path, subfolder="transformer", video_length=16).to(
|
||||
dtype=dtype
|
||||
)
|
||||
if vae is None:
|
||||
if config.enable_vae_temporal_decoder:
|
||||
vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
||||
config.model_path, subfolder="vae_temporal_decoder", torch_dtype=dtype
|
||||
)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained(config.model_path, subfolder="vae", torch_dtype=dtype)
|
||||
if tokenizer is None:
|
||||
tokenizer = T5Tokenizer.from_pretrained(config.model_path, subfolder="tokenizer")
|
||||
if text_encoder is None:
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
config.model_path, subfolder="text_encoder", torch_dtype=dtype
|
||||
)
|
||||
if scheduler is None:
|
||||
scheduler = DDIMScheduler.from_pretrained(
|
||||
config.model_path,
|
||||
subfolder="scheduler",
|
||||
beta_start=config.beta_start,
|
||||
beta_end=config.beta_end,
|
||||
beta_schedule=config.beta_schedule,
|
||||
variance_type=config.variance_type,
|
||||
clip_sample=False,
|
||||
)
|
||||
|
||||
# pab
|
||||
if config.enable_pab:
|
||||
set_pab_manager(config.pab_config)
|
||||
|
||||
# set eval and device
|
||||
self.set_eval_and_device(device, vae, transformer)
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
# cpu offload
|
||||
if config.cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
else:
|
||||
self.set_eval_and_device(device, text_encoder)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
||||
def mask_text_embeddings(self, emb, mask):
|
||||
if emb.shape[0] == 1:
|
||||
keep_index = mask.sum().item()
|
||||
return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096
|
||||
else:
|
||||
masked_feature = emb * mask[:, None, :, None] # 1 120 4096
|
||||
return masked_feature, emb.shape[2]
|
||||
|
||||
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str = "",
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
mask_feature: bool = True,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
||||
PixArt-Alpha, this should be "".
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
|
||||
string.
|
||||
clean_caption (bool, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
mask_feature: (bool, defaults to `True`):
|
||||
If `True`, the function will mask the text embeddings.
|
||||
"""
|
||||
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# See Section 3.1. of the paper.
|
||||
max_length = 120
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
prompt_embeds_attention_mask = attention_mask
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
|
||||
prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens = [negative_prompt] * batch_size
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
|
||||
# Perform additional masking.
|
||||
if mask_feature and not embeds_initially_provided:
|
||||
prompt_embeds = prompt_embeds.unsqueeze(1)
|
||||
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
|
||||
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
|
||||
masked_negative_prompt_embeds = (
|
||||
negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
|
||||
)
|
||||
|
||||
# import torch.nn.functional as F
|
||||
|
||||
# padding = (0, 0, 0, 113) # (左, 右, 下, 上)
|
||||
# masked_prompt_embeds_ = F.pad(masked_prompt_embeds, padding, "constant", 0)
|
||||
# masked_negative_prompt_embeds_ = F.pad(masked_negative_prompt_embeds, padding, "constant", 0)
|
||||
|
||||
# print(masked_prompt_embeds == masked_prompt_embeds_[:, :masked_negative_prompt_embeds.shape[1], ...])
|
||||
|
||||
return masked_prompt_embeds, masked_negative_prompt_embeds
|
||||
# return masked_prompt_embeds_, masked_negative_prompt_embeds_
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_steps,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
video_length,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
prompt: str = None,
|
||||
negative_prompt: str = "",
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
clean_caption: bool = True,
|
||||
mask_feature: bool = True,
|
||||
enable_temporal_attentions: bool = True,
|
||||
verbose: bool = True,
|
||||
) -> Union[VideoSysPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Latte can only generate video of 16 frames 512x512.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
|
||||
enable_temporal_attentions (`bool`, defaults to `True`):
|
||||
If `True`, the model will use temporal attentions to generate the video.
|
||||
verbose (`bool`, *optional*, defaults to `True`):
|
||||
Whether to print progress bars and other information during inference.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
video_length = 16
|
||||
height = 512
|
||||
width = 512
|
||||
update_steps(num_inference_steps)
|
||||
self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
clean_caption=clean_caption,
|
||||
mask_feature=mask_feature,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
video_length,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs.
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.1 Prepare micro-conditions.
|
||||
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
||||
if self.transformer.config.sample_size == 128:
|
||||
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
|
||||
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
|
||||
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
|
||||
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
|
||||
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
|
||||
for i, t in progress_wrap(list(enumerate(timesteps))):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
current_timestep = t
|
||||
if not torch.is_tensor(current_timestep):
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = latent_model_input.device.type == "mps"
|
||||
if isinstance(current_timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
||||
elif len(current_timestep.shape) == 0:
|
||||
current_timestep = current_timestep[None].to(latent_model_input.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
all_timesteps=timesteps,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=current_timestep,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
enable_temporal_attentions=enable_temporal_attentions,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# learned sigma
|
||||
if self.transformer.config.out_channels // 2 == latent_channels:
|
||||
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
||||
else:
|
||||
noise_pred = noise_pred
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latents":
|
||||
if latents.shape[2] == 1: # image
|
||||
video = self.decode_latents_image(latents)
|
||||
else: # video
|
||||
if self._config.enable_vae_temporal_decoder:
|
||||
video = self.decode_latents_with_temporal_decoder(latents)
|
||||
else:
|
||||
video = self.decode_latents(latents)
|
||||
else:
|
||||
video = latents
|
||||
return VideoSysPipelineOutput(video=video)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return VideoSysPipelineOutput(video=video)
|
||||
|
||||
def decode_latents_image(self, latents):
|
||||
video_length = latents.shape[2]
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
|
||||
video = []
|
||||
for frame_idx in range(latents.shape[0]):
|
||||
video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
|
||||
video = torch.cat(video)
|
||||
video = einops.rearrange(video, "(b f) c h w -> b f c h w", f=video_length)
|
||||
video = (video / 2.0 + 0.5).clamp(0, 1)
|
||||
return video
|
||||
|
||||
def decode_latents(self, latents):
|
||||
video_length = latents.shape[2]
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
|
||||
video = []
|
||||
for frame_idx in range(latents.shape[0]):
|
||||
video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
|
||||
video = torch.cat(video)
|
||||
video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
|
||||
video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().contiguous()
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
return video
|
||||
|
||||
def decode_latents_with_temporal_decoder(self, latents):
|
||||
video_length = latents.shape[2]
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
|
||||
video = []
|
||||
|
||||
decode_chunk_size = 14
|
||||
for frame_idx in range(0, latents.shape[0], decode_chunk_size):
|
||||
num_frames_in = latents[frame_idx : frame_idx + decode_chunk_size].shape[0]
|
||||
|
||||
decode_kwargs = {}
|
||||
decode_kwargs["num_frames"] = num_frames_in
|
||||
|
||||
video.append(self.vae.decode(latents[frame_idx : frame_idx + decode_chunk_size], **decode_kwargs).sample)
|
||||
|
||||
video = torch.cat(video)
|
||||
video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
|
||||
video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().contiguous()
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
return video
|
||||
|
||||
def save_video(self, video, output_path):
|
||||
save_video(video, output_path, fps=8)
|
||||
@ -1,3 +0,0 @@
|
||||
from .pipeline_open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
|
||||
|
||||
__all__ = ["OpenSoraConfig", "OpenSoraPipeline", "OpenSoraPABConfig"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,807 +0,0 @@
|
||||
# Adapted from OpenSora
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# OpenSora: https://github.com/hpcaitech/Open-Sora
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
import numbers
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
||||
from torchvision.io import write_video
|
||||
from torchvision.utils import save_image
|
||||
|
||||
IMG_FPS = 120
|
||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
|
||||
regex = re.compile(
|
||||
r"^(?:http|ftp)s?://" # http:// or https://
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
|
||||
r"(?::\d+)?" # optional port
|
||||
r"(?:/?|[/?]\S+)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# H:W
|
||||
ASPECT_RATIO_MAP = {
|
||||
"3:8": "0.38",
|
||||
"9:21": "0.43",
|
||||
"12:25": "0.48",
|
||||
"1:2": "0.50",
|
||||
"9:17": "0.53",
|
||||
"27:50": "0.54",
|
||||
"9:16": "0.56",
|
||||
"5:8": "0.62",
|
||||
"2:3": "0.67",
|
||||
"3:4": "0.75",
|
||||
"1:1": "1.00",
|
||||
"4:3": "1.33",
|
||||
"3:2": "1.50",
|
||||
"16:9": "1.78",
|
||||
"17:9": "1.89",
|
||||
"2:1": "2.00",
|
||||
"50:27": "2.08",
|
||||
}
|
||||
|
||||
|
||||
# computed from above code
|
||||
# S = 8294400
|
||||
ASPECT_RATIO_4K = {
|
||||
"0.38": (1764, 4704),
|
||||
"0.43": (1886, 4400),
|
||||
"0.48": (1996, 4158),
|
||||
"0.50": (2036, 4072),
|
||||
"0.53": (2096, 3960),
|
||||
"0.54": (2118, 3918),
|
||||
"0.62": (2276, 3642),
|
||||
"0.56": (2160, 3840), # base
|
||||
"0.67": (2352, 3528),
|
||||
"0.75": (2494, 3326),
|
||||
"1.00": (2880, 2880),
|
||||
"1.33": (3326, 2494),
|
||||
"1.50": (3528, 2352),
|
||||
"1.78": (3840, 2160),
|
||||
"1.89": (3958, 2096),
|
||||
"2.00": (4072, 2036),
|
||||
"2.08": (4156, 1994),
|
||||
}
|
||||
|
||||
# S = 3686400
|
||||
ASPECT_RATIO_2K = {
|
||||
"0.38": (1176, 3136),
|
||||
"0.43": (1256, 2930),
|
||||
"0.48": (1330, 2770),
|
||||
"0.50": (1358, 2716),
|
||||
"0.53": (1398, 2640),
|
||||
"0.54": (1412, 2612),
|
||||
"0.56": (1440, 2560), # base
|
||||
"0.62": (1518, 2428),
|
||||
"0.67": (1568, 2352),
|
||||
"0.75": (1662, 2216),
|
||||
"1.00": (1920, 1920),
|
||||
"1.33": (2218, 1664),
|
||||
"1.50": (2352, 1568),
|
||||
"1.78": (2560, 1440),
|
||||
"1.89": (2638, 1396),
|
||||
"2.00": (2716, 1358),
|
||||
"2.08": (2772, 1330),
|
||||
}
|
||||
|
||||
# S = 2073600
|
||||
ASPECT_RATIO_1080P = {
|
||||
"0.38": (882, 2352),
|
||||
"0.43": (942, 2198),
|
||||
"0.48": (998, 2080),
|
||||
"0.50": (1018, 2036),
|
||||
"0.53": (1048, 1980),
|
||||
"0.54": (1058, 1958),
|
||||
"0.56": (1080, 1920), # base
|
||||
"0.62": (1138, 1820),
|
||||
"0.67": (1176, 1764),
|
||||
"0.75": (1248, 1664),
|
||||
"1.00": (1440, 1440),
|
||||
"1.33": (1662, 1246),
|
||||
"1.50": (1764, 1176),
|
||||
"1.78": (1920, 1080),
|
||||
"1.89": (1980, 1048),
|
||||
"2.00": (2036, 1018),
|
||||
"2.08": (2078, 998),
|
||||
}
|
||||
|
||||
# S = 921600
|
||||
ASPECT_RATIO_720P = {
|
||||
"0.38": (588, 1568),
|
||||
"0.43": (628, 1466),
|
||||
"0.48": (666, 1388),
|
||||
"0.50": (678, 1356),
|
||||
"0.53": (698, 1318),
|
||||
"0.54": (706, 1306),
|
||||
"0.56": (720, 1280), # base
|
||||
"0.62": (758, 1212),
|
||||
"0.67": (784, 1176),
|
||||
"0.75": (832, 1110),
|
||||
"1.00": (960, 960),
|
||||
"1.33": (1108, 832),
|
||||
"1.50": (1176, 784),
|
||||
"1.78": (1280, 720),
|
||||
"1.89": (1320, 698),
|
||||
"2.00": (1358, 680),
|
||||
"2.08": (1386, 666),
|
||||
}
|
||||
|
||||
# S = 409920
|
||||
ASPECT_RATIO_480P = {
|
||||
"0.38": (392, 1046),
|
||||
"0.43": (420, 980),
|
||||
"0.48": (444, 925),
|
||||
"0.50": (452, 904),
|
||||
"0.53": (466, 880),
|
||||
"0.54": (470, 870),
|
||||
"0.56": (480, 854), # base
|
||||
"0.62": (506, 810),
|
||||
"0.67": (522, 784),
|
||||
"0.75": (554, 738),
|
||||
"1.00": (640, 640),
|
||||
"1.33": (740, 555),
|
||||
"1.50": (784, 522),
|
||||
"1.78": (854, 480),
|
||||
"1.89": (880, 466),
|
||||
"2.00": (906, 454),
|
||||
"2.08": (924, 444),
|
||||
}
|
||||
|
||||
# S = 230400
|
||||
ASPECT_RATIO_360P = {
|
||||
"0.38": (294, 784),
|
||||
"0.43": (314, 732),
|
||||
"0.48": (332, 692),
|
||||
"0.50": (340, 680),
|
||||
"0.53": (350, 662),
|
||||
"0.54": (352, 652),
|
||||
"0.56": (360, 640), # base
|
||||
"0.62": (380, 608),
|
||||
"0.67": (392, 588),
|
||||
"0.75": (416, 554),
|
||||
"1.00": (480, 480),
|
||||
"1.33": (554, 416),
|
||||
"1.50": (588, 392),
|
||||
"1.78": (640, 360),
|
||||
"1.89": (660, 350),
|
||||
"2.00": (678, 340),
|
||||
"2.08": (692, 332),
|
||||
}
|
||||
|
||||
# S = 102240
|
||||
ASPECT_RATIO_240P = {
|
||||
"0.38": (196, 522),
|
||||
"0.43": (210, 490),
|
||||
"0.48": (222, 462),
|
||||
"0.50": (226, 452),
|
||||
"0.53": (232, 438),
|
||||
"0.54": (236, 436),
|
||||
"0.56": (240, 426), # base
|
||||
"0.62": (252, 404),
|
||||
"0.67": (262, 393),
|
||||
"0.75": (276, 368),
|
||||
"1.00": (320, 320),
|
||||
"1.33": (370, 278),
|
||||
"1.50": (392, 262),
|
||||
"1.78": (426, 240),
|
||||
"1.89": (440, 232),
|
||||
"2.00": (452, 226),
|
||||
"2.08": (462, 222),
|
||||
}
|
||||
|
||||
# S = 36864
|
||||
ASPECT_RATIO_144P = {
|
||||
"0.38": (117, 312),
|
||||
"0.43": (125, 291),
|
||||
"0.48": (133, 277),
|
||||
"0.50": (135, 270),
|
||||
"0.53": (139, 262),
|
||||
"0.54": (141, 260),
|
||||
"0.56": (144, 256), # base
|
||||
"0.62": (151, 241),
|
||||
"0.67": (156, 234),
|
||||
"0.75": (166, 221),
|
||||
"1.00": (192, 192),
|
||||
"1.33": (221, 165),
|
||||
"1.50": (235, 156),
|
||||
"1.78": (256, 144),
|
||||
"1.89": (263, 139),
|
||||
"2.00": (271, 135),
|
||||
"2.08": (277, 132),
|
||||
}
|
||||
|
||||
# from PixArt
|
||||
# S = 8294400
|
||||
ASPECT_RATIO_2880 = {
|
||||
"0.25": (1408, 5760),
|
||||
"0.26": (1408, 5568),
|
||||
"0.27": (1408, 5376),
|
||||
"0.28": (1408, 5184),
|
||||
"0.32": (1600, 4992),
|
||||
"0.33": (1600, 4800),
|
||||
"0.34": (1600, 4672),
|
||||
"0.4": (1792, 4480),
|
||||
"0.42": (1792, 4288),
|
||||
"0.47": (1920, 4096),
|
||||
"0.49": (1920, 3904),
|
||||
"0.51": (1920, 3776),
|
||||
"0.55": (2112, 3840),
|
||||
"0.59": (2112, 3584),
|
||||
"0.68": (2304, 3392),
|
||||
"0.72": (2304, 3200),
|
||||
"0.78": (2496, 3200),
|
||||
"0.83": (2496, 3008),
|
||||
"0.89": (2688, 3008),
|
||||
"0.93": (2688, 2880),
|
||||
"1.0": (2880, 2880),
|
||||
"1.07": (2880, 2688),
|
||||
"1.12": (3008, 2688),
|
||||
"1.21": (3008, 2496),
|
||||
"1.28": (3200, 2496),
|
||||
"1.39": (3200, 2304),
|
||||
"1.47": (3392, 2304),
|
||||
"1.7": (3584, 2112),
|
||||
"1.82": (3840, 2112),
|
||||
"2.03": (3904, 1920),
|
||||
"2.13": (4096, 1920),
|
||||
"2.39": (4288, 1792),
|
||||
"2.5": (4480, 1792),
|
||||
"2.92": (4672, 1600),
|
||||
"3.0": (4800, 1600),
|
||||
"3.12": (4992, 1600),
|
||||
"3.68": (5184, 1408),
|
||||
"3.82": (5376, 1408),
|
||||
"3.95": (5568, 1408),
|
||||
"4.0": (5760, 1408),
|
||||
}
|
||||
|
||||
# S = 4194304
|
||||
ASPECT_RATIO_2048 = {
|
||||
"0.25": (1024, 4096),
|
||||
"0.26": (1024, 3968),
|
||||
"0.27": (1024, 3840),
|
||||
"0.28": (1024, 3712),
|
||||
"0.32": (1152, 3584),
|
||||
"0.33": (1152, 3456),
|
||||
"0.35": (1152, 3328),
|
||||
"0.4": (1280, 3200),
|
||||
"0.42": (1280, 3072),
|
||||
"0.48": (1408, 2944),
|
||||
"0.5": (1408, 2816),
|
||||
"0.52": (1408, 2688),
|
||||
"0.57": (1536, 2688),
|
||||
"0.6": (1536, 2560),
|
||||
"0.68": (1664, 2432),
|
||||
"0.72": (1664, 2304),
|
||||
"0.78": (1792, 2304),
|
||||
"0.82": (1792, 2176),
|
||||
"0.88": (1920, 2176),
|
||||
"0.94": (1920, 2048),
|
||||
"1.0": (2048, 2048),
|
||||
"1.07": (2048, 1920),
|
||||
"1.13": (2176, 1920),
|
||||
"1.21": (2176, 1792),
|
||||
"1.29": (2304, 1792),
|
||||
"1.38": (2304, 1664),
|
||||
"1.46": (2432, 1664),
|
||||
"1.67": (2560, 1536),
|
||||
"1.75": (2688, 1536),
|
||||
"2.0": (2816, 1408),
|
||||
"2.09": (2944, 1408),
|
||||
"2.4": (3072, 1280),
|
||||
"2.5": (3200, 1280),
|
||||
"2.89": (3328, 1152),
|
||||
"3.0": (3456, 1152),
|
||||
"3.11": (3584, 1152),
|
||||
"3.62": (3712, 1024),
|
||||
"3.75": (3840, 1024),
|
||||
"3.88": (3968, 1024),
|
||||
"4.0": (4096, 1024),
|
||||
}
|
||||
|
||||
# S = 1048576
|
||||
ASPECT_RATIO_1024 = {
|
||||
"0.25": (512, 2048),
|
||||
"0.26": (512, 1984),
|
||||
"0.27": (512, 1920),
|
||||
"0.28": (512, 1856),
|
||||
"0.32": (576, 1792),
|
||||
"0.33": (576, 1728),
|
||||
"0.35": (576, 1664),
|
||||
"0.4": (640, 1600),
|
||||
"0.42": (640, 1536),
|
||||
"0.48": (704, 1472),
|
||||
"0.5": (704, 1408),
|
||||
"0.52": (704, 1344),
|
||||
"0.57": (768, 1344),
|
||||
"0.6": (768, 1280),
|
||||
"0.68": (832, 1216),
|
||||
"0.72": (832, 1152),
|
||||
"0.78": (896, 1152),
|
||||
"0.82": (896, 1088),
|
||||
"0.88": (960, 1088),
|
||||
"0.94": (960, 1024),
|
||||
"1.0": (1024, 1024),
|
||||
"1.07": (1024, 960),
|
||||
"1.13": (1088, 960),
|
||||
"1.21": (1088, 896),
|
||||
"1.29": (1152, 896),
|
||||
"1.38": (1152, 832),
|
||||
"1.46": (1216, 832),
|
||||
"1.67": (1280, 768),
|
||||
"1.75": (1344, 768),
|
||||
"2.0": (1408, 704),
|
||||
"2.09": (1472, 704),
|
||||
"2.4": (1536, 640),
|
||||
"2.5": (1600, 640),
|
||||
"2.89": (1664, 576),
|
||||
"3.0": (1728, 576),
|
||||
"3.11": (1792, 576),
|
||||
"3.62": (1856, 512),
|
||||
"3.75": (1920, 512),
|
||||
"3.88": (1984, 512),
|
||||
"4.0": (2048, 512),
|
||||
}
|
||||
|
||||
# S = 262144
|
||||
ASPECT_RATIO_512 = {
|
||||
"0.25": (256, 1024),
|
||||
"0.26": (256, 992),
|
||||
"0.27": (256, 960),
|
||||
"0.28": (256, 928),
|
||||
"0.32": (288, 896),
|
||||
"0.33": (288, 864),
|
||||
"0.35": (288, 832),
|
||||
"0.4": (320, 800),
|
||||
"0.42": (320, 768),
|
||||
"0.48": (352, 736),
|
||||
"0.5": (352, 704),
|
||||
"0.52": (352, 672),
|
||||
"0.57": (384, 672),
|
||||
"0.6": (384, 640),
|
||||
"0.68": (416, 608),
|
||||
"0.72": (416, 576),
|
||||
"0.78": (448, 576),
|
||||
"0.82": (448, 544),
|
||||
"0.88": (480, 544),
|
||||
"0.94": (480, 512),
|
||||
"1.0": (512, 512),
|
||||
"1.07": (512, 480),
|
||||
"1.13": (544, 480),
|
||||
"1.21": (544, 448),
|
||||
"1.29": (576, 448),
|
||||
"1.38": (576, 416),
|
||||
"1.46": (608, 416),
|
||||
"1.67": (640, 384),
|
||||
"1.75": (672, 384),
|
||||
"2.0": (704, 352),
|
||||
"2.09": (736, 352),
|
||||
"2.4": (768, 320),
|
||||
"2.5": (800, 320),
|
||||
"2.89": (832, 288),
|
||||
"3.0": (864, 288),
|
||||
"3.11": (896, 288),
|
||||
"3.62": (928, 256),
|
||||
"3.75": (960, 256),
|
||||
"3.88": (992, 256),
|
||||
"4.0": (1024, 256),
|
||||
}
|
||||
|
||||
# S = 65536
|
||||
ASPECT_RATIO_256 = {
|
||||
"0.25": (128, 512),
|
||||
"0.26": (128, 496),
|
||||
"0.27": (128, 480),
|
||||
"0.28": (128, 464),
|
||||
"0.32": (144, 448),
|
||||
"0.33": (144, 432),
|
||||
"0.35": (144, 416),
|
||||
"0.4": (160, 400),
|
||||
"0.42": (160, 384),
|
||||
"0.48": (176, 368),
|
||||
"0.5": (176, 352),
|
||||
"0.52": (176, 336),
|
||||
"0.57": (192, 336),
|
||||
"0.6": (192, 320),
|
||||
"0.68": (208, 304),
|
||||
"0.72": (208, 288),
|
||||
"0.78": (224, 288),
|
||||
"0.82": (224, 272),
|
||||
"0.88": (240, 272),
|
||||
"0.94": (240, 256),
|
||||
"1.0": (256, 256),
|
||||
"1.07": (256, 240),
|
||||
"1.13": (272, 240),
|
||||
"1.21": (272, 224),
|
||||
"1.29": (288, 224),
|
||||
"1.38": (288, 208),
|
||||
"1.46": (304, 208),
|
||||
"1.67": (320, 192),
|
||||
"1.75": (336, 192),
|
||||
"2.0": (352, 176),
|
||||
"2.09": (368, 176),
|
||||
"2.4": (384, 160),
|
||||
"2.5": (400, 160),
|
||||
"2.89": (416, 144),
|
||||
"3.0": (432, 144),
|
||||
"3.11": (448, 144),
|
||||
"3.62": (464, 128),
|
||||
"3.75": (480, 128),
|
||||
"3.88": (496, 128),
|
||||
"4.0": (512, 128),
|
||||
}
|
||||
|
||||
|
||||
def get_closest_ratio(height: float, width: float, ratios: dict):
|
||||
aspect_ratio = height / width
|
||||
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
||||
return closest_ratio
|
||||
|
||||
|
||||
ASPECT_RATIOS = {
|
||||
"144p": (36864, ASPECT_RATIO_144P),
|
||||
"256": (65536, ASPECT_RATIO_256),
|
||||
"240p": (102240, ASPECT_RATIO_240P),
|
||||
"360p": (230400, ASPECT_RATIO_360P),
|
||||
"512": (262144, ASPECT_RATIO_512),
|
||||
"480p": (409920, ASPECT_RATIO_480P),
|
||||
"720p": (921600, ASPECT_RATIO_720P),
|
||||
"1024": (1048576, ASPECT_RATIO_1024),
|
||||
"1080p": (2073600, ASPECT_RATIO_1080P),
|
||||
"2k": (3686400, ASPECT_RATIO_2K),
|
||||
"2048": (4194304, ASPECT_RATIO_2048),
|
||||
"2880": (8294400, ASPECT_RATIO_2880),
|
||||
"4k": (8294400, ASPECT_RATIO_4K),
|
||||
}
|
||||
|
||||
|
||||
def get_image_size(resolution, ar_ratio):
|
||||
ar_key = ASPECT_RATIO_MAP[ar_ratio]
|
||||
rs_dict = ASPECT_RATIOS[resolution][1]
|
||||
assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}"
|
||||
return rs_dict[ar_key]
|
||||
|
||||
|
||||
NUM_FRAMES_MAP = {
|
||||
"1x": 51,
|
||||
"2x": 102,
|
||||
"4x": 204,
|
||||
"8x": 408,
|
||||
"16x": 816,
|
||||
"2s": 51,
|
||||
"4s": 102,
|
||||
"8s": 204,
|
||||
"16s": 408,
|
||||
"32s": 816,
|
||||
}
|
||||
|
||||
|
||||
def get_num_frames(num_frames):
|
||||
if num_frames in NUM_FRAMES_MAP:
|
||||
return NUM_FRAMES_MAP[num_frames]
|
||||
else:
|
||||
return int(num_frames)
|
||||
|
||||
|
||||
def save_sample(x, save_path=None, fps=8, normalize=True, value_range=(-1, 1), force_video=False, verbose=True):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): shape [C, T, H, W]
|
||||
"""
|
||||
assert x.ndim == 4
|
||||
|
||||
if not force_video and x.shape[1] == 1: # T = 1: save as image
|
||||
save_path += ".png"
|
||||
x = x.squeeze(1)
|
||||
save_image([x], save_path, normalize=normalize, value_range=value_range)
|
||||
else:
|
||||
save_path += ".mp4"
|
||||
if normalize:
|
||||
low, high = value_range
|
||||
x.clamp_(min=low, max=high)
|
||||
x.sub_(low).div_(max(high - low, 1e-5))
|
||||
|
||||
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
|
||||
write_video(save_path, x, fps=fps, video_codec="h264")
|
||||
if verbose:
|
||||
print(f"Saved to {save_path}")
|
||||
return save_path
|
||||
|
||||
|
||||
def is_url(url):
|
||||
return re.match(regex, url) is not None
|
||||
|
||||
|
||||
def download_url(input_path):
|
||||
output_dir = "cache"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
base_name = os.path.basename(input_path)
|
||||
output_path = os.path.join(output_dir, base_name)
|
||||
img_data = requests.get(input_path).content
|
||||
with open(output_path, "wb") as handler:
|
||||
handler.write(img_data)
|
||||
print(f"URL {input_path} downloaded to {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def get_transforms_video(name="center", image_size=(256, 256)):
|
||||
if name is None:
|
||||
return None
|
||||
elif name == "center":
|
||||
assert image_size[0] == image_size[1], "image_size must be square for center crop"
|
||||
transform_video = transforms.Compose(
|
||||
[
|
||||
ToTensorVideo(), # TCHW
|
||||
# video_transforms.RandomHorizontalFlipVideo(),
|
||||
UCFCenterCropVideo(image_size[0]),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
elif name == "resize_crop":
|
||||
transform_video = transforms.Compose(
|
||||
[
|
||||
ToTensorVideo(), # TCHW
|
||||
ResizeCrop(image_size),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Transform {name} not implemented")
|
||||
return transform_video
|
||||
|
||||
|
||||
def crop(clip, i, j, h, w):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||
"""
|
||||
if len(clip.size()) != 4:
|
||||
raise ValueError("clip should be a 4D tensor")
|
||||
return clip[..., i : i + h, j : j + w]
|
||||
|
||||
|
||||
def center_crop(clip, crop_size):
|
||||
if not _is_tensor_video_clip(clip):
|
||||
raise ValueError("clip should be a 4D torch.tensor")
|
||||
h, w = clip.size(-2), clip.size(-1)
|
||||
th, tw = crop_size
|
||||
if h < th or w < tw:
|
||||
raise ValueError("height and width must be no smaller than crop_size")
|
||||
|
||||
i = int(round((h - th) / 2.0))
|
||||
j = int(round((w - tw) / 2.0))
|
||||
return crop(clip, i, j, th, tw)
|
||||
|
||||
|
||||
def resize_scale(clip, target_size, interpolation_mode):
|
||||
if len(target_size) != 2:
|
||||
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
||||
H, W = clip.size(-2), clip.size(-1)
|
||||
scale_ = target_size[0] / min(H, W)
|
||||
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
|
||||
|
||||
|
||||
class UCFCenterCropVideo:
|
||||
"""
|
||||
First scale to the specified size in equal proportion to the short edge,
|
||||
then center cropping
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
interpolation_mode="bilinear",
|
||||
):
|
||||
if isinstance(size, tuple):
|
||||
if len(size) != 2:
|
||||
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
|
||||
self.interpolation_mode = interpolation_mode
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||
Returns:
|
||||
torch.tensor: scale resized / center cropped video clip.
|
||||
size is (T, C, crop_size, crop_size)
|
||||
"""
|
||||
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
||||
clip_center_crop = center_crop(clip_resize, self.size)
|
||||
return clip_center_crop
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
||||
|
||||
|
||||
def _is_tensor_video_clip(clip):
|
||||
if not torch.is_tensor(clip):
|
||||
raise TypeError("clip should be Tensor. Got %s" % type(clip))
|
||||
|
||||
if not clip.ndimension() == 4:
|
||||
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def to_tensor(clip):
|
||||
"""
|
||||
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
||||
permute the dimensions of clip tensor
|
||||
Args:
|
||||
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
||||
Return:
|
||||
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
||||
"""
|
||||
_is_tensor_video_clip(clip)
|
||||
if not clip.dtype == torch.uint8:
|
||||
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
|
||||
# return clip.float().permute(3, 0, 1, 2) / 255.0
|
||||
return clip.float() / 255.0
|
||||
|
||||
|
||||
class ToTensorVideo:
|
||||
"""
|
||||
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
||||
permute the dimensions of clip tensor
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
||||
Return:
|
||||
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
||||
"""
|
||||
return to_tensor(clip)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
class ResizeCrop:
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, clip):
|
||||
clip = resize_crop_to_fill(clip, self.size)
|
||||
return clip
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(size={self.size})"
|
||||
|
||||
|
||||
def get_transforms_image(name="center", image_size=(256, 256)):
|
||||
if name is None:
|
||||
return None
|
||||
elif name == "center":
|
||||
assert image_size[0] == image_size[1], "Image size must be square for center crop"
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
|
||||
# transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
elif name == "resize_crop":
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Transform {name} not implemented")
|
||||
return transform
|
||||
|
||||
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
|
||||
|
||||
|
||||
def resize_crop_to_fill(pil_image, image_size):
|
||||
w, h = pil_image.size # PIL is (W, H)
|
||||
th, tw = image_size
|
||||
rh, rw = th / h, tw / w
|
||||
if rh > rw:
|
||||
sh, sw = th, round(w * rh)
|
||||
image = pil_image.resize((sw, sh), Image.BICUBIC)
|
||||
i = 0
|
||||
j = int(round((sw - tw) / 2.0))
|
||||
else:
|
||||
sh, sw = round(h * rw), tw
|
||||
image = pil_image.resize((sw, sh), Image.BICUBIC)
|
||||
i = int(round((sh - th) / 2.0))
|
||||
j = 0
|
||||
arr = np.array(image)
|
||||
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
|
||||
return Image.fromarray(arr[i : i + th, j : j + tw])
|
||||
|
||||
|
||||
def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)):
|
||||
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
if transform is None:
|
||||
transform = get_transforms_video(image_size=image_size, name=transform_name)
|
||||
video = transform(vframes) # T C H W
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
return video
|
||||
|
||||
|
||||
def read_from_path(path, image_size, transform_name="center"):
|
||||
if is_url(path):
|
||||
path = download_url(path)
|
||||
ext = os.path.splitext(path)[-1].lower()
|
||||
if ext.lower() in VID_EXTENSIONS:
|
||||
return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
|
||||
else:
|
||||
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||||
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
|
||||
|
||||
|
||||
def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)):
|
||||
image = pil_loader(path)
|
||||
if transform is None:
|
||||
transform = get_transforms_image(image_size=image_size, name=transform_name)
|
||||
image = transform(image)
|
||||
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
return video
|
||||
|
||||
|
||||
def prepare_multi_resolution_info(info_type, batch_size, image_size, num_frames, fps, device, dtype):
|
||||
if info_type is None:
|
||||
return dict()
|
||||
elif info_type == "PixArtMS":
|
||||
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(batch_size, 1)
|
||||
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(batch_size, 1)
|
||||
return dict(ar=ar, hw=hw)
|
||||
elif info_type in ["STDiT2", "OpenSora"]:
|
||||
fps = fps if num_frames > 1 else IMG_FPS
|
||||
fps = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size)
|
||||
height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(batch_size)
|
||||
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(batch_size)
|
||||
num_frames = torch.tensor([num_frames], device=device, dtype=dtype).repeat(batch_size)
|
||||
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(batch_size)
|
||||
return dict(height=height, width=width, num_frames=num_frames, ar=ar, fps=fps)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@ -1,958 +0,0 @@
|
||||
import html
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import ftfy
|
||||
import torch
|
||||
from diffusers.models import AutoencoderKL
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from videosys.core.pab_mgr import PABConfig, set_pab_manager
|
||||
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
|
||||
from videosys.models.autoencoders.autoencoder_kl_open_sora import OpenSoraVAE_V1_2
|
||||
from videosys.models.transformers.open_sora_transformer_3d import STDiT3
|
||||
from videosys.schedulers.scheduling_rflow_open_sora import RFLOW
|
||||
from videosys.utils.utils import save_video
|
||||
|
||||
from .data_process import get_image_size, get_num_frames, prepare_multi_resolution_info, read_from_path
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
|
||||
BAD_PUNCT_REGEX = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
|
||||
class OpenSoraPABConfig(PABConfig):
|
||||
def __init__(
|
||||
self,
|
||||
steps: int = 50,
|
||||
spatial_broadcast: bool = True,
|
||||
spatial_threshold: list = [450, 930],
|
||||
spatial_range: int = 2,
|
||||
temporal_broadcast: bool = True,
|
||||
temporal_threshold: list = [450, 930],
|
||||
temporal_range: int = 4,
|
||||
cross_broadcast: bool = True,
|
||||
cross_threshold: list = [450, 930],
|
||||
cross_range: int = 6,
|
||||
mlp_broadcast: bool = True,
|
||||
mlp_spatial_broadcast_config: dict = {
|
||||
676: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
788: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
864: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
},
|
||||
mlp_temporal_broadcast_config: dict = {
|
||||
676: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
788: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
864: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
|
||||
},
|
||||
):
|
||||
super().__init__(
|
||||
steps=steps,
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
spatial_threshold=spatial_threshold,
|
||||
spatial_range=spatial_range,
|
||||
temporal_broadcast=temporal_broadcast,
|
||||
temporal_threshold=temporal_threshold,
|
||||
temporal_range=temporal_range,
|
||||
cross_broadcast=cross_broadcast,
|
||||
cross_threshold=cross_threshold,
|
||||
cross_range=cross_range,
|
||||
mlp_broadcast=mlp_broadcast,
|
||||
mlp_spatial_broadcast_config=mlp_spatial_broadcast_config,
|
||||
mlp_temporal_broadcast_config=mlp_temporal_broadcast_config,
|
||||
)
|
||||
|
||||
|
||||
class OpenSoraConfig:
|
||||
"""
|
||||
This config is to instantiate a `OpenSoraPipeline` class for video generation.
|
||||
|
||||
To be specific, this config will be passed to engine by `VideoSysEngine(config)`.
|
||||
In the engine, it will be used to instantiate the corresponding pipeline class.
|
||||
And the engine will call the `generate` function of the pipeline to generate the video.
|
||||
If you want to explore the detail of generation, please refer to the pipeline class below.
|
||||
|
||||
Args:
|
||||
transformer (str):
|
||||
The transformer model to use. Defaults to "hpcai-tech/OpenSora-STDiT-v3".
|
||||
vae (str):
|
||||
The VAE model to use. Defaults to "hpcai-tech/OpenSora-VAE-v1.2".
|
||||
text_encoder (str):
|
||||
The text encoder model to use. Defaults to "DeepFloyd/t5-v1_1-xxl".
|
||||
num_gpus (int):
|
||||
The number of GPUs to use. Defaults to 1.
|
||||
num_sampling_steps (int):
|
||||
The number of sampling steps. Defaults to 30.
|
||||
cfg_scale (float):
|
||||
The configuration scale. Defaults to 7.0.
|
||||
tiling_size (int):
|
||||
The tiling size. Defaults to 4.
|
||||
enable_flash_attn (bool):
|
||||
Whether to enable Flash Attention. Defaults to False.
|
||||
enable_pab (bool):
|
||||
Whether to enable Pyramid Attention Broadcast. Defaults to False.
|
||||
pab_config (CogVideoXPABConfig):
|
||||
The configuration for Pyramid Attention Broadcast. Defaults to `LattePABConfig()`.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
from videosys import OpenSoraConfig, VideoSysEngine
|
||||
|
||||
# change num_gpus for multi-gpu inference
|
||||
# sampling parameters are defined in the config
|
||||
config = OpenSoraConfig(num_sampling_steps=30, cfg_scale=7.0, num_gpus=1)
|
||||
engine = VideoSysEngine(config)
|
||||
|
||||
prompt = "Sunset over the sea."
|
||||
# num frames: 2s, 4s, 8s, 16s
|
||||
# resolution: 144p, 240p, 360p, 480p, 720p
|
||||
# aspect ratio: 9:16, 16:9, 3:4, 4:3, 1:1
|
||||
video = engine.generate(
|
||||
prompt=prompt,
|
||||
resolution="480p",
|
||||
aspect_ratio="9:16",
|
||||
num_frames="2s",
|
||||
).video[0]
|
||||
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: str = "hpcai-tech/OpenSora-STDiT-v3",
|
||||
vae: str = "hpcai-tech/OpenSora-VAE-v1.2",
|
||||
text_encoder: str = "DeepFloyd/t5-v1_1-xxl",
|
||||
# ======== distributed ========
|
||||
num_gpus: int = 1,
|
||||
# ======== scheduler ========
|
||||
num_sampling_steps: int = 30,
|
||||
cfg_scale: float = 7.0,
|
||||
# ======= memory =======
|
||||
cpu_offload: bool = False,
|
||||
# ======== vae ========
|
||||
tiling_size: int = 4,
|
||||
# ======== speedup ========
|
||||
enable_flash_attn: bool = False,
|
||||
# ======== pab ========
|
||||
enable_pab: bool = False,
|
||||
pab_config: PABConfig = OpenSoraPABConfig(),
|
||||
):
|
||||
self.pipeline_cls = OpenSoraPipeline
|
||||
self.transformer = transformer
|
||||
self.vae = vae
|
||||
self.text_encoder = text_encoder
|
||||
# ======== distributed ========
|
||||
self.num_gpus = num_gpus
|
||||
# ======== scheduler ========
|
||||
self.num_sampling_steps = num_sampling_steps
|
||||
self.cfg_scale = cfg_scale
|
||||
# ======== vae ========
|
||||
self.tiling_size = tiling_size
|
||||
# ======= memory ========
|
||||
self.cpu_offload = cpu_offload
|
||||
# ======== speedup ========
|
||||
self.enable_flash_attn = enable_flash_attn
|
||||
# ======== pab ========
|
||||
self.enable_pab = enable_pab
|
||||
self.pab_config = pab_config
|
||||
|
||||
|
||||
class OpenSoraPipeline(VideoSysPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using PixArt-Alpha.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. PixArt-Alpha uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`STDiT3`]):
|
||||
A text conditioned `STDiT3` to denoise the encoded video latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
"""
|
||||
bad_punct_regex = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = [
|
||||
"text_encoder",
|
||||
"tokenizer",
|
||||
]
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OpenSoraConfig,
|
||||
text_encoder: Optional[T5EncoderModel] = None,
|
||||
tokenizer: Optional[AutoTokenizer] = None,
|
||||
vae: Optional[AutoencoderKL] = None,
|
||||
transformer: Optional[STDiT3] = None,
|
||||
scheduler: Optional[RFLOW] = None,
|
||||
device: torch.device = torch.device("cuda"),
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
self._config = config
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
|
||||
# initialize the model if not provided
|
||||
if text_encoder is None:
|
||||
text_encoder = T5EncoderModel.from_pretrained(config.text_encoder).to(dtype)
|
||||
if tokenizer is None:
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.text_encoder)
|
||||
if vae is None:
|
||||
vae = OpenSoraVAE_V1_2(
|
||||
from_pretrained=config.vae,
|
||||
micro_frame_size=17,
|
||||
micro_batch_size=config.tiling_size,
|
||||
).to(dtype)
|
||||
if transformer is None:
|
||||
transformer = STDiT3.from_pretrained(config.transformer, enable_flash_attn=config.enable_flash_attn).to(
|
||||
dtype
|
||||
)
|
||||
if scheduler is None:
|
||||
scheduler = RFLOW(
|
||||
use_timestep_transform=True, num_sampling_steps=config.num_sampling_steps, cfg_scale=config.cfg_scale
|
||||
)
|
||||
|
||||
# pab
|
||||
if config.enable_pab:
|
||||
set_pab_manager(config.pab_config)
|
||||
|
||||
# set eval and device
|
||||
self.set_eval_and_device(device, vae, transformer)
|
||||
|
||||
self.register_modules(
|
||||
text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# cpu offload
|
||||
if config.cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
else:
|
||||
self.set_eval_and_device(self._device, text_encoder)
|
||||
|
||||
def get_text_embeddings(self, texts):
|
||||
text_tokens_and_mask = self.tokenizer(
|
||||
texts,
|
||||
max_length=300,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
device = self._execution_device
|
||||
input_ids = text_tokens_and_mask["input_ids"].to(device)
|
||||
attention_mask = text_tokens_and_mask["attention_mask"].to(device)
|
||||
with torch.no_grad():
|
||||
text_encoder_embs = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)["last_hidden_state"].detach()
|
||||
return text_encoder_embs, attention_mask
|
||||
|
||||
def encode_prompt(self, text):
|
||||
caption_embs, emb_masks = self.get_text_embeddings(text)
|
||||
caption_embs = caption_embs[:, None]
|
||||
return dict(y=caption_embs, mask=emb_masks)
|
||||
|
||||
def null_embed(self, n):
|
||||
null_y = self.transformer.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None].to(self._execution_device)
|
||||
return null_y
|
||||
|
||||
@staticmethod
|
||||
def _basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
def _clean_caption(self, caption):
|
||||
import urllib.parse as ul
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = self._basic_clean(caption)
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
def text_preprocessing(self, text, use_text_preprocessing: bool = True):
|
||||
if use_text_preprocessing:
|
||||
# The exact text cleaning as was in the training stage:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
return text
|
||||
else:
|
||||
return text.lower().strip()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
resolution="480p",
|
||||
aspect_ratio="9:16",
|
||||
num_frames: int = 51,
|
||||
loop: int = 1,
|
||||
llm_refine: bool = False,
|
||||
negative_prompt: str = "",
|
||||
ms: Optional[str] = "",
|
||||
refs: Optional[str] = "",
|
||||
aes: float = 6.5,
|
||||
flow: Optional[float] = None,
|
||||
camera_motion: Optional[float] = None,
|
||||
condition_frame_length: int = 5,
|
||||
align: int = 5,
|
||||
condition_frame_edit: float = 0.0,
|
||||
return_dict: bool = True,
|
||||
verbose: bool = True,
|
||||
) -> Union[VideoSysPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
resolution (`str`, *optional*, defaults to `"480p"`):
|
||||
The resolution of the generated video.
|
||||
aspect_ratio (`str`, *optional*, defaults to `"9:16"`):
|
||||
The aspect ratio of the generated video.
|
||||
num_frames (`int`, *optional*, defaults to 51):
|
||||
The number of frames to generate.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The width in pixels of the generated image.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images
|
||||
"""
|
||||
# == basic ==
|
||||
fps = 24
|
||||
image_size = get_image_size(resolution, aspect_ratio)
|
||||
num_frames = get_num_frames(num_frames)
|
||||
|
||||
# == prepare batch prompts ==
|
||||
batch_prompts = [prompt]
|
||||
ms = [ms]
|
||||
refs = [refs]
|
||||
|
||||
# == get json from prompts ==
|
||||
batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms)
|
||||
|
||||
# == get reference for condition ==
|
||||
refs = collect_references_batch(refs, self.vae, image_size)
|
||||
|
||||
# == multi-resolution info ==
|
||||
model_args = prepare_multi_resolution_info(
|
||||
"OpenSora", len(batch_prompts), image_size, num_frames, fps, self._device, self._dtype
|
||||
)
|
||||
|
||||
# == process prompts step by step ==
|
||||
# 0. split prompt
|
||||
# each element in the list is [prompt_segment_list, loop_idx_list]
|
||||
batched_prompt_segment_list = []
|
||||
batched_loop_idx_list = []
|
||||
for prompt in batch_prompts:
|
||||
prompt_segment_list, loop_idx_list = split_prompt(prompt)
|
||||
batched_prompt_segment_list.append(prompt_segment_list)
|
||||
batched_loop_idx_list.append(loop_idx_list)
|
||||
|
||||
# 1. refine prompt by openai
|
||||
# if llm_refine:
|
||||
# only call openai API when
|
||||
# 1. seq parallel is not enabled
|
||||
# 2. seq parallel is enabled and the process is rank 0
|
||||
# if not enable_sequence_parallelism or (enable_sequence_parallelism and coordinator.is_master()):
|
||||
# for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
|
||||
# batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list)
|
||||
|
||||
# # sync the prompt if using seq parallel
|
||||
# if enable_sequence_parallelism:
|
||||
# coordinator.block_all()
|
||||
# prompt_segment_length = [
|
||||
# len(prompt_segment_list) for prompt_segment_list in batched_prompt_segment_list
|
||||
# ]
|
||||
|
||||
# # flatten the prompt segment list
|
||||
# batched_prompt_segment_list = [
|
||||
# prompt_segment
|
||||
# for prompt_segment_list in batched_prompt_segment_list
|
||||
# for prompt_segment in prompt_segment_list
|
||||
# ]
|
||||
|
||||
# # create a list of size equal to world size
|
||||
# broadcast_obj_list = [batched_prompt_segment_list] * coordinator.world_size
|
||||
# dist.broadcast_object_list(broadcast_obj_list, 0)
|
||||
|
||||
# # recover the prompt list
|
||||
# batched_prompt_segment_list = []
|
||||
# segment_start_idx = 0
|
||||
# all_prompts = broadcast_obj_list[0]
|
||||
# for num_segment in prompt_segment_length:
|
||||
# batched_prompt_segment_list.append(
|
||||
# all_prompts[segment_start_idx : segment_start_idx + num_segment]
|
||||
# )
|
||||
# segment_start_idx += num_segment
|
||||
|
||||
# 2. append score
|
||||
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
|
||||
batched_prompt_segment_list[idx] = append_score_to_prompts(
|
||||
prompt_segment_list,
|
||||
aes=aes,
|
||||
flow=flow,
|
||||
camera_motion=camera_motion,
|
||||
)
|
||||
|
||||
# 3. clean prompt with T5
|
||||
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
|
||||
batched_prompt_segment_list[idx] = [self.text_preprocessing(prompt) for prompt in prompt_segment_list]
|
||||
|
||||
# 4. merge to obtain the final prompt
|
||||
batch_prompts = []
|
||||
for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list):
|
||||
batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list))
|
||||
|
||||
# == Iter over loop generation ==
|
||||
video_clips = []
|
||||
for loop_i in range(loop):
|
||||
# == get prompt for loop i ==
|
||||
batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i)
|
||||
|
||||
# == add condition frames for loop ==
|
||||
if loop_i > 0:
|
||||
refs, ms = append_generated(
|
||||
self.vae, video_clips[-1], refs, ms, loop_i, condition_frame_length, condition_frame_edit
|
||||
)
|
||||
|
||||
# == sampling ==
|
||||
input_size = (num_frames, *image_size)
|
||||
latent_size = self.vae.get_latent_size(input_size)
|
||||
z = torch.randn(
|
||||
len(batch_prompts), self.vae.out_channels, *latent_size, device=self._device, dtype=self._dtype
|
||||
)
|
||||
model_args.update(self.encode_prompt(batch_prompts_loop))
|
||||
y_null = self.null_embed(len(batch_prompts_loop))
|
||||
|
||||
masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)
|
||||
samples = self.scheduler.sample(
|
||||
self.transformer,
|
||||
z=z,
|
||||
model_args=model_args,
|
||||
y_null=y_null,
|
||||
device=self._device,
|
||||
progress=verbose,
|
||||
mask=masks,
|
||||
)
|
||||
samples = self.vae.decode(samples.to(self._dtype), num_frames=num_frames)
|
||||
video_clips.append(samples)
|
||||
|
||||
for i in range(1, loop):
|
||||
video_clips[i] = video_clips[i][:, dframe_to_frame(condition_frame_length) :]
|
||||
video = torch.cat(video_clips, dim=1)
|
||||
|
||||
low, high = -1, 1
|
||||
video.clamp_(min=low, max=high)
|
||||
video.sub_(low).div_(max(high - low, 1e-5))
|
||||
video = video.mul(255).add_(0.5).clamp_(0, 255).permute(0, 2, 3, 4, 1).to("cpu", torch.uint8)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return VideoSysPipelineOutput(video=video)
|
||||
|
||||
def save_video(self, video, output_path):
|
||||
save_video(video, output_path, fps=24)
|
||||
|
||||
|
||||
def load_prompts(prompt_path, start_idx=None, end_idx=None):
|
||||
with open(prompt_path, "r") as f:
|
||||
prompts = [line.strip() for line in f.readlines()]
|
||||
prompts = prompts[start_idx:end_idx]
|
||||
return prompts
|
||||
|
||||
|
||||
def get_save_path_name(
|
||||
save_dir,
|
||||
sample_name=None, # prefix
|
||||
sample_idx=None, # sample index
|
||||
prompt=None, # used prompt
|
||||
prompt_as_path=False, # use prompt as path
|
||||
num_sample=1, # number of samples to generate for one prompt
|
||||
k=None, # kth sample
|
||||
):
|
||||
if sample_name is None:
|
||||
sample_name = "" if prompt_as_path else "sample"
|
||||
sample_name_suffix = prompt if prompt_as_path else f"_{sample_idx:04d}"
|
||||
save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix[:50]}")
|
||||
if num_sample != 1:
|
||||
save_path = f"{save_path}-{k}"
|
||||
return save_path
|
||||
|
||||
|
||||
def get_eval_save_path_name(
|
||||
save_dir,
|
||||
id, # add id parameter
|
||||
sample_name=None, # prefix
|
||||
sample_idx=None, # sample index
|
||||
prompt=None, # used prompt
|
||||
prompt_as_path=False, # use prompt as path
|
||||
num_sample=1, # number of samples to generate for one prompt
|
||||
k=None, # kth sample
|
||||
):
|
||||
if sample_name is None:
|
||||
sample_name = "" if prompt_as_path else "sample"
|
||||
save_path = os.path.join(save_dir, f"{id}")
|
||||
if num_sample != 1:
|
||||
save_path = f"{save_path}-{k}"
|
||||
return save_path
|
||||
|
||||
|
||||
def append_score_to_prompts(prompts, aes=None, flow=None, camera_motion=None):
|
||||
new_prompts = []
|
||||
for prompt in prompts:
|
||||
new_prompt = prompt
|
||||
if aes is not None and "aesthetic score:" not in prompt:
|
||||
new_prompt = f"{new_prompt} aesthetic score: {aes:.1f}."
|
||||
if flow is not None and "motion score:" not in prompt:
|
||||
new_prompt = f"{new_prompt} motion score: {flow:.1f}."
|
||||
if camera_motion is not None and "camera motion:" not in prompt:
|
||||
new_prompt = f"{new_prompt} camera motion: {camera_motion}."
|
||||
new_prompts.append(new_prompt)
|
||||
return new_prompts
|
||||
|
||||
|
||||
def extract_json_from_prompts(prompts, reference, mask_strategy):
|
||||
ret_prompts = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
parts = re.split(r"(?=[{])", prompt)
|
||||
assert len(parts) <= 2, f"Invalid prompt: {prompt}"
|
||||
ret_prompts.append(parts[0])
|
||||
if len(parts) > 1:
|
||||
additional_info = json.loads(parts[1])
|
||||
for key in additional_info:
|
||||
assert key in ["reference_path", "mask_strategy"], f"Invalid key: {key}"
|
||||
if key == "reference_path":
|
||||
reference[i] = additional_info[key]
|
||||
elif key == "mask_strategy":
|
||||
mask_strategy[i] = additional_info[key]
|
||||
return ret_prompts, reference, mask_strategy
|
||||
|
||||
|
||||
def collect_references_batch(reference_paths, vae, image_size):
|
||||
refs_x = [] # refs_x: [batch, ref_num, C, T, H, W]
|
||||
for reference_path in reference_paths:
|
||||
if reference_path == "":
|
||||
refs_x.append([])
|
||||
continue
|
||||
ref_path = reference_path.split(";")
|
||||
ref = []
|
||||
for r_path in ref_path:
|
||||
r = read_from_path(r_path, image_size, transform_name="resize_crop")
|
||||
r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
|
||||
r_x = r_x.squeeze(0)
|
||||
ref.append(r_x)
|
||||
refs_x.append(ref)
|
||||
return refs_x
|
||||
|
||||
|
||||
def extract_prompts_loop(prompts, num_loop):
|
||||
ret_prompts = []
|
||||
for prompt in prompts:
|
||||
if prompt.startswith("|0|"):
|
||||
prompt_list = prompt.split("|")[1:]
|
||||
text_list = []
|
||||
for i in range(0, len(prompt_list), 2):
|
||||
start_loop = int(prompt_list[i])
|
||||
text = prompt_list[i + 1]
|
||||
end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop + 1
|
||||
text_list.extend([text] * (end_loop - start_loop))
|
||||
prompt = text_list[num_loop]
|
||||
ret_prompts.append(prompt)
|
||||
return ret_prompts
|
||||
|
||||
|
||||
def split_prompt(prompt_text):
|
||||
if prompt_text.startswith("|0|"):
|
||||
# this is for prompts which look like
|
||||
# |0| a beautiful day |1| a sunny day |2| a rainy day
|
||||
# we want to parse it into a list of prompts with the loop index
|
||||
prompt_list = prompt_text.split("|")[1:]
|
||||
text_list = []
|
||||
loop_idx = []
|
||||
for i in range(0, len(prompt_list), 2):
|
||||
start_loop = int(prompt_list[i])
|
||||
text = prompt_list[i + 1].strip()
|
||||
text_list.append(text)
|
||||
loop_idx.append(start_loop)
|
||||
return text_list, loop_idx
|
||||
else:
|
||||
return [prompt_text], None
|
||||
|
||||
|
||||
def merge_prompt(text_list, loop_idx_list=None):
|
||||
if loop_idx_list is None:
|
||||
return text_list[0]
|
||||
else:
|
||||
prompt = ""
|
||||
for i, text in enumerate(text_list):
|
||||
prompt += f"|{loop_idx_list[i]}|{text}"
|
||||
return prompt
|
||||
|
||||
|
||||
MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"]
|
||||
|
||||
|
||||
def parse_mask_strategy(mask_strategy):
|
||||
mask_batch = []
|
||||
if mask_strategy == "" or mask_strategy is None:
|
||||
return mask_batch
|
||||
|
||||
mask_strategy = mask_strategy.split(";")
|
||||
for mask in mask_strategy:
|
||||
mask_group = mask.split(",")
|
||||
num_group = len(mask_group)
|
||||
assert num_group >= 1 and num_group <= 6, f"Invalid mask strategy: {mask}"
|
||||
mask_group.extend(MASK_DEFAULT[num_group:])
|
||||
for i in range(5):
|
||||
mask_group[i] = int(mask_group[i])
|
||||
mask_group[5] = float(mask_group[5])
|
||||
mask_batch.append(mask_group)
|
||||
return mask_batch
|
||||
|
||||
|
||||
def find_nearest_point(value, point, max_value):
|
||||
t = value // point
|
||||
if value % point > point / 2 and t < max_value // point - 1:
|
||||
t += 1
|
||||
return t * point
|
||||
|
||||
|
||||
def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
|
||||
masks = []
|
||||
no_mask = True
|
||||
for i, mask_strategy in enumerate(mask_strategys):
|
||||
no_mask = False
|
||||
mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
|
||||
mask_strategy = parse_mask_strategy(mask_strategy)
|
||||
for mst in mask_strategy:
|
||||
loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst
|
||||
if loop_id != loop_i:
|
||||
continue
|
||||
ref = refs_x[i][m_id]
|
||||
|
||||
if m_ref_start < 0:
|
||||
# ref: [C, T, H, W]
|
||||
m_ref_start = ref.shape[1] + m_ref_start
|
||||
if m_target_start < 0:
|
||||
# z: [B, C, T, H, W]
|
||||
m_target_start = z.shape[2] + m_target_start
|
||||
if align is not None:
|
||||
m_ref_start = find_nearest_point(m_ref_start, align, ref.shape[1])
|
||||
m_target_start = find_nearest_point(m_target_start, align, z.shape[2])
|
||||
m_length = min(m_length, z.shape[2] - m_target_start, ref.shape[1] - m_ref_start)
|
||||
z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
|
||||
mask[m_target_start : m_target_start + m_length] = edit_ratio
|
||||
masks.append(mask)
|
||||
if no_mask:
|
||||
return None
|
||||
masks = torch.stack(masks)
|
||||
return masks
|
||||
|
||||
|
||||
def append_generated(vae, generated_video, refs_x, mask_strategy, loop_i, condition_frame_length, condition_frame_edit):
|
||||
ref_x = vae.encode(generated_video)
|
||||
for j, refs in enumerate(refs_x):
|
||||
if refs is None:
|
||||
refs_x[j] = [ref_x[j]]
|
||||
else:
|
||||
refs.append(ref_x[j])
|
||||
if mask_strategy[j] is None or mask_strategy[j] == "":
|
||||
mask_strategy[j] = ""
|
||||
else:
|
||||
mask_strategy[j] += ";"
|
||||
mask_strategy[
|
||||
j
|
||||
] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length},{condition_frame_edit}"
|
||||
return refs_x, mask_strategy
|
||||
|
||||
|
||||
def dframe_to_frame(num):
|
||||
assert num % 5 == 0, f"Invalid num: {num}"
|
||||
return num // 5 * 17
|
||||
|
||||
|
||||
OPENAI_CLIENT = None
|
||||
REFINE_PROMPTS = None
|
||||
REFINE_PROMPTS_PATH = "assets/texts/t2v_pllava.txt"
|
||||
REFINE_PROMPTS_TEMPLATE = """
|
||||
You need to refine user's input prompt. The user's input prompt is used for video generation task. You need to refine the user's prompt to make it more suitable for the task. Here are some examples of refined prompts:
|
||||
{}
|
||||
|
||||
The refined prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The refined prompt should be in English.
|
||||
"""
|
||||
RANDOM_PROMPTS = None
|
||||
RANDOM_PROMPTS_TEMPLATE = """
|
||||
You need to generate one input prompt for video generation task. The prompt should be suitable for the task. Here are some examples of refined prompts:
|
||||
{}
|
||||
|
||||
The prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The prompt should be in English.
|
||||
"""
|
||||
|
||||
|
||||
def get_openai_response(sys_prompt, usr_prompt, model="gpt-4o"):
|
||||
global OPENAI_CLIENT
|
||||
if OPENAI_CLIENT is None:
|
||||
from openai import OpenAI
|
||||
|
||||
OPENAI_CLIENT = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
completion = OPENAI_CLIENT.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": sys_prompt,
|
||||
}, # <-- This is the system message that provides context to the model
|
||||
{
|
||||
"role": "user",
|
||||
"content": usr_prompt,
|
||||
}, # <-- This is the user message for which the model will generate a response
|
||||
],
|
||||
)
|
||||
|
||||
return completion.choices[0].message.content
|
||||
|
||||
|
||||
def get_random_prompt_by_openai():
|
||||
global RANDOM_PROMPTS
|
||||
if RANDOM_PROMPTS is None:
|
||||
examples = load_prompts(REFINE_PROMPTS_PATH)
|
||||
RANDOM_PROMPTS = RANDOM_PROMPTS_TEMPLATE.format("\n".join(examples))
|
||||
|
||||
response = get_openai_response(RANDOM_PROMPTS, "Generate one example.")
|
||||
return response
|
||||
|
||||
|
||||
def refine_prompt_by_openai(prompt):
|
||||
global REFINE_PROMPTS
|
||||
if REFINE_PROMPTS is None:
|
||||
examples = load_prompts(REFINE_PROMPTS_PATH)
|
||||
REFINE_PROMPTS = REFINE_PROMPTS_TEMPLATE.format("\n".join(examples))
|
||||
|
||||
response = get_openai_response(REFINE_PROMPTS, prompt)
|
||||
return response
|
||||
|
||||
|
||||
def has_openai_key():
|
||||
return "OPENAI_API_KEY" in os.environ
|
||||
|
||||
|
||||
def refine_prompts_by_openai(prompts):
|
||||
new_prompts = []
|
||||
for prompt in prompts:
|
||||
try:
|
||||
if prompt.strip() == "":
|
||||
new_prompt = get_random_prompt_by_openai()
|
||||
print(f"[Info] Empty prompt detected, generate random prompt: {new_prompt}")
|
||||
else:
|
||||
new_prompt = refine_prompt_by_openai(prompt)
|
||||
print(f"[Info] Refine prompt: {prompt} -> {new_prompt}")
|
||||
new_prompts.append(new_prompt)
|
||||
except Exception as e:
|
||||
print(f"[Warning] Failed to refine prompt: {prompt} due to {e}")
|
||||
new_prompts.append(prompt)
|
||||
return new_prompts
|
||||
|
||||
|
||||
def add_watermark(
|
||||
input_video_path, watermark_image_path="./assets/images/watermark/watermark.png", output_video_path=None
|
||||
):
|
||||
# execute this command in terminal with subprocess
|
||||
# return if the process is successful
|
||||
if output_video_path is None:
|
||||
output_video_path = input_video_path.replace(".mp4", "_watermark.mp4")
|
||||
cmd = f'ffmpeg -y -i {input_video_path} -i {watermark_image_path} -filter_complex "[1][0]scale2ref=oh*mdar:ih*0.1[logo][video];[video][logo]overlay" {output_video_path}'
|
||||
exit_code = os.system(cmd)
|
||||
is_success = exit_code == 0
|
||||
return is_success
|
||||
@ -1,3 +0,0 @@
|
||||
from .pipeline_open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
|
||||
|
||||
__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanPABConfig"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user