mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-09 04:44:23 +08:00
239 lines
10 KiB
Python
239 lines
10 KiB
Python
from utils import generate_func, read_prompt_list
|
|
from videosys import OpenSoraConfig, VideoSysEngine
|
|
import torch
|
|
from einops import rearrange
|
|
from videosys.models.transformers.open_sora_transformer_3d import t2i_modulate, auto_grad_checkpoint
|
|
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
|
|
import numpy as np
|
|
from videosys.utils.utils import batch_func
|
|
from functools import partial
|
|
|
|
def teacache_forward(
|
|
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
|
|
):
|
|
# === Split batch ===
|
|
if self.parallel_manager.cp_size > 1:
|
|
x, timestep, y, x_mask, mask = batch_func(
|
|
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
|
x,
|
|
timestep,
|
|
y,
|
|
x_mask,
|
|
mask,
|
|
)
|
|
|
|
dtype = self.x_embedder.proj.weight.dtype
|
|
B = x.size(0)
|
|
x = x.to(dtype)
|
|
timestep = timestep.to(dtype)
|
|
y = y.to(dtype)
|
|
|
|
# === get pos embed ===
|
|
_, _, Tx, Hx, Wx = x.size()
|
|
T, H, W = self.get_dynamic_size(x)
|
|
S = H * W
|
|
base_size = round(S**0.5)
|
|
resolution_sq = (height[0].item() * width[0].item()) ** 0.5
|
|
scale = resolution_sq / self.input_sq_size
|
|
pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size)
|
|
|
|
# === get timestep embed ===
|
|
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
|
|
fps = self.fps_embedder(fps.unsqueeze(1), B)
|
|
t = t + fps
|
|
t_mlp = self.t_block(t)
|
|
t0 = t0_mlp = None
|
|
if x_mask is not None:
|
|
t0_timestep = torch.zeros_like(timestep)
|
|
t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
|
|
t0 = t0 + fps
|
|
t0_mlp = self.t_block(t0)
|
|
|
|
# === get y embed ===
|
|
if self.config.skip_y_embedder:
|
|
y_lens = mask
|
|
if isinstance(y_lens, torch.Tensor):
|
|
y_lens = y_lens.long().tolist()
|
|
else:
|
|
y, y_lens = self.encode_text(y, mask)
|
|
|
|
# === get x embed ===
|
|
x = self.x_embedder(x) # [B, N, C]
|
|
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
|
x = x + pos_emb
|
|
|
|
if self.enable_teacache:
|
|
inp = x.clone()
|
|
inp = rearrange(inp, "B T S C -> B (T S) C", T=T, S=S)
|
|
B, N, C = inp.shape
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
|
self.spatial_blocks[0].scale_shift_table[None] + t_mlp.reshape(B, 6, -1)
|
|
).chunk(6, dim=1)
|
|
modulated_inp = t2i_modulate(self.spatial_blocks[0].norm1(inp), shift_msa, scale_msa)
|
|
if timestep[0] == all_timesteps[0] or timestep[0] == all_timesteps[-1]:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
else:
|
|
coefficients = [2.17546007e+02, -1.18329252e+02, 2.68662585e+01, -4.59364272e-02, 4.84426240e-02]
|
|
rescale_func = np.poly1d(coefficients)
|
|
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
|
should_calc = False
|
|
else:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
self.previous_modulated_input = modulated_inp
|
|
|
|
# === blocks ===
|
|
if self.enable_teacache:
|
|
if not should_calc:
|
|
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
|
x += self.previous_residual
|
|
else:
|
|
# shard over the sequence dim if sp is enabled
|
|
if self.parallel_manager.sp_size > 1:
|
|
set_pad("temporal", T, self.parallel_manager.sp_group)
|
|
set_pad("spatial", S, self.parallel_manager.sp_group)
|
|
x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
|
|
T = x.shape[1]
|
|
x_mask_org = x_mask
|
|
x_mask = split_sequence(
|
|
x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
|
|
)
|
|
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
|
origin_x = x.clone().detach()
|
|
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
|
|
x = auto_grad_checkpoint(
|
|
spatial_block,
|
|
x,
|
|
y,
|
|
t_mlp,
|
|
y_lens,
|
|
x_mask,
|
|
t0_mlp,
|
|
T,
|
|
S,
|
|
timestep,
|
|
all_timesteps=all_timesteps,
|
|
)
|
|
|
|
x = auto_grad_checkpoint(
|
|
temporal_block,
|
|
x,
|
|
y,
|
|
t_mlp,
|
|
y_lens,
|
|
x_mask,
|
|
t0_mlp,
|
|
T,
|
|
S,
|
|
timestep,
|
|
all_timesteps=all_timesteps,
|
|
)
|
|
self.previous_residual = x - origin_x
|
|
else:
|
|
# shard over the sequence dim if sp is enabled
|
|
if self.parallel_manager.sp_size > 1:
|
|
set_pad("temporal", T, self.parallel_manager.sp_group)
|
|
set_pad("spatial", S, self.parallel_manager.sp_group)
|
|
x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
|
|
T = x.shape[1]
|
|
x_mask_org = x_mask
|
|
x_mask = split_sequence(
|
|
x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
|
|
)
|
|
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
|
|
|
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
|
|
x = auto_grad_checkpoint(
|
|
spatial_block,
|
|
x,
|
|
y,
|
|
t_mlp,
|
|
y_lens,
|
|
x_mask,
|
|
t0_mlp,
|
|
T,
|
|
S,
|
|
timestep,
|
|
all_timesteps=all_timesteps,
|
|
)
|
|
|
|
x = auto_grad_checkpoint(
|
|
temporal_block,
|
|
x,
|
|
y,
|
|
t_mlp,
|
|
y_lens,
|
|
x_mask,
|
|
t0_mlp,
|
|
T,
|
|
S,
|
|
timestep,
|
|
all_timesteps=all_timesteps,
|
|
)
|
|
|
|
if self.parallel_manager.sp_size > 1:
|
|
if self.enable_teacache:
|
|
if should_calc:
|
|
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
|
self.previous_residual = rearrange(self.previous_residual, "B (T S) C -> B T S C", T=T, S=S)
|
|
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
|
|
self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
|
|
T, S = x.shape[1], x.shape[2]
|
|
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
|
self.previous_residual = rearrange(self.previous_residual, "B T S C -> B (T S) C", T=T, S=S)
|
|
x_mask = x_mask_org
|
|
else:
|
|
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
|
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
|
|
T, S = x.shape[1], x.shape[2]
|
|
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
|
|
x_mask = x_mask_org
|
|
# === final layer ===
|
|
x = self.final_layer(x, t, x_mask, t0, T, S)
|
|
x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
|
|
|
|
# cast to float32 for better accuracy
|
|
x = x.to(torch.float32)
|
|
|
|
# === Gather Output ===
|
|
if self.parallel_manager.cp_size > 1:
|
|
x = gather_sequence(x, self.parallel_manager.cp_group, dim=0)
|
|
|
|
return x
|
|
|
|
def eval_base(prompt_list):
|
|
config = OpenSoraConfig()
|
|
engine = VideoSysEngine(config)
|
|
generate_func(engine, prompt_list, "./samples/opensora_base", loop=5)
|
|
|
|
def eval_teacache_slow(prompt_list):
|
|
config = OpenSoraConfig()
|
|
engine = VideoSysEngine(config)
|
|
engine.driver_worker.transformer.__class__.enable_teacache = True
|
|
engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.1
|
|
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
|
|
engine.driver_worker.transformer.__class__.previous_modulated_input = None
|
|
engine.driver_worker.transformer.__class__.previous_residual = None
|
|
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
|
generate_func(engine, prompt_list, "./samples/opensora_teacache_slow", loop=5)
|
|
|
|
def eval_teacache_fast(prompt_list):
|
|
config = OpenSoraConfig()
|
|
engine = VideoSysEngine(config)
|
|
engine.driver_worker.transformer.__class__.enable_teacache = True
|
|
engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.2
|
|
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
|
|
engine.driver_worker.transformer.__class__.previous_modulated_input = None
|
|
engine.driver_worker.transformer.__class__.previous_residual = None
|
|
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
|
generate_func(engine, prompt_list, "./samples/opensora_teacache_fast", loop=5)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
|
eval_base(prompt_list)
|
|
eval_teacache_slow(prompt_list)
|
|
eval_teacache_fast(prompt_list)
|
|
|