2024-12-27 10:55:22 +08:00

239 lines
10 KiB
Python

from utils import generate_func, read_prompt_list
from videosys import OpenSoraConfig, VideoSysEngine
import torch
from einops import rearrange
from videosys.models.transformers.open_sora_transformer_3d import t2i_modulate, auto_grad_checkpoint
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
import numpy as np
from videosys.utils.utils import batch_func
from functools import partial
def teacache_forward(
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
):
# === Split batch ===
if self.parallel_manager.cp_size > 1:
x, timestep, y, x_mask, mask = batch_func(
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
x,
timestep,
y,
x_mask,
mask,
)
dtype = self.x_embedder.proj.weight.dtype
B = x.size(0)
x = x.to(dtype)
timestep = timestep.to(dtype)
y = y.to(dtype)
# === get pos embed ===
_, _, Tx, Hx, Wx = x.size()
T, H, W = self.get_dynamic_size(x)
S = H * W
base_size = round(S**0.5)
resolution_sq = (height[0].item() * width[0].item()) ** 0.5
scale = resolution_sq / self.input_sq_size
pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size)
# === get timestep embed ===
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
fps = self.fps_embedder(fps.unsqueeze(1), B)
t = t + fps
t_mlp = self.t_block(t)
t0 = t0_mlp = None
if x_mask is not None:
t0_timestep = torch.zeros_like(timestep)
t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
t0 = t0 + fps
t0_mlp = self.t_block(t0)
# === get y embed ===
if self.config.skip_y_embedder:
y_lens = mask
if isinstance(y_lens, torch.Tensor):
y_lens = y_lens.long().tolist()
else:
y, y_lens = self.encode_text(y, mask)
# === get x embed ===
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = x + pos_emb
if self.enable_teacache:
inp = x.clone()
inp = rearrange(inp, "B T S C -> B (T S) C", T=T, S=S)
B, N, C = inp.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.spatial_blocks[0].scale_shift_table[None] + t_mlp.reshape(B, 6, -1)
).chunk(6, dim=1)
modulated_inp = t2i_modulate(self.spatial_blocks[0].norm1(inp), shift_msa, scale_msa)
if timestep[0] == all_timesteps[0] or timestep[0] == all_timesteps[-1]:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [2.17546007e+02, -1.18329252e+02, 2.68662585e+01, -4.59364272e-02, 4.84426240e-02]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
# === blocks ===
if self.enable_teacache:
if not should_calc:
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
x += self.previous_residual
else:
# shard over the sequence dim if sp is enabled
if self.parallel_manager.sp_size > 1:
set_pad("temporal", T, self.parallel_manager.sp_group)
set_pad("spatial", S, self.parallel_manager.sp_group)
x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
T = x.shape[1]
x_mask_org = x_mask
x_mask = split_sequence(
x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
origin_x = x.clone().detach()
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
x = auto_grad_checkpoint(
spatial_block,
x,
y,
t_mlp,
y_lens,
x_mask,
t0_mlp,
T,
S,
timestep,
all_timesteps=all_timesteps,
)
x = auto_grad_checkpoint(
temporal_block,
x,
y,
t_mlp,
y_lens,
x_mask,
t0_mlp,
T,
S,
timestep,
all_timesteps=all_timesteps,
)
self.previous_residual = x - origin_x
else:
# shard over the sequence dim if sp is enabled
if self.parallel_manager.sp_size > 1:
set_pad("temporal", T, self.parallel_manager.sp_group)
set_pad("spatial", S, self.parallel_manager.sp_group)
x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
T = x.shape[1]
x_mask_org = x_mask
x_mask = split_sequence(
x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
x = auto_grad_checkpoint(
spatial_block,
x,
y,
t_mlp,
y_lens,
x_mask,
t0_mlp,
T,
S,
timestep,
all_timesteps=all_timesteps,
)
x = auto_grad_checkpoint(
temporal_block,
x,
y,
t_mlp,
y_lens,
x_mask,
t0_mlp,
T,
S,
timestep,
all_timesteps=all_timesteps,
)
if self.parallel_manager.sp_size > 1:
if self.enable_teacache:
if should_calc:
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
self.previous_residual = rearrange(self.previous_residual, "B (T S) C -> B T S C", T=T, S=S)
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
T, S = x.shape[1], x.shape[2]
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
self.previous_residual = rearrange(self.previous_residual, "B T S C -> B (T S) C", T=T, S=S)
x_mask = x_mask_org
else:
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
T, S = x.shape[1], x.shape[2]
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
x_mask = x_mask_org
# === final layer ===
x = self.final_layer(x, t, x_mask, t0, T, S)
x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
# cast to float32 for better accuracy
x = x.to(torch.float32)
# === Gather Output ===
if self.parallel_manager.cp_size > 1:
x = gather_sequence(x, self.parallel_manager.cp_group, dim=0)
return x
def eval_base(prompt_list):
config = OpenSoraConfig()
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensora_base", loop=5)
def eval_teacache_slow(prompt_list):
config = OpenSoraConfig()
engine = VideoSysEngine(config)
engine.driver_worker.transformer.__class__.enable_teacache = True
engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.1
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensora_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
config = OpenSoraConfig()
engine = VideoSysEngine(config)
engine.driver_worker.transformer.__class__.enable_teacache = True
engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.2
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensora_teacache_fast", loop=5)
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
eval_base(prompt_list)
eval_teacache_slow(prompt_list)
eval_teacache_fast(prompt_list)