From 30bf3cba889415a153132de26b35ff74f00687f9 Mon Sep 17 00:00:00 2001 From: LiewFeng <1361871897@qq.com> Date: Thu, 19 Dec 2024 13:03:56 +0800 Subject: [PATCH] support cogvideox --- README.md | 22 +- eval/teacache/experiments/cogvideox.py | 234 ++++++++++++++++++ eval/teacache/experiments/opensora_plan.py | 3 + eval/teacache/experiments/utils.py | 3 +- .../transformers/cogvideox_transformer_3d.py | 1 + .../pipelines/cogvideox/pipeline_cogvideox.py | 16 +- 6 files changed, 262 insertions(+), 17 deletions(-) create mode 100644 eval/teacache/experiments/cogvideox.py diff --git a/README.md b/README.md index 8b0c335..a02f7f9 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,11 @@ ![visualization](./assets/tisser.png) +## Latest News 🔥 +- [2024/12/19] 🔥 Support [CogVideoX](https://github.com/THUDM/CogVideo). +- [2024/12/06] 🎉 Release the [code](https://github.com/LiewFeng/TeaCache) TeaCache. Support [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) and [Latte](https://github.com/Vchitect/Latte). +- [2024/11/28] 🎉 Release the [paper](https://arxiv.org/abs/2411.19108) of TeaCache. + ## Introduction We introduce Timestep Embedding Aware Cache (TeaCache), a training-free caching approach that estimates and leverages the fluctuating differences among model outputs across timesteps. For more details and visual results, please visit our [project page](https://github.com/LiewFeng/TeaCache). @@ -92,6 +97,7 @@ cd eval/teacache python experiments/latte.py python experiments/opensora.py python experiments/open_sora_plan.py +python experiments/cogvideox.py ``` 2. Calculate Vbench score @@ -116,19 +122,17 @@ python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb ## Citation +If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. ``` -@misc{liu2024timestep, - title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, - author={Feng Liu and Shiwei Zhang and Xiaofeng Wang and Yujie Wei and Haonan Qiu and Yuzhong Zhao and Yingya Zhang and Qixiang Ye and Fang Wan}, - year={2024}, - eprint={2411.19108}, - archivePrefix={arXiv}, - primaryClass={cs.CV}, - url={https://arxiv.org/abs/2411.19108} +@article{liu2024timestep, + title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, + author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, + journal={arXiv preprint arXiv:2411.19108}, + year={2024} } ``` ## Acknowledgement -This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys). Thanks for their contributions! +This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte) and [CogVideoX](https://github.com/THUDM/CogVideo). Thanks for their contributions! diff --git a/eval/teacache/experiments/cogvideox.py b/eval/teacache/experiments/cogvideox.py new file mode 100644 index 0000000..a22a13e --- /dev/null +++ b/eval/teacache/experiments/cogvideox.py @@ -0,0 +1,234 @@ +from utils import generate_func, read_prompt_list +from videosys import CogVideoXConfig, VideoSysEngine +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +import numpy as np +from typing import Any, Dict, Optional, Tuple, Union +from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence +from videosys.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput +from videosys.utils.utils import batch_func +from functools import partial +from diffusers.utils import is_torch_version + +def teacache_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + return_dict: bool = True, + all_timesteps=None + ): + if self.parallel_manager.cp_size > 1: + ( + hidden_states, + encoder_hidden_states, + timestep, + timestep_cond, + image_rotary_emb, + ) = batch_func( + partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), + hidden_states, + encoder_hidden_states, + timestep, + timestep_cond, + image_rotary_emb, + ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + org_timestep = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + + # 3. Position embedding + text_seq_length = encoder_hidden_states.shape[1] + if not self.config.use_rotary_positional_embeddings: + seq_length = height * width * num_frames // (self.config.patch_size**2) + + pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] + hidden_states = hidden_states + pos_embeds + hidden_states = self.embedding_dropout(hidden_states) + + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + if self.enable_teacache: + inp = hidden_states.clone() + encoder_hidden_states_ = encoder_hidden_states.clone() + emb_ = emb.clone() + _, modulated_inp, _, _ = self.transformer_blocks[0].norm1(inp, encoder_hidden_states_, emb_) + if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + coefficients = [1.42842830e+05, -3.99193393e+04, 3.85937428e+03, -1.49458838e+02, 2.04751119e+00] + else: + # CogVideoX-5B + coefficients = [1.80221813e+05, -5.37021537e+04, 5.61853221e+03, -2.44280388e+02, 3.83458338e+00] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + + if self.enable_teacache: + if not should_calc: + hidden_states += self.previous_residual + encoder_hidden_states += self.previous_residual_encoder + else: + if self.parallel_manager.sp_size > 1: + set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group) + hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) + ori_hidden_states = hidden_states.clone() + ori_encoder_hidden_states = encoder_hidden_states.clone() + # 4. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + timestep=timesteps if False else None, + ) + self.previous_residual = hidden_states - ori_hidden_states + self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states + else: + if self.parallel_manager.sp_size > 1: + set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group) + hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) + + # 4. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + timestep=timesteps if False else None, + ) + + if self.parallel_manager.sp_size > 1: + if self.enable_teacache: + if should_calc: + hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) + self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) + else: + hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 5. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 6. Unpatchify + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if self.parallel_manager.cp_size > 1: + output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +def eval_teacache_slow(prompt_list): + config = CogVideoXConfig() + engine = VideoSysEngine(config) + engine.driver_worker.transformer.enable_teacache = True + engine.driver_worker.transformer.rel_l1_thresh = 0.1 + engine.driver_worker.transformer.accumulated_rel_l1_distance = 0 + engine.driver_worker.transformer.previous_modulated_input = None + engine.driver_worker.transformer.previous_residual = None + engine.driver_worker.transformer.previous_residual_encoder = None + engine.driver_worker.transformer.__class__.forward = teacache_forward + generate_func(engine, prompt_list, "./samples/cogvideox_teacache_slow", loop=5) + +def eval_teacache_fast(prompt_list): + config = CogVideoXConfig() + engine = VideoSysEngine(config) + engine.driver_worker.transformer.enable_teacache = True + engine.driver_worker.transformer.rel_l1_thresh = 0.2 + engine.driver_worker.transformer.accumulated_rel_l1_distance = 0 + engine.driver_worker.transformer.previous_modulated_input = None + engine.driver_worker.transformer.previous_residual = None + engine.driver_worker.transformer.previous_residual_encoder = None + engine.driver_worker.transformer.__class__.forward = teacache_forward + generate_func(engine, prompt_list, "./samples/cogvideox_teacache_fast", loop=5) + + +def eval_base(prompt_list): + config = CogVideoXConfig() + engine = VideoSysEngine(config) + generate_func(engine, prompt_list, "./samples/cogvideox_base", loop=5) + + +if __name__ == "__main__": + prompt_list = read_prompt_list("vbench/VBench_full_info.json") + eval_base(prompt_list) + eval_teacache_slow(prompt_list) + eval_teacache_fast(prompt_list) + \ No newline at end of file diff --git a/eval/teacache/experiments/opensora_plan.py b/eval/teacache/experiments/opensora_plan.py index b056687..50b2250 100644 --- a/eval/teacache/experiments/opensora_plan.py +++ b/eval/teacache/experiments/opensora_plan.py @@ -6,6 +6,9 @@ from einops import rearrange, repeat import numpy as np from typing import Any, Dict, Optional, Tuple from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence +from videosys.models.transformers.open_sora_plan_v110_transformer_3d import Transformer3DModelOutput +from videosys.utils.utils import batch_func +from functools import partial def teacache_forward( self, diff --git a/eval/teacache/experiments/utils.py b/eval/teacache/experiments/utils.py index cb52309..01e3813 100644 --- a/eval/teacache/experiments/utils.py +++ b/eval/teacache/experiments/utils.py @@ -10,8 +10,7 @@ def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict kwargs["verbose"] = False for prompt in tqdm.tqdm(prompt_list): for l in range(loop): - set_seed(l) - video = pipeline.generate(prompt, **kwargs).video[0] + video = pipeline.generate(prompt, seed=l, **kwargs).video[0] pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4")) diff --git a/videosys/models/transformers/cogvideox_transformer_3d.py b/videosys/models/transformers/cogvideox_transformer_3d.py index e568e06..fe12063 100644 --- a/videosys/models/transformers/cogvideox_transformer_3d.py +++ b/videosys/models/transformers/cogvideox_transformer_3d.py @@ -484,6 +484,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): timestep_cond: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, + all_timesteps=None ): if self.parallel_manager.cp_size > 1: ( diff --git a/videosys/pipelines/cogvideox/pipeline_cogvideox.py b/videosys/pipelines/cogvideox/pipeline_cogvideox.py index e8bd151..357a976 100644 --- a/videosys/pipelines/cogvideox/pipeline_cogvideox.py +++ b/videosys/pipelines/cogvideox/pipeline_cogvideox.py @@ -28,6 +28,7 @@ from videosys.schedulers.scheduling_ddim_cogvideox import CogVideoXDDIMScheduler from videosys.schedulers.scheduling_dpm_cogvideox import CogVideoXDPMScheduler from videosys.utils.logging import logger from videosys.utils.utils import save_video, set_seed +import tqdm class CogVideoXPABConfig(PABConfig): @@ -511,6 +512,7 @@ class CogVideoXPipeline(VideoSysPipeline): ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 226, + verbose=True ) -> Union[VideoSysPipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -675,10 +677,11 @@ class CogVideoXPipeline(VideoSysPipeline): # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - with self.progress_bar(total=num_inference_steps) as progress_bar: - # for DPM-solver++ - old_pred_original_sample = None - for i, t in enumerate(timesteps): + # with self.progress_bar(total=num_inference_steps) as progress_bar: + progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x) + # for DPM-solver++ + old_pred_original_sample = None + for i, t in progress_wrap(list(enumerate(timesteps))): if self.interrupt: continue @@ -693,6 +696,7 @@ class CogVideoXPipeline(VideoSysPipeline): hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, + all_timesteps=timesteps, image_rotary_emb=image_rotary_emb, return_dict=False, )[0] @@ -733,8 +737,8 @@ class CogVideoXPipeline(VideoSysPipeline): prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() + # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + # progress_bar.update() if not output_type == "latent": video = self.decode_latents(latents)