add WanVideo TeaCache coefficients

This commit is contained in:
kijai 2025-03-05 14:03:05 +02:00
parent 4d8cd3daa4
commit fa6d20eeb3

View File

@ -689,6 +689,7 @@ except:
from einops import repeat
from unittest.mock import patch
from contextlib import nullcontext
import numpy as np
def relative_l1_distance(last_tensor, current_tensor):
l1_distance = torch.abs(last_tensor - current_tensor).mean()
@ -751,8 +752,12 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
cache = self.teacache_state[suffix]
if cache['prev_input'] is not None:
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
if kwargs["transformer_options"]["coefficients"] == []:
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
else:
rescale_func = np.poly1d(kwargs["transformer_options"]["coefficients"])
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
try:
if curr_acc_dist < rel_l1_thresh:
should_calc = False
@ -764,7 +769,10 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
should_calc = True
cache['accumulated_rel_l1_distance'] = 0
cache['prev_input'] = e0.clone().detach()
if kwargs["transformer_options"]["coefficients"] == []:
cache['prev_input'] = e0.clone().detach()
else:
cache['prev_input'] = e.clone().detach()
if not should_calc:
x += cache['previous_residual'].to(x.device)
@ -801,6 +809,7 @@ class WanVideoTeaCacheKJ:
"start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}),
"cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}),
"coefficients": (["disabled", "1.3B", "14B", "i2v_480", "i2v_720"],),
}
}
@ -811,10 +820,31 @@ class WanVideoTeaCacheKJ:
DESCRIPTION = "Patch WanVideo model to use TeaCache. Speeds up inference by caching the output of the model and applying it based on the input/output difference. Currently doesn't use coefficients for caching, will be imporoved in the future"
EXPERIMENTAL = True
def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device):
def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device, coefficients):
if rel_l1_thresh == 0:
return (model,)
# type_str = str(type(model.model.model_config).__name__)
if model.model.diffusion_model.dim == 1536:
model_type ="1.3B"
# else:
# if "WAN21_T2V" in type_str:
# model_type = "14B"
# elif "WAN21_I2V" in type_str:
# model_type = "i2v_480"
# else:
# model_type = "i2v_720" #how to detect this?
teacache_coefficients_map = {
"disabled": [],
"1.3B": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01],
"14B": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404],
"i2v_480": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01],
"i2v_720": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683],
}
coefficients = teacache_coefficients_map[coefficients]
teacache_device = mm.get_torch_device() if cache_device == "main_device" else mm.unet_offload_device()
model_clone = model.clone()
@ -822,6 +852,7 @@ class WanVideoTeaCacheKJ:
model_clone.model_options['transformer_options'] = {}
model_clone.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh
model_clone.model_options["transformer_options"]["teacache_device"] = teacache_device
model_clone.model_options["transformer_options"]["coefficients"] = coefficients
diffusion_model = model_clone.get_model_object("diffusion_model")
def outer_wrapper(start_percent, end_percent):