diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index beeb91b..801639a 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -689,6 +689,7 @@ except: from einops import repeat from unittest.mock import patch from contextlib import nullcontext +import numpy as np def relative_l1_distance(last_tensor, current_tensor): l1_distance = torch.abs(last_tensor - current_tensor).mean() @@ -751,8 +752,12 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non cache = self.teacache_state[suffix] if cache['prev_input'] is not None: - temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0) - curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1 + if kwargs["transformer_options"]["coefficients"] == []: + temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0) + curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1 + else: + rescale_func = np.poly1d(kwargs["transformer_options"]["coefficients"]) + curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item()) try: if curr_acc_dist < rel_l1_thresh: should_calc = False @@ -764,7 +769,10 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non should_calc = True cache['accumulated_rel_l1_distance'] = 0 - cache['prev_input'] = e0.clone().detach() + if kwargs["transformer_options"]["coefficients"] == []: + cache['prev_input'] = e0.clone().detach() + else: + cache['prev_input'] = e.clone().detach() if not should_calc: x += cache['previous_residual'].to(x.device) @@ -801,6 +809,7 @@ class WanVideoTeaCacheKJ: "start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}), "cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}), + "coefficients": (["disabled", "1.3B", "14B", "i2v_480", "i2v_720"],), } } @@ -811,10 +820,31 @@ class WanVideoTeaCacheKJ: DESCRIPTION = "Patch WanVideo model to use TeaCache. Speeds up inference by caching the output of the model and applying it based on the input/output difference. Currently doesn't use coefficients for caching, will be imporoved in the future" EXPERIMENTAL = True - def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device): + def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device, coefficients): if rel_l1_thresh == 0: return (model,) + # type_str = str(type(model.model.model_config).__name__) + if model.model.diffusion_model.dim == 1536: + model_type ="1.3B" + # else: + # if "WAN21_T2V" in type_str: + # model_type = "14B" + # elif "WAN21_I2V" in type_str: + # model_type = "i2v_480" + # else: + # model_type = "i2v_720" #how to detect this? + + + teacache_coefficients_map = { + "disabled": [], + "1.3B": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01], + "14B": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404], + "i2v_480": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01], + "i2v_720": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683], + } + coefficients = teacache_coefficients_map[coefficients] + teacache_device = mm.get_torch_device() if cache_device == "main_device" else mm.unet_offload_device() model_clone = model.clone() @@ -822,6 +852,7 @@ class WanVideoTeaCacheKJ: model_clone.model_options['transformer_options'] = {} model_clone.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh model_clone.model_options["transformer_options"]["teacache_device"] = teacache_device + model_clone.model_options["transformer_options"]["coefficients"] = coefficients diffusion_model = model_clone.get_model_object("diffusion_model") def outer_wrapper(start_percent, end_percent):