mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-09 04:44:23 +08:00
255 lines
9.5 KiB
Python
255 lines
9.5 KiB
Python
import os
|
|
import time
|
|
from pathlib import Path
|
|
from loguru import logger
|
|
from datetime import datetime
|
|
|
|
from hyvideo.utils.file_utils import save_videos_grid
|
|
from hyvideo.config import parse_args
|
|
from hyvideo.inference import HunyuanVideoSampler
|
|
|
|
from hyvideo.modules.modulate_layers import modulate
|
|
from hyvideo.modules.attenion import attention, parallel_attention, get_cu_seqlens
|
|
from typing import Any, List, Tuple, Optional, Union, Dict
|
|
import torch
|
|
import json
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
def teacache_forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
t: torch.Tensor, # Should be in range(0, 1000).
|
|
text_states: torch.Tensor = None,
|
|
text_mask: torch.Tensor = None, # Now we don't use it.
|
|
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
|
|
freqs_cos: Optional[torch.Tensor] = None,
|
|
freqs_sin: Optional[torch.Tensor] = None,
|
|
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
|
|
return_dict: bool = True,
|
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
out = {}
|
|
img = x
|
|
txt = text_states
|
|
_, _, ot, oh, ow = x.shape
|
|
tt, th, tw = (
|
|
ot // self.patch_size[0],
|
|
oh // self.patch_size[1],
|
|
ow // self.patch_size[2],
|
|
)
|
|
|
|
# Prepare modulation vectors.
|
|
vec = self.time_in(t)
|
|
|
|
# text modulation
|
|
vec = vec + self.vector_in(text_states_2)
|
|
|
|
# guidance modulation
|
|
if self.guidance_embed:
|
|
if guidance is None:
|
|
raise ValueError(
|
|
"Didn't get guidance strength for guidance distilled model."
|
|
)
|
|
|
|
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
|
|
vec = vec + self.guidance_in(guidance)
|
|
|
|
# Embed image and text.
|
|
img = self.img_in(img)
|
|
if self.text_projection == "linear":
|
|
txt = self.txt_in(txt)
|
|
elif self.text_projection == "single_refiner":
|
|
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Unsupported text_projection: {self.text_projection}"
|
|
)
|
|
|
|
txt_seq_len = txt.shape[1]
|
|
img_seq_len = img.shape[1]
|
|
|
|
# Compute cu_squlens and max_seqlen for flash attention
|
|
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
|
|
cu_seqlens_kv = cu_seqlens_q
|
|
max_seqlen_q = img_seq_len + txt_seq_len
|
|
max_seqlen_kv = max_seqlen_q
|
|
|
|
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
|
|
|
if self.enable_teacache:
|
|
inp = img.clone()
|
|
vec_ = vec.clone()
|
|
txt_ = txt.clone()
|
|
(
|
|
img_mod1_shift,
|
|
img_mod1_scale,
|
|
img_mod1_gate,
|
|
img_mod2_shift,
|
|
img_mod2_scale,
|
|
img_mod2_gate,
|
|
) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
|
|
normed_inp = self.double_blocks[0].img_norm1(inp)
|
|
modulated_inp = modulate(
|
|
normed_inp, shift=img_mod1_shift, scale=img_mod1_scale
|
|
)
|
|
if self.cnt == 0 or self.cnt == self.num_steps-1:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
else:
|
|
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
|
rescale_func = np.poly1d(coefficients)
|
|
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
|
should_calc = False
|
|
else:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
self.previous_modulated_input = modulated_inp
|
|
self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1
|
|
|
|
if self.enable_teacache:
|
|
if not should_calc:
|
|
img += self.previous_residual
|
|
else:
|
|
ori_img = img.clone()
|
|
# --------------------- Pass through DiT blocks ------------------------
|
|
for _, block in enumerate(self.double_blocks):
|
|
double_block_args = [
|
|
img,
|
|
txt,
|
|
vec,
|
|
cu_seqlens_q,
|
|
cu_seqlens_kv,
|
|
max_seqlen_q,
|
|
max_seqlen_kv,
|
|
freqs_cis,
|
|
]
|
|
|
|
img, txt = block(*double_block_args)
|
|
|
|
# Merge txt and img to pass through single stream blocks.
|
|
x = torch.cat((img, txt), 1)
|
|
if len(self.single_blocks) > 0:
|
|
for _, block in enumerate(self.single_blocks):
|
|
single_block_args = [
|
|
x,
|
|
vec,
|
|
txt_seq_len,
|
|
cu_seqlens_q,
|
|
cu_seqlens_kv,
|
|
max_seqlen_q,
|
|
max_seqlen_kv,
|
|
(freqs_cos, freqs_sin),
|
|
]
|
|
|
|
x = block(*single_block_args)
|
|
|
|
img = x[:, :img_seq_len, ...]
|
|
self.previous_residual = img - ori_img
|
|
else:
|
|
# --------------------- Pass through DiT blocks ------------------------
|
|
for _, block in enumerate(self.double_blocks):
|
|
double_block_args = [
|
|
img,
|
|
txt,
|
|
vec,
|
|
cu_seqlens_q,
|
|
cu_seqlens_kv,
|
|
max_seqlen_q,
|
|
max_seqlen_kv,
|
|
freqs_cis,
|
|
]
|
|
|
|
img, txt = block(*double_block_args)
|
|
|
|
# Merge txt and img to pass through single stream blocks.
|
|
x = torch.cat((img, txt), 1)
|
|
if len(self.single_blocks) > 0:
|
|
for _, block in enumerate(self.single_blocks):
|
|
single_block_args = [
|
|
x,
|
|
vec,
|
|
txt_seq_len,
|
|
cu_seqlens_q,
|
|
cu_seqlens_kv,
|
|
max_seqlen_q,
|
|
max_seqlen_kv,
|
|
(freqs_cos, freqs_sin),
|
|
]
|
|
|
|
x = block(*single_block_args)
|
|
|
|
img = x[:, :img_seq_len, ...]
|
|
|
|
# ---------------------------- Final layer ------------------------------
|
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
|
|
|
img = self.unpatchify(img, tt, th, tw)
|
|
if return_dict:
|
|
out["x"] = img
|
|
return out
|
|
return img
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
print(args)
|
|
models_root_path = Path(args.model_base)
|
|
if not models_root_path.exists():
|
|
raise ValueError(f"`models_root` not exists: {models_root_path}")
|
|
|
|
# Create save folder to save the samples
|
|
save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
|
|
if not os.path.exists(args.save_path):
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
# Load models
|
|
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
|
|
|
|
# Get the updated args
|
|
args = hunyuan_video_sampler.args
|
|
|
|
|
|
# TeaCache
|
|
hunyuan_video_sampler.pipeline.transformer.__class__.enable_teacache = True
|
|
hunyuan_video_sampler.pipeline.transformer.__class__.cnt = 0
|
|
hunyuan_video_sampler.pipeline.transformer.__class__.num_steps = args.infer_steps - 1
|
|
hunyuan_video_sampler.pipeline.transformer.__class__.rel_l1_thresh = 0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
|
|
hunyuan_video_sampler.pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
|
|
hunyuan_video_sampler.pipeline.transformer.__class__.previous_modulated_input = None
|
|
hunyuan_video_sampler.pipeline.transformer.__class__.previous_residual = None
|
|
hunyuan_video_sampler.pipeline.transformer.__class__.forward = teacache_forward
|
|
|
|
# Start sampling
|
|
# TODO: batch inference check
|
|
outputs = hunyuan_video_sampler.predict(
|
|
prompt=args.prompt,
|
|
height=args.video_size[0],
|
|
width=args.video_size[1],
|
|
video_length=args.video_length,
|
|
seed=args.seed,
|
|
negative_prompt=args.neg_prompt,
|
|
infer_steps=args.infer_steps,
|
|
guidance_scale=args.cfg_scale,
|
|
num_videos_per_prompt=args.num_videos,
|
|
flow_shift=args.flow_shift,
|
|
batch_size=args.batch_size,
|
|
embedded_guidance_scale=args.embedded_cfg_scale
|
|
)
|
|
samples = outputs['samples']
|
|
|
|
# Save samples
|
|
if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
|
|
for i, sample in enumerate(samples):
|
|
sample = samples[i].unsqueeze(0)
|
|
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
|
|
save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
|
|
save_videos_grid(sample, save_path, fps=24)
|
|
logger.info(f'Sample save to: {save_path}')
|
|
|
|
if __name__ == "__main__":
|
|
main()
|