support cogvideox

This commit is contained in:
LiewFeng 2024-12-19 13:03:56 +08:00
parent 21ea48e71a
commit 30bf3cba88
6 changed files with 262 additions and 17 deletions

View File

@ -53,6 +53,11 @@
![visualization](./assets/tisser.png) ![visualization](./assets/tisser.png)
## Latest News 🔥
- [2024/12/19] 🔥 Support [CogVideoX](https://github.com/THUDM/CogVideo).
- [2024/12/06] 🎉 Release the [code](https://github.com/LiewFeng/TeaCache) TeaCache. Support [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) and [Latte](https://github.com/Vchitect/Latte).
- [2024/11/28] 🎉 Release the [paper](https://arxiv.org/abs/2411.19108) of TeaCache.
## Introduction ## Introduction
We introduce Timestep Embedding Aware Cache (TeaCache), a training-free caching approach that estimates and leverages the fluctuating differences among model outputs across timesteps. For more details and visual results, please visit our [project page](https://github.com/LiewFeng/TeaCache). We introduce Timestep Embedding Aware Cache (TeaCache), a training-free caching approach that estimates and leverages the fluctuating differences among model outputs across timesteps. For more details and visual results, please visit our [project page](https://github.com/LiewFeng/TeaCache).
@ -92,6 +97,7 @@ cd eval/teacache
python experiments/latte.py python experiments/latte.py
python experiments/opensora.py python experiments/opensora.py
python experiments/open_sora_plan.py python experiments/open_sora_plan.py
python experiments/cogvideox.py
``` ```
2. Calculate Vbench score 2. Calculate Vbench score
@ -116,19 +122,17 @@ python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb
## Citation ## Citation
If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
``` ```
@misc{liu2024timestep, @article{liu2024timestep,
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
author={Feng Liu and Shiwei Zhang and Xiaofeng Wang and Yujie Wei and Haonan Qiu and Yuzhong Zhao and Yingya Zhang and Qixiang Ye and Fang Wan}, author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
year={2024}, journal={arXiv preprint arXiv:2411.19108},
eprint={2411.19108}, year={2024}
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2411.19108}
} }
``` ```
## Acknowledgement ## Acknowledgement
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys). Thanks for their contributions! This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte) and [CogVideoX](https://github.com/THUDM/CogVideo). Thanks for their contributions!

View File

@ -0,0 +1,234 @@
from utils import generate_func, read_prompt_list
from videosys import CogVideoXConfig, VideoSysEngine
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
from typing import Any, Dict, Optional, Tuple, Union
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
from videosys.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput
from videosys.utils.utils import batch_func
from functools import partial
from diffusers.utils import is_torch_version
def teacache_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
all_timesteps=None
):
if self.parallel_manager.cp_size > 1:
(
hidden_states,
encoder_hidden_states,
timestep,
timestep_cond,
image_rotary_emb,
) = batch_func(
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
hidden_states,
encoder_hidden_states,
timestep,
timestep_cond,
image_rotary_emb,
)
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
org_timestep = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
text_seq_length = encoder_hidden_states.shape[1]
if not self.config.use_rotary_positional_embeddings:
seq_length = height * width * num_frames // (self.config.patch_size**2)
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
if self.enable_teacache:
inp = hidden_states.clone()
encoder_hidden_states_ = encoder_hidden_states.clone()
emb_ = emb.clone()
_, modulated_inp, _, _ = self.transformer_blocks[0].norm1(inp, encoder_hidden_states_, emb_)
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
coefficients = [1.42842830e+05, -3.99193393e+04, 3.85937428e+03, -1.49458838e+02, 2.04751119e+00]
else:
# CogVideoX-5B
coefficients = [1.80221813e+05, -5.37021537e+04, 5.61853221e+03, -2.44280388e+02, 3.83458338e+00]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
encoder_hidden_states += self.previous_residual_encoder
else:
if self.parallel_manager.sp_size > 1:
set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
ori_hidden_states = hidden_states.clone()
ori_encoder_hidden_states = encoder_hidden_states.clone()
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
timestep=timesteps if False else None,
)
self.previous_residual = hidden_states - ori_hidden_states
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
else:
if self.parallel_manager.sp_size > 1:
set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
timestep=timesteps if False else None,
)
if self.parallel_manager.sp_size > 1:
if self.enable_teacache:
if should_calc:
hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
else:
hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if self.parallel_manager.cp_size > 1:
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def eval_teacache_slow(prompt_list):
config = CogVideoXConfig()
engine = VideoSysEngine(config)
engine.driver_worker.transformer.enable_teacache = True
engine.driver_worker.transformer.rel_l1_thresh = 0.1
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.previous_modulated_input = None
engine.driver_worker.transformer.previous_residual = None
engine.driver_worker.transformer.previous_residual_encoder = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/cogvideox_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
config = CogVideoXConfig()
engine = VideoSysEngine(config)
engine.driver_worker.transformer.enable_teacache = True
engine.driver_worker.transformer.rel_l1_thresh = 0.2
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.previous_modulated_input = None
engine.driver_worker.transformer.previous_residual = None
engine.driver_worker.transformer.previous_residual_encoder = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/cogvideox_teacache_fast", loop=5)
def eval_base(prompt_list):
config = CogVideoXConfig()
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/cogvideox_base", loop=5)
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
eval_base(prompt_list)
eval_teacache_slow(prompt_list)
eval_teacache_fast(prompt_list)

View File

@ -6,6 +6,9 @@ from einops import rearrange, repeat
import numpy as np import numpy as np
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
from videosys.models.transformers.open_sora_plan_v110_transformer_3d import Transformer3DModelOutput
from videosys.utils.utils import batch_func
from functools import partial
def teacache_forward( def teacache_forward(
self, self,

View File

@ -10,8 +10,7 @@ def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict
kwargs["verbose"] = False kwargs["verbose"] = False
for prompt in tqdm.tqdm(prompt_list): for prompt in tqdm.tqdm(prompt_list):
for l in range(loop): for l in range(loop):
set_seed(l) video = pipeline.generate(prompt, seed=l, **kwargs).video[0]
video = pipeline.generate(prompt, **kwargs).video[0]
pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4")) pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4"))

View File

@ -484,6 +484,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True, return_dict: bool = True,
all_timesteps=None
): ):
if self.parallel_manager.cp_size > 1: if self.parallel_manager.cp_size > 1:
( (

View File

@ -28,6 +28,7 @@ from videosys.schedulers.scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
from videosys.schedulers.scheduling_dpm_cogvideox import CogVideoXDPMScheduler from videosys.schedulers.scheduling_dpm_cogvideox import CogVideoXDPMScheduler
from videosys.utils.logging import logger from videosys.utils.logging import logger
from videosys.utils.utils import save_video, set_seed from videosys.utils.utils import save_video, set_seed
import tqdm
class CogVideoXPABConfig(PABConfig): class CogVideoXPABConfig(PABConfig):
@ -511,6 +512,7 @@ class CogVideoXPipeline(VideoSysPipeline):
] = None, ] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226, max_sequence_length: int = 226,
verbose=True
) -> Union[VideoSysPipelineOutput, Tuple]: ) -> Union[VideoSysPipelineOutput, Tuple]:
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -675,10 +677,11 @@ class CogVideoXPipeline(VideoSysPipeline):
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar: # with self.progress_bar(total=num_inference_steps) as progress_bar:
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
# for DPM-solver++ # for DPM-solver++
old_pred_original_sample = None old_pred_original_sample = None
for i, t in enumerate(timesteps): for i, t in progress_wrap(list(enumerate(timesteps))):
if self.interrupt: if self.interrupt:
continue continue
@ -693,6 +696,7 @@ class CogVideoXPipeline(VideoSysPipeline):
hidden_states=latent_model_input, hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep=timestep, timestep=timestep,
all_timesteps=timesteps,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
return_dict=False, return_dict=False,
)[0] )[0]
@ -733,8 +737,8 @@ class CogVideoXPipeline(VideoSysPipeline):
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() # progress_bar.update()
if not output_type == "latent": if not output_type == "latent":
video = self.decode_latents(latents) video = self.decode_latents(latents)