mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-16 08:14:32 +08:00
add WanVideo TeaCache coefficients
This commit is contained in:
parent
4d8cd3daa4
commit
fa6d20eeb3
@ -689,6 +689,7 @@ except:
|
||||
from einops import repeat
|
||||
from unittest.mock import patch
|
||||
from contextlib import nullcontext
|
||||
import numpy as np
|
||||
|
||||
def relative_l1_distance(last_tensor, current_tensor):
|
||||
l1_distance = torch.abs(last_tensor - current_tensor).mean()
|
||||
@ -751,8 +752,12 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
||||
cache = self.teacache_state[suffix]
|
||||
|
||||
if cache['prev_input'] is not None:
|
||||
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
|
||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
|
||||
if kwargs["transformer_options"]["coefficients"] == []:
|
||||
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
|
||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
|
||||
else:
|
||||
rescale_func = np.poly1d(kwargs["transformer_options"]["coefficients"])
|
||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
|
||||
try:
|
||||
if curr_acc_dist < rel_l1_thresh:
|
||||
should_calc = False
|
||||
@ -764,7 +769,10 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
||||
should_calc = True
|
||||
cache['accumulated_rel_l1_distance'] = 0
|
||||
|
||||
cache['prev_input'] = e0.clone().detach()
|
||||
if kwargs["transformer_options"]["coefficients"] == []:
|
||||
cache['prev_input'] = e0.clone().detach()
|
||||
else:
|
||||
cache['prev_input'] = e.clone().detach()
|
||||
|
||||
if not should_calc:
|
||||
x += cache['previous_residual'].to(x.device)
|
||||
@ -801,6 +809,7 @@ class WanVideoTeaCacheKJ:
|
||||
"start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}),
|
||||
"cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}),
|
||||
"coefficients": (["disabled", "1.3B", "14B", "i2v_480", "i2v_720"],),
|
||||
}
|
||||
}
|
||||
|
||||
@ -811,10 +820,31 @@ class WanVideoTeaCacheKJ:
|
||||
DESCRIPTION = "Patch WanVideo model to use TeaCache. Speeds up inference by caching the output of the model and applying it based on the input/output difference. Currently doesn't use coefficients for caching, will be imporoved in the future"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device):
|
||||
def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device, coefficients):
|
||||
if rel_l1_thresh == 0:
|
||||
return (model,)
|
||||
|
||||
# type_str = str(type(model.model.model_config).__name__)
|
||||
if model.model.diffusion_model.dim == 1536:
|
||||
model_type ="1.3B"
|
||||
# else:
|
||||
# if "WAN21_T2V" in type_str:
|
||||
# model_type = "14B"
|
||||
# elif "WAN21_I2V" in type_str:
|
||||
# model_type = "i2v_480"
|
||||
# else:
|
||||
# model_type = "i2v_720" #how to detect this?
|
||||
|
||||
|
||||
teacache_coefficients_map = {
|
||||
"disabled": [],
|
||||
"1.3B": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01],
|
||||
"14B": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404],
|
||||
"i2v_480": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01],
|
||||
"i2v_720": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683],
|
||||
}
|
||||
coefficients = teacache_coefficients_map[coefficients]
|
||||
|
||||
teacache_device = mm.get_torch_device() if cache_device == "main_device" else mm.unet_offload_device()
|
||||
|
||||
model_clone = model.clone()
|
||||
@ -822,6 +852,7 @@ class WanVideoTeaCacheKJ:
|
||||
model_clone.model_options['transformer_options'] = {}
|
||||
model_clone.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh
|
||||
model_clone.model_options["transformer_options"]["teacache_device"] = teacache_device
|
||||
model_clone.model_options["transformer_options"]["coefficients"] = coefficients
|
||||
diffusion_model = model_clone.get_model_object("diffusion_model")
|
||||
|
||||
def outer_wrapper(start_percent, end_percent):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user