diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 682674a..8d269e3 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -990,6 +990,125 @@ def relative_l1_distance(last_tensor, current_tensor): relative_l1_distance = l1_distance / norm return relative_l1_distance.to(torch.float32) +@torch.compiler.disable() +def tea_cache(self, x, e0, e, transformer_options): + #teacache for cond and uncond separately + rel_l1_thresh = transformer_options["rel_l1_thresh"] + + is_cond = True if transformer_options["cond_or_uncond"] == [0] else False + + should_calc = True + suffix = "cond" if is_cond else "uncond" + + # Init cache dict if not exists + if not hasattr(self, 'teacache_state'): + self.teacache_state = { + 'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None, + 'teacache_skipped_steps': 0, 'previous_residual': None}, + 'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None, + 'teacache_skipped_steps': 0, 'previous_residual': None} + } + logging.info("\nTeaCache: Initialized") + + cache = self.teacache_state[suffix] + + if cache['prev_input'] is not None: + if transformer_options["coefficients"] == []: + temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0) + curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1 + else: + rescale_func = np.poly1d(transformer_options["coefficients"]) + curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item()) + try: + if curr_acc_dist < rel_l1_thresh: + should_calc = False + cache['accumulated_rel_l1_distance'] = curr_acc_dist + else: + should_calc = True + cache['accumulated_rel_l1_distance'] = 0 + except: + should_calc = True + cache['accumulated_rel_l1_distance'] = 0 + + if transformer_options["coefficients"] == []: + cache['prev_input'] = e0.clone().detach() + else: + cache['prev_input'] = e.clone().detach() + + if not should_calc: + x += cache['previous_residual'].to(x.device) + cache['teacache_skipped_steps'] += 1 + #print(f"TeaCache: Skipping {suffix} step") + return should_calc, cache + +def teacache_wanvideo_vace_forward_orig(self, x, t, context, vace_context, vace_strength, clip_fea=None, freqs=None, transformer_options={}, **kwargs): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + orig_shape = list(vace_context.shape) + vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:]) + c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype) + c = c.flatten(2).transpose(1, 2) + c = list(c.split(orig_shape[0], dim=0)) + + if not transformer_options: + raise RuntimeError("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later") + + teacache_enabled = transformer_options.get("teacache_enabled", False) + if not teacache_enabled: + should_calc = True + else: + should_calc, cache = tea_cache(self, x, e0, e, transformer_options) + + if should_calc: + original_x = x.clone().detach() + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + + ii = self.vace_layers_mapping.get(i, None) + if ii is not None: + for iii in range(len(c)): + c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=original_x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x += c_skip * vace_strength[iii] + del c_skip + + if teacache_enabled: + cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"]) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x + def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) @@ -1003,69 +1122,20 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non # context context = self.text_embedding(context) - if clip_fea is not None and self.img_emb is not None: - context_clip = self.img_emb(clip_fea) # bs x 257 x dim - context = torch.concat([context_clip, context], dim=1) - @torch.compiler.disable() - def tea_cache(x, e0, e, kwargs): - #teacache for cond and uncond separately - rel_l1_thresh = transformer_options["rel_l1_thresh"] - - is_cond = True if transformer_options["cond_or_uncond"] == [0] else False + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] - should_calc = True - suffix = "cond" if is_cond else "uncond" - - # Init cache dict if not exists - if not hasattr(self, 'teacache_state'): - self.teacache_state = { - 'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None, - 'teacache_skipped_steps': 0, 'previous_residual': None}, - 'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None, - 'teacache_skipped_steps': 0, 'previous_residual': None} - } - logging.info("\nTeaCache: Initialized") - - cache = self.teacache_state[suffix] - - if cache['prev_input'] is not None: - if transformer_options["coefficients"] == []: - temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0) - curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1 - else: - rescale_func = np.poly1d(transformer_options["coefficients"]) - curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item()) - try: - if curr_acc_dist < rel_l1_thresh: - should_calc = False - cache['accumulated_rel_l1_distance'] = curr_acc_dist - else: - should_calc = True - cache['accumulated_rel_l1_distance'] = 0 - except: - should_calc = True - cache['accumulated_rel_l1_distance'] = 0 - - if transformer_options["coefficients"] == []: - cache['prev_input'] = e0.clone().detach() - else: - cache['prev_input'] = e.clone().detach() - - if not should_calc: - x += cache['previous_residual'].to(x.device) - cache['teacache_skipped_steps'] += 1 - #print(f"TeaCache: Skipping {suffix} step") - return should_calc, cache - - if not transformer_options: - raise RuntimeError("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later") teacache_enabled = transformer_options.get("teacache_enabled", False) if not teacache_enabled: should_calc = True else: - should_calc, cache = tea_cache(x, e0, e, kwargs) + should_calc, cache = tea_cache(self, x, e0, e, transformer_options) if should_calc: original_x = x.clone().detach() @@ -1075,12 +1145,12 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) return out out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) if teacache_enabled: cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"]) @@ -1206,9 +1276,10 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC if start_percent <= current_percent <= end_percent: c["transformer_options"]["teacache_enabled"] = True + forward_function = teacache_wanvideo_vace_forward_orig if hasattr(diffusion_model, "vace_layers") else teacache_wanvideo_forward_orig context = patch.multiple( diffusion_model, - forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__) + forward_orig=forward_function.__get__(diffusion_model, diffusion_model.__class__) ) with context: