From 2295d6a10ce17b11eb00e16f26bdf03c2b499252 Mon Sep 17 00:00:00 2001 From: zishen-ucap Date: Thu, 13 Mar 2025 16:18:28 +0800 Subject: [PATCH 1/3] Support parameter - use_det_stpes --- TeaCache4Wan2.1/README.md | 43 ++ TeaCache4Wan2.1/teacache_generate.py | 688 ++++++++++++++------------- 2 files changed, 412 insertions(+), 319 deletions(-) diff --git a/TeaCache4Wan2.1/README.md b/TeaCache4Wan2.1/README.md index 9ef917e..05d4912 100644 --- a/TeaCache4Wan2.1/README.md +++ b/TeaCache4Wan2.1/README.md @@ -57,6 +57,49 @@ For I2V with 720P resolution, you can use the following command: python teacache_generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 42 --offload_model True --t5_cpu --frame_num 61 --teacache_thresh 0.3 ``` +## Faster Video Generation Using the `use_ret_steps` Parameter + +Using Retention Steps will result in faster generation speed and better generation quality (except for t2v-1.3B). + +https://github.com/user-attachments/assets/f241b5f5-1044-4223-b2a4-449dc6dc1ad7 + +https://github.com/user-attachments/assets/01db60f9-4aaf-43c4-8f1b-6e050cfa1180 + +https://github.com/user-attachments/assets/e03621f2-1085-4571-8eca-51889f47ce18 + +https://github.com/user-attachments/assets/d1340197-20c1-4f9e-a780-31f789af0893 + + +| use_ref_steps | Wan2.1 t2v 1.3B (thresh) | Slow (thresh) | Fast (thresh) | +|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| +| False | ~97 s (0.00) | ~64 s (0.05) | ~49 s (0.08) | +| True | ~97 s (0.00) | ~61 s (0.05) | ~41 s (0.10) | + +| use_ref_steps | Wan2.1 t2v 14B (thresh) | Slow (thresh) | Fast (thresh) | +|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| +| False | ~1829 s (0.00) | ~1234 s (0.14) | ~909 s (0.20) | +| True | ~1829 s (0.00) | ~915 s (0.10) | ~578 s (0.20) | + +| use_ref_steps | Wan2.1 i2v 480p (thresh) | Slow (thresh) | Fast (thresh) | +|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| +| False | ~385 s (0.00) | ~241 s (0.13) | ~156 s (0.26) | +| True | ~385 s (0.00) | ~212 s (0.20) | ~164 s (0.30) | + +| use_ref_steps | Wan2.1 i2v 720p (thresh) | Slow (thresh) | Fast (thresh) | +|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| +| False | ~903 s (0.00) | ~476 s (0.20) | ~363 s (0.30) | +| True | ~903 s (0.00) | ~430 s (0.20) | ~340 s (0.30) | + + +You can refer to the previous video generation instructions and use the `use_ret_steps` parameter to speed up the video generation process, achieving results closer to **Wan2.1**. Simply add the `--use_ret_steps` parameter to the original command and adjust the `--teacache_thresh` parameter to achieve more efficient video generation. The value of the `--teacache_thresh` parameter can be referenced from the table, allowing you to choose the appropriate value based on different models and settings. + +### Example Command: + +```bash +python teacache_generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.3 --use_ret_steps +``` + + ## Acknowledgements We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1). \ No newline at end of file diff --git a/TeaCache4Wan2.1/teacache_generate.py b/TeaCache4Wan2.1/teacache_generate.py index 649f1d5..c296f0b 100644 --- a/TeaCache4Wan2.1/teacache_generate.py +++ b/TeaCache4Wan2.1/teacache_generate.py @@ -29,6 +29,7 @@ from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from tqdm import tqdm + EXAMPLE_PROMPT = { "t2v-1.3B": { "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", @@ -48,183 +49,6 @@ EXAMPLE_PROMPT = { } - -def _validate_args(args): - # Basic check - assert args.ckpt_dir is not None, "Please specify the checkpoint directory." - assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" - assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" - - # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. - if args.sample_steps is None: - args.sample_steps = 40 if "i2v" in args.task else 50 - - if args.sample_shift is None: - args.sample_shift = 5.0 - if "i2v" in args.task and args.size in ["832*480", "480*832"]: - args.sample_shift = 3.0 - - # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. - if args.frame_num is None: - args.frame_num = 1 if "t2i" in args.task else 81 - - # T2I frame_num check - if "t2i" in args.task: - assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}" - - args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( - 0, sys.maxsize) - # Size check - assert args.size in SUPPORTED_SIZES[ - args. - task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" - - -def _parse_args(): - parser = argparse.ArgumentParser( - description="Generate a image or video from a text prompt or image using Wan" - ) - parser.add_argument( - "--task", - type=str, - default="t2v-14B", - choices=list(WAN_CONFIGS.keys()), - help="The task to run.") - parser.add_argument( - "--size", - type=str, - default="1280*720", - choices=list(SIZE_CONFIGS.keys()), - help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." - ) - parser.add_argument( - "--frame_num", - type=int, - default=None, - help="How many frames to sample from a image or video. The number should be 4n+1" - ) - parser.add_argument( - "--ckpt_dir", - type=str, - default=None, - help="The path to the checkpoint directory.") - parser.add_argument( - "--offload_model", - type=str2bool, - default=None, - help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." - ) - parser.add_argument( - "--ulysses_size", - type=int, - default=1, - help="The size of the ulysses parallelism in DiT.") - parser.add_argument( - "--ring_size", - type=int, - default=1, - help="The size of the ring attention parallelism in DiT.") - parser.add_argument( - "--t5_fsdp", - action="store_true", - default=False, - help="Whether to use FSDP for T5.") - parser.add_argument( - "--t5_cpu", - action="store_true", - default=False, - help="Whether to place T5 model on CPU.") - parser.add_argument( - "--dit_fsdp", - action="store_true", - default=False, - help="Whether to use FSDP for DiT.") - parser.add_argument( - "--save_file", - type=str, - default=None, - help="The file to save the generated image or video to.") - parser.add_argument( - "--prompt", - type=str, - default=None, - help="The prompt to generate the image or video from.") - parser.add_argument( - "--use_prompt_extend", - action="store_true", - default=False, - help="Whether to use prompt extend.") - parser.add_argument( - "--prompt_extend_method", - type=str, - default="local_qwen", - choices=["dashscope", "local_qwen"], - help="The prompt extend method to use.") - parser.add_argument( - "--prompt_extend_model", - type=str, - default=None, - help="The prompt extend model to use.") - parser.add_argument( - "--prompt_extend_target_lang", - type=str, - default="ch", - choices=["ch", "en"], - help="The target language of prompt extend.") - parser.add_argument( - "--base_seed", - type=int, - default=-1, - help="The seed to use for generating the image or video.") - parser.add_argument( - "--image", - type=str, - default=None, - help="The image to generate the video from.") - parser.add_argument( - "--sample_solver", - type=str, - default='unipc', - choices=['unipc', 'dpm++'], - help="The solver used to sample.") - parser.add_argument( - "--sample_steps", type=int, default=None, help="The sampling steps.") - parser.add_argument( - "--sample_shift", - type=float, - default=None, - help="Sampling shift factor for flow matching schedulers.") - parser.add_argument( - "--sample_guide_scale", - type=float, - default=5.0, - help="Classifier free guidance scale.") - parser.add_argument( - "--teacache_thresh", - type=float, - default=0.05, - help="The size of the ulysses parallelism in DiT.") - - args = parser.parse_args() - - _validate_args(args) - - return args - - -def _init_logging(rank): - # logging - if rank == 0: - # set format - logging.basicConfig( - level=logging.INFO, - format="[%(asctime)s] %(levelname)s: %(message)s", - handlers=[logging.StreamHandler(stream=sys.stdout)]) - else: - logging.basicConfig(level=logging.ERROR) - - -# add a cond_flag def t2v_generate(self, input_prompt, size=(1280, 720), @@ -341,8 +165,8 @@ def t2v_generate(self, # sample videos latents = noise - arg_c = {'context': context, 'seq_len': seq_len, 'cond_flag': True} - arg_null = {'context': context_null, 'seq_len': seq_len, 'cond_flag': False} + arg_c = {'context': context, 'seq_len': seq_len} + arg_null = {'context': context_null, 'seq_len': seq_len} for _, t in enumerate(tqdm(timesteps)): latent_model_input = latents @@ -385,7 +209,7 @@ def t2v_generate(self, return videos[0] if self.rank == 0 else None -# add a cond_flag + def i2v_generate(self, input_prompt, img, @@ -543,7 +367,7 @@ def i2v_generate(self, 'clip_fea': clip_context, 'seq_len': max_seq_len, 'y': [y], - 'cond_flag': True, + # 'cond_flag': True, } arg_null = { @@ -551,7 +375,7 @@ def i2v_generate(self, 'clip_fea': clip_context, 'seq_len': max_seq_len, 'y': [y], - 'cond_flag': False, + # 'cond_flag': False, } if offload_model: @@ -610,131 +434,332 @@ def i2v_generate(self, return videos[0] if self.rank == 0 else None + def teacache_forward( - self, - x, - t, - context, - seq_len, - clip_fea=None, - y=None, - cond_flag=False, - ): - r""" - Forward pass through the diffusion model + self, + x, + t, + context, + seq_len, + clip_fea=None, + y=None, +): + r""" + Forward pass through the diffusion model - Args: - x (List[Tensor]): - List of input video tensors, each with shape [C_in, F, H, W] - t (Tensor): - Diffusion timesteps tensor of shape [B] - context (List[Tensor]): - List of text embeddings each with shape [L, C] - seq_len (`int`): - Maximum sequence length for positional encoding - clip_fea (Tensor, *optional*): - CLIP image features for image-to-video mode - y (List[Tensor], *optional*): - Conditional video inputs for image-to-video mode, same shape as x + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x - Returns: - List[Tensor]: - List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] - """ - if self.model_type == 'i2v': - assert clip_fea is not None and y is not None - # params - device = self.patch_embedding.weight.device - if self.freqs.device != device: - self.freqs = self.freqs.to(device) + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + if self.model_type == 'i2v': + assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) - if y is not None: - x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] - # embeddings - x = [self.patch_embedding(u.unsqueeze(0)) for u in x] - grid_sizes = torch.stack( - [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) - x = [u.flatten(2).transpose(1, 2) for u in x] - seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) - assert seq_lens.max() <= seq_len - x = torch.cat([ - torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], - dim=1) for u in x - ]) - # time embeddings - with amp.autocast(dtype=torch.float32): - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t).float()) - e0 = self.time_projection(e).unflatten(1, (6, self.dim)) - assert e.dtype == torch.float32 and e0.dtype == torch.float32 + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) - # context - context_lens = None - context = self.text_embedding( - torch.stack([ - torch.cat( - [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) - for u in context - ])) + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 - if clip_fea is not None: - context_clip = self.img_emb(clip_fea) # bs x 257 x dim - context = torch.concat([context_clip, context], dim=1) + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) - # arguments - kwargs = dict( - e=e0, - seq_lens=seq_lens, - grid_sizes=grid_sizes, - freqs=self.freqs, - context=context, - context_lens=context_lens) + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens) - if self.enable_teacache: - if cond_flag: - modulated_inp = e - if self.cnt == 0 or self.cnt == self.num_steps-1: - should_calc = True - self.accumulated_rel_l1_distance = 0 - else: - rescale_func = np.poly1d(self.coefficients) - if cond_flag: - self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) - if self.accumulated_rel_l1_distance < self.rel_l1_thresh: - should_calc = False - else: - should_calc = True - self.accumulated_rel_l1_distance = 0 - self.previous_modulated_input = modulated_inp - self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1 - self.should_calc = should_calc + if self.enable_teacache: + modulated_inp = e0 if self.use_ref_steps else e + # teacache + if self.cnt%2==0: # even -> conditon + self.is_even = True + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc_even = True + self.accumulated_rel_l1_distance_even = 0 else: - should_calc = self.should_calc - # if not cond_flag: - # self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1 - - if self.enable_teacache: - if not should_calc: - x = x + self.previous_residual_cond if cond_flag else x + self.previous_residual_uncond + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_even < self.teacache_thresh: + should_calc_even = False + else: + should_calc_even = True + self.accumulated_rel_l1_distance_even = 0 + self.previous_e0_even = modulated_inp.clone() + + else: # odd -> unconditon + self.is_even = False + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc_odd = True + self.accumulated_rel_l1_distance_odd = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: + should_calc_odd = False + else: + should_calc_odd = True + self.accumulated_rel_l1_distance_odd = 0 + self.previous_e0_odd = modulated_inp.clone() + + if self.enable_teacache: + if self.is_even: + if not should_calc_even: + x += self.previous_residual_even else: ori_x = x.clone() for block in self.blocks: x = block(x, **kwargs) - if cond_flag: - self.previous_residual_cond = x - ori_x - else: - self.previous_residual_uncond = x - ori_x + self.previous_residual_even = x - ori_x else: - for block in self.blocks: - x = block(x, **kwargs) + if not should_calc_odd: + x += self.previous_residual_odd + else: + ori_x = x.clone() + for block in self.blocks: + x = block(x, **kwargs) + self.previous_residual_odd = x - ori_x + + else: + for block in self.blocks: + x = block(x, **kwargs) - # head - x = self.head(x, e) + # head + x = self.head(x, e) - # unpatchify - x = self.unpatchify(x, grid_sizes) - return [u.float() for u in x] + # unpatchify + x = self.unpatchify(x, grid_sizes) + self.cnt += 1 + if self.cnt >= self.num_steps: + self.cnt = 0 + return [u.float() for u in x] + +def _validate_args(args): + # Basic check + assert args.ckpt_dir is not None, "Please specify the checkpoint directory." + assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" + assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" + + # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. + if args.sample_steps is None: + args.sample_steps = 40 if "i2v" in args.task else 50 + + if args.sample_shift is None: + args.sample_shift = 5.0 + if "i2v" in args.task and args.size in ["832*480", "480*832"]: + args.sample_shift = 3.0 + + # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. + if args.frame_num is None: + args.frame_num = 1 if "t2i" in args.task else 81 + + # T2I frame_num check + if "t2i" in args.task: + assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}" + + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( + 0, sys.maxsize) + # Size check + assert args.size in SUPPORTED_SIZES[ + args. + task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate a image or video from a text prompt or image using Wan" + ) + parser.add_argument( + "--task", + type=str, + default="t2v-14B", + choices=list(WAN_CONFIGS.keys()), + help="The task to run.") + parser.add_argument( + "--size", + type=str, + default="1280*720", + choices=list(SIZE_CONFIGS.keys()), + help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." + ) + parser.add_argument( + "--frame_num", + type=int, + default=None, + help="How many frames to sample from a image or video. The number should be 4n+1" + ) + parser.add_argument( + "--ckpt_dir", + type=str, + default=None, + help="The path to the checkpoint directory.") + parser.add_argument( + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." + ) + parser.add_argument( + "--ulysses_size", + type=int, + default=1, + help="The size of the ulysses parallelism in DiT.") + parser.add_argument( + "--ring_size", + type=int, + default=1, + help="The size of the ring attention parallelism in DiT.") + parser.add_argument( + "--t5_fsdp", + action="store_true", + default=False, + help="Whether to use FSDP for T5.") + parser.add_argument( + "--t5_cpu", + action="store_true", + default=False, + help="Whether to place T5 model on CPU.") + parser.add_argument( + "--dit_fsdp", + action="store_true", + default=False, + help="Whether to use FSDP for DiT.") + parser.add_argument( + "--save_file", + type=str, + default=None, + help="The file to save the generated image or video to.") + parser.add_argument( + "--prompt", + type=str, + default=None, + help="The prompt to generate the image or video from.") + parser.add_argument( + "--use_prompt_extend", + action="store_true", + default=False, + help="Whether to use prompt extend.") + parser.add_argument( + "--prompt_extend_method", + type=str, + default="local_qwen", + choices=["dashscope", "local_qwen"], + help="The prompt extend method to use.") + parser.add_argument( + "--prompt_extend_model", + type=str, + default=None, + help="The prompt extend model to use.") + parser.add_argument( + "--prompt_extend_target_lang", + type=str, + default="ch", + choices=["ch", "en"], + help="The target language of prompt extend.") + parser.add_argument( + "--base_seed", + type=int, + default=-1, + help="The seed to use for generating the image or video.") + parser.add_argument( + "--image", + type=str, + default=None, + help="The image to generate the video from.") + parser.add_argument( + "--sample_solver", + type=str, + default='unipc', + choices=['unipc', 'dpm++'], + help="The solver used to sample.") + parser.add_argument( + "--sample_steps", type=int, default=None, help="The sampling steps.") + parser.add_argument( + "--sample_shift", + type=float, + default=None, + help="Sampling shift factor for flow matching schedulers.") + parser.add_argument( + "--sample_guide_scale", + type=float, + default=5.0, + help="Classifier free guidance scale.") + parser.add_argument( + "--teacache_thresh", + type=float, + default=0.2, + help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup") + parser.add_argument( + "--use_ret_steps", + action="store_true", + default=False, + help="Using Retention Steps will result in faster generation speed and better generation quality.") + + + args = parser.parse_args() + + _validate_args(args) + + return args + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)]) + else: + logging.basicConfig(level=logging.ERROR) def generate(args): @@ -841,21 +866,33 @@ def generate(args): # TeaCache wan_t2v.__class__.generate = t2v_generate - wan_t2v.model.__class__.cnt = 0 wan_t2v.model.__class__.enable_teacache = True - wan_t2v.model.__class__.num_steps = args.sample_steps if args.sample_steps is not None else 50 - wan_t2v.model.__class__.rel_l1_thresh = args.teacache_thresh # 2min54s, 0.05: 1min 55s(1.5x), 0.1, 1min 24s(2.1x) 0.15, 1min 6s, 0.08: 1min 27s(2x), 0.07: 1min 48s(1.6x), 0.06: 1min 51s - wan_t2v.model.__class__.accumulated_rel_l1_distance = 0 - wan_t2v.model.__class__.previous_modulated_input = None - wan_t2v.model.__class__.previous_residual = None - wan_t2v.model.__class__.previous_residual_uncond = None - wan_t2v.model.__class__.should_calc = True - if '1.3B' in args.ckpt_dir: - wan_t2v.model.__class__.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] - if '14B' in args.ckpt_dir: - wan_t2v.model.__class__.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] - wan_t2v.model.__class__.forward = teacache_forward - + wan_t2v.model.__class__.forward = teacache_forward + wan_t2v.model.__class__.cnt = 0 + wan_t2v.model.__class__.num_steps = args.sample_steps*2 + wan_t2v.model.__class__.teacache_thresh = args.teacache_thresh + wan_t2v.model.__class__.accumulated_rel_l1_distance_even = 0 + wan_t2v.model.__class__.accumulated_rel_l1_distance_odd = 0 + wan_t2v.model.__class__.previous_e0_even = None + wan_t2v.model.__class__.previous_e0_odd = None + wan_t2v.model.__class__.previous_residual_even = None + wan_t2v.model.__class__.previous_residual_odd = None + wan_t2v.model.__class__.use_ref_steps = args.use_ret_steps + if args.use_ret_steps: + # wan_t2v.model.__class__.task = args.task + if '1.3B' in args.ckpt_dir: + wan_t2v.model.__class__.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02] + if '14B' in args.ckpt_dir: + wan_t2v.model.__class__.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01] + wan_t2v.model.__class__.ret_steps = 5*2 + wan_t2v.model.__class__.cutoff_steps = args.sample_steps*2 + else: + if '1.3B' in args.ckpt_dir: + wan_t2v.model.__class__.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] + if '14B' in args.ckpt_dir: + wan_t2v.model.__class__.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] + wan_t2v.model.__class__.ret_steps = 1*2 + wan_t2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2 logging.info( f"Generating {'image' if 't2i' in args.task else 'video'} ...") video = wan_t2v.generate( @@ -912,23 +949,35 @@ def generate(args): use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) - # TeaCache wan_i2v.__class__.generate = i2v_generate - wan_i2v.model.__class__.cnt = 0 wan_i2v.model.__class__.enable_teacache = True - wan_i2v.model.__class__.num_steps = args.sample_steps if args.sample_steps is not None else 40 - wan_i2v.model.__class__.rel_l1_thresh = args.teacache_thresh # 12min 26s - wan_i2v.model.__class__.accumulated_rel_l1_distance = 0 - wan_i2v.model.__class__.previous_modulated_input = None - wan_i2v.model.__class__.previous_residual_cond = None - wan_i2v.model.__class__.previous_residual_uncond = None - wan_i2v.model.__class__.should_calc = True - if '480P' in args.ckpt_dir: - wan_i2v.model.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] - if '720P' in args.ckpt_dir: - wan_i2v.model.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] wan_i2v.model.__class__.forward = teacache_forward + wan_i2v.model.__class__.cnt = 0 + wan_i2v.model.__class__.num_steps = args.sample_steps*2 + wan_i2v.model.__class__.teacache_thresh = args.teacache_thresh + wan_i2v.model.__class__.accumulated_rel_l1_distance_even = 0 + wan_i2v.model.__class__.accumulated_rel_l1_distance_odd = 0 + wan_i2v.model.__class__.previous_e0_even = None + wan_i2v.model.__class__.previous_e0_odd = None + wan_i2v.model.__class__.previous_residual_even = None + wan_i2v.model.__class__.previous_residual_odd = None + wan_i2v.model.__class__.use_ref_steps = args.use_ret_steps + if args.use_ret_steps: + # wan_i2v.model.__class__.task = args.task + if '480P' in args.ckpt_dir: + wan_i2v.model.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01] + if '720P' in args.ckpt_dir: + wan_i2v.model.__class__.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02] + wan_i2v.model.__class__.ret_steps = 5*2 + wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2 + else: + if '480P' in args.ckpt_dir: + wan_i2v.model.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] + if '720P' in args.ckpt_dir: + wan_i2v.model.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] + wan_i2v.model.__class__.ret_steps = 1*2 + wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2 logging.info("Generating video ...") video = wan_i2v.generate( @@ -968,8 +1017,9 @@ def generate(args): nrow=1, normalize=True, value_range=(-1, 1)) - logging.info("Finished.") - + logging.info("Finished.") + + if __name__ == "__main__": args = _parse_args() From cabe560cf6a2bc08b5f61aa8aa0aed5a82a8e20f Mon Sep 17 00:00:00 2001 From: zishen-ucap Date: Thu, 13 Mar 2025 17:06:39 +0800 Subject: [PATCH 2/3] Support parameter -- use_ret_stpes --- TeaCache4Wan2.1/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TeaCache4Wan2.1/README.md b/TeaCache4Wan2.1/README.md index 05d4912..5896e85 100644 --- a/TeaCache4Wan2.1/README.md +++ b/TeaCache4Wan2.1/README.md @@ -91,7 +91,7 @@ https://github.com/user-attachments/assets/d1340197-20c1-4f9e-a780-31f789af0893 | True | ~903 s (0.00) | ~430 s (0.20) | ~340 s (0.30) | -You can refer to the previous video generation instructions and use the `use_ret_steps` parameter to speed up the video generation process, achieving results closer to **Wan2.1**. Simply add the `--use_ret_steps` parameter to the original command and adjust the `--teacache_thresh` parameter to achieve more efficient video generation. The value of the `--teacache_thresh` parameter can be referenced from the table, allowing you to choose the appropriate value based on different models and settings. +You can refer to the previous video generation instructions and use the `use_ret_steps` parameter to speed up the video generation process, achieving results closer to [Wan2.1](https://github.com/Wan-Video/Wan2.1). Simply add the `--use_ret_steps` parameter to the original command and adjust the `--teacache_thresh` parameter to achieve more efficient video generation. The value of the `--teacache_thresh` parameter can be referenced from the table, allowing you to choose the appropriate value based on different models and settings. ### Example Command: From 8ae350344216bc9b20fddba6a89ad0be03f6efd1 Mon Sep 17 00:00:00 2001 From: zishen-ucap Date: Thu, 13 Mar 2025 17:16:51 +0800 Subject: [PATCH 3/3] Support parameter -- use_ret_stpes --- TeaCache4Wan2.1/teacache_generate.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/TeaCache4Wan2.1/teacache_generate.py b/TeaCache4Wan2.1/teacache_generate.py index c296f0b..bafbf3a 100644 --- a/TeaCache4Wan2.1/teacache_generate.py +++ b/TeaCache4Wan2.1/teacache_generate.py @@ -879,7 +879,6 @@ def generate(args): wan_t2v.model.__class__.previous_residual_odd = None wan_t2v.model.__class__.use_ref_steps = args.use_ret_steps if args.use_ret_steps: - # wan_t2v.model.__class__.task = args.task if '1.3B' in args.ckpt_dir: wan_t2v.model.__class__.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02] if '14B' in args.ckpt_dir: @@ -964,7 +963,6 @@ def generate(args): wan_i2v.model.__class__.previous_residual_odd = None wan_i2v.model.__class__.use_ref_steps = args.use_ret_steps if args.use_ret_steps: - # wan_i2v.model.__class__.task = args.task if '480P' in args.ckpt_dir: wan_i2v.model.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01] if '720P' in args.ckpt_dir: