mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-08 20:34:24 +08:00
support cogvideox
This commit is contained in:
parent
21ea48e71a
commit
30bf3cba88
20
README.md
20
README.md
@ -53,6 +53,11 @@
|
||||
|
||||

|
||||
|
||||
## Latest News 🔥
|
||||
- [2024/12/19] 🔥 Support [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
- [2024/12/06] 🎉 Release the [code](https://github.com/LiewFeng/TeaCache) TeaCache. Support [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) and [Latte](https://github.com/Vchitect/Latte).
|
||||
- [2024/11/28] 🎉 Release the [paper](https://arxiv.org/abs/2411.19108) of TeaCache.
|
||||
|
||||
## Introduction
|
||||
We introduce Timestep Embedding Aware Cache (TeaCache), a training-free caching approach that estimates and leverages the fluctuating differences among model outputs across timesteps. For more details and visual results, please visit our [project page](https://github.com/LiewFeng/TeaCache).
|
||||
|
||||
@ -92,6 +97,7 @@ cd eval/teacache
|
||||
python experiments/latte.py
|
||||
python experiments/opensora.py
|
||||
python experiments/open_sora_plan.py
|
||||
python experiments/cogvideox.py
|
||||
```
|
||||
|
||||
2. Calculate Vbench score
|
||||
@ -116,19 +122,17 @@ python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb
|
||||
|
||||
|
||||
## Citation
|
||||
If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
|
||||
|
||||
```
|
||||
@misc{liu2024timestep,
|
||||
@article{liu2024timestep,
|
||||
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
|
||||
author={Feng Liu and Shiwei Zhang and Xiaofeng Wang and Yujie Wei and Haonan Qiu and Yuzhong Zhao and Yingya Zhang and Qixiang Ye and Fang Wan},
|
||||
year={2024},
|
||||
eprint={2411.19108},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV},
|
||||
url={https://arxiv.org/abs/2411.19108}
|
||||
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
|
||||
journal={arXiv preprint arXiv:2411.19108},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys). Thanks for their contributions!
|
||||
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte) and [CogVideoX](https://github.com/THUDM/CogVideo). Thanks for their contributions!
|
||||
|
||||
234
eval/teacache/experiments/cogvideox.py
Normal file
234
eval/teacache/experiments/cogvideox.py
Normal file
@ -0,0 +1,234 @@
|
||||
from utils import generate_func, read_prompt_list
|
||||
from videosys import CogVideoXConfig, VideoSysEngine
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
|
||||
from videosys.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput
|
||||
from videosys.utils.utils import batch_func
|
||||
from functools import partial
|
||||
from diffusers.utils import is_torch_version
|
||||
|
||||
def teacache_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
all_timesteps=None
|
||||
):
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep,
|
||||
timestep_cond,
|
||||
image_rotary_emb,
|
||||
) = batch_func(
|
||||
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep,
|
||||
timestep_cond,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. Time embedding
|
||||
timesteps = timestep
|
||||
org_timestep = timestep
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
# 2. Patch embedding
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
|
||||
# 3. Position embedding
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
|
||||
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
if self.enable_teacache:
|
||||
inp = hidden_states.clone()
|
||||
encoder_hidden_states_ = encoder_hidden_states.clone()
|
||||
emb_ = emb.clone()
|
||||
_, modulated_inp, _, _ = self.transformer_blocks[0].norm1(inp, encoder_hidden_states_, emb_)
|
||||
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
coefficients = [1.42842830e+05, -3.99193393e+04, 3.85937428e+03, -1.49458838e+02, 2.04751119e+00]
|
||||
else:
|
||||
# CogVideoX-5B
|
||||
coefficients = [1.80221813e+05, -5.37021537e+04, 5.61853221e+03, -2.44280388e+02, 3.83458338e+00]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
|
||||
if self.enable_teacache:
|
||||
if not should_calc:
|
||||
hidden_states += self.previous_residual
|
||||
encoder_hidden_states += self.previous_residual_encoder
|
||||
else:
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
|
||||
hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||
ori_hidden_states = hidden_states.clone()
|
||||
ori_encoder_hidden_states = encoder_hidden_states.clone()
|
||||
# 4. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
timestep=timesteps if False else None,
|
||||
)
|
||||
self.previous_residual = hidden_states - ori_hidden_states
|
||||
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
|
||||
else:
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
|
||||
hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||
|
||||
# 4. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
timestep=timesteps if False else None,
|
||||
)
|
||||
|
||||
if self.parallel_manager.sp_size > 1:
|
||||
if self.enable_teacache:
|
||||
if should_calc:
|
||||
hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||
self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||
else:
|
||||
hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
else:
|
||||
# CogVideoX-5B
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 5. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 6. Unpatchify
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
|
||||
def eval_teacache_slow(prompt_list):
|
||||
config = CogVideoXConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
engine.driver_worker.transformer.enable_teacache = True
|
||||
engine.driver_worker.transformer.rel_l1_thresh = 0.1
|
||||
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
|
||||
engine.driver_worker.transformer.previous_modulated_input = None
|
||||
engine.driver_worker.transformer.previous_residual = None
|
||||
engine.driver_worker.transformer.previous_residual_encoder = None
|
||||
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
||||
generate_func(engine, prompt_list, "./samples/cogvideox_teacache_slow", loop=5)
|
||||
|
||||
def eval_teacache_fast(prompt_list):
|
||||
config = CogVideoXConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
engine.driver_worker.transformer.enable_teacache = True
|
||||
engine.driver_worker.transformer.rel_l1_thresh = 0.2
|
||||
engine.driver_worker.transformer.accumulated_rel_l1_distance = 0
|
||||
engine.driver_worker.transformer.previous_modulated_input = None
|
||||
engine.driver_worker.transformer.previous_residual = None
|
||||
engine.driver_worker.transformer.previous_residual_encoder = None
|
||||
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
||||
generate_func(engine, prompt_list, "./samples/cogvideox_teacache_fast", loop=5)
|
||||
|
||||
|
||||
def eval_base(prompt_list):
|
||||
config = CogVideoXConfig()
|
||||
engine = VideoSysEngine(config)
|
||||
generate_func(engine, prompt_list, "./samples/cogvideox_base", loop=5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
||||
eval_base(prompt_list)
|
||||
eval_teacache_slow(prompt_list)
|
||||
eval_teacache_fast(prompt_list)
|
||||
|
||||
@ -6,6 +6,9 @@ from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
|
||||
from videosys.models.transformers.open_sora_plan_v110_transformer_3d import Transformer3DModelOutput
|
||||
from videosys.utils.utils import batch_func
|
||||
from functools import partial
|
||||
|
||||
def teacache_forward(
|
||||
self,
|
||||
|
||||
@ -10,8 +10,7 @@ def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict
|
||||
kwargs["verbose"] = False
|
||||
for prompt in tqdm.tqdm(prompt_list):
|
||||
for l in range(loop):
|
||||
set_seed(l)
|
||||
video = pipeline.generate(prompt, **kwargs).video[0]
|
||||
video = pipeline.generate(prompt, seed=l, **kwargs).video[0]
|
||||
pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4"))
|
||||
|
||||
|
||||
|
||||
@ -484,6 +484,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
all_timesteps=None
|
||||
):
|
||||
if self.parallel_manager.cp_size > 1:
|
||||
(
|
||||
|
||||
@ -28,6 +28,7 @@ from videosys.schedulers.scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
|
||||
from videosys.schedulers.scheduling_dpm_cogvideox import CogVideoXDPMScheduler
|
||||
from videosys.utils.logging import logger
|
||||
from videosys.utils.utils import save_video, set_seed
|
||||
import tqdm
|
||||
|
||||
|
||||
class CogVideoXPABConfig(PABConfig):
|
||||
@ -511,6 +512,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 226,
|
||||
verbose=True
|
||||
) -> Union[VideoSysPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@ -675,10 +677,11 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
# with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
|
||||
# for DPM-solver++
|
||||
old_pred_original_sample = None
|
||||
for i, t in enumerate(timesteps):
|
||||
for i, t in progress_wrap(list(enumerate(timesteps))):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
@ -693,6 +696,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
all_timesteps=timesteps,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@ -733,8 +737,8 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
# progress_bar.update()
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user