mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-10 05:14:24 +08:00
230 lines
11 KiB
Python
230 lines
11 KiB
Python
from utils import generate_func, read_prompt_list
|
|
from videosys import CogVideoXConfig, VideoSysEngine
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from einops import rearrange, repeat
|
|
import numpy as np
|
|
from typing import Any, Dict, Optional, Tuple, Union
|
|
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
|
|
from videosys.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput
|
|
from videosys.utils.utils import batch_func
|
|
from functools import partial
|
|
from diffusers.utils import is_torch_version
|
|
|
|
def teacache_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor,
|
|
timestep: Union[int, float, torch.LongTensor],
|
|
timestep_cond: Optional[torch.Tensor] = None,
|
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
return_dict: bool = True,
|
|
all_timesteps=None
|
|
):
|
|
if self.parallel_manager.cp_size > 1:
|
|
(
|
|
hidden_states,
|
|
encoder_hidden_states,
|
|
timestep,
|
|
timestep_cond,
|
|
image_rotary_emb,
|
|
) = batch_func(
|
|
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
|
hidden_states,
|
|
encoder_hidden_states,
|
|
timestep,
|
|
timestep_cond,
|
|
image_rotary_emb,
|
|
)
|
|
|
|
batch_size, num_frames, channels, height, width = hidden_states.shape
|
|
|
|
# 1. Time embedding
|
|
timesteps = timestep
|
|
org_timestep = timestep
|
|
t_emb = self.time_proj(timesteps)
|
|
|
|
# timesteps does not contain any weights and will always return f32 tensors
|
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
|
# there might be better ways to encapsulate this.
|
|
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
|
emb = self.time_embedding(t_emb, timestep_cond)
|
|
|
|
# 2. Patch embedding
|
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
|
|
|
# 3. Position embedding
|
|
text_seq_length = encoder_hidden_states.shape[1]
|
|
if not self.config.use_rotary_positional_embeddings:
|
|
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
|
|
|
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
|
hidden_states = hidden_states + pos_embeds
|
|
hidden_states = self.embedding_dropout(hidden_states)
|
|
|
|
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
|
hidden_states = hidden_states[:, text_seq_length:]
|
|
|
|
if self.enable_teacache:
|
|
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
else:
|
|
if not self.config.use_rotary_positional_embeddings:
|
|
# CogVideoX-2B
|
|
coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03]
|
|
else:
|
|
# CogVideoX-5B
|
|
coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02]
|
|
rescale_func = np.poly1d(coefficients)
|
|
self.accumulated_rel_l1_distance += rescale_func(((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
|
should_calc = False
|
|
else:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
self.previous_modulated_input = emb
|
|
|
|
if self.enable_teacache:
|
|
if not should_calc:
|
|
hidden_states += self.previous_residual
|
|
encoder_hidden_states += self.previous_residual_encoder
|
|
else:
|
|
if self.parallel_manager.sp_size > 1:
|
|
set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
|
|
hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
|
ori_hidden_states = hidden_states.clone()
|
|
ori_encoder_hidden_states = encoder_hidden_states.clone()
|
|
# 4. Transformer blocks
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
if self.training and self.gradient_checkpointing:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
|
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(block),
|
|
hidden_states,
|
|
encoder_hidden_states,
|
|
emb,
|
|
image_rotary_emb,
|
|
**ckpt_kwargs,
|
|
)
|
|
else:
|
|
hidden_states, encoder_hidden_states = block(
|
|
hidden_states=hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
temb=emb,
|
|
image_rotary_emb=image_rotary_emb,
|
|
timestep=timesteps if False else None,
|
|
)
|
|
self.previous_residual = hidden_states - ori_hidden_states
|
|
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
|
|
else:
|
|
if self.parallel_manager.sp_size > 1:
|
|
set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
|
|
hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
|
|
|
# 4. Transformer blocks
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
if self.training and self.gradient_checkpointing:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
|
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(block),
|
|
hidden_states,
|
|
encoder_hidden_states,
|
|
emb,
|
|
image_rotary_emb,
|
|
**ckpt_kwargs,
|
|
)
|
|
else:
|
|
hidden_states, encoder_hidden_states = block(
|
|
hidden_states=hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
temb=emb,
|
|
image_rotary_emb=image_rotary_emb,
|
|
timestep=timesteps if False else None,
|
|
)
|
|
|
|
if self.parallel_manager.sp_size > 1:
|
|
if self.enable_teacache:
|
|
if should_calc:
|
|
hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
|
self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
|
else:
|
|
hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
|
|
|
if not self.config.use_rotary_positional_embeddings:
|
|
# CogVideoX-2B
|
|
hidden_states = self.norm_final(hidden_states)
|
|
else:
|
|
# CogVideoX-5B
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
hidden_states = self.norm_final(hidden_states)
|
|
hidden_states = hidden_states[:, text_seq_length:]
|
|
|
|
# 5. Final block
|
|
hidden_states = self.norm_out(hidden_states, temb=emb)
|
|
hidden_states = self.proj_out(hidden_states)
|
|
|
|
# 6. Unpatchify
|
|
p = self.config.patch_size
|
|
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
|
|
|
if self.parallel_manager.cp_size > 1:
|
|
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
|
|
|
|
if not return_dict:
|
|
return (output,)
|
|
return Transformer2DModelOutput(sample=output)
|
|
|
|
|
|
def eval_teacache_slow(prompt_list):
|
|
config = CogVideoXConfig()
|
|
engine = VideoSysEngine(config)
|
|
engine.driver_worker.transformer.__class__.enable_teacache = True
|
|
engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.1
|
|
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
|
|
engine.driver_worker.transformer.__class__.previous_modulated_input = None
|
|
engine.driver_worker.transformer.__class__.previous_residual = None
|
|
engine.driver_worker.transformer.__class__.previous_residual_encoder = None
|
|
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
|
generate_func(engine, prompt_list, "./samples/cogvideox_teacache_slow", loop=5)
|
|
|
|
def eval_teacache_fast(prompt_list):
|
|
config = CogVideoXConfig()
|
|
engine = VideoSysEngine(config)
|
|
engine.driver_worker.transformer.__class__.enable_teacache = True
|
|
engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.2
|
|
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
|
|
engine.driver_worker.transformer.__class__.previous_modulated_input = None
|
|
engine.driver_worker.transformer.__class__.previous_residual = None
|
|
engine.driver_worker.transformer.__class__.previous_residual_encoder = None
|
|
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
|
generate_func(engine, prompt_list, "./samples/cogvideox_teacache_fast", loop=5)
|
|
|
|
|
|
def eval_base(prompt_list):
|
|
config = CogVideoXConfig()
|
|
engine = VideoSysEngine(config)
|
|
generate_func(engine, prompt_list, "./samples/cogvideox_base", loop=5)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
|
eval_base(prompt_list)
|
|
eval_teacache_slow(prompt_list)
|
|
eval_teacache_fast(prompt_list)
|
|
|