mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-19 01:34:31 +08:00
add WanVideo TeaCache coefficients
This commit is contained in:
parent
4d8cd3daa4
commit
fa6d20eeb3
@ -689,6 +689,7 @@ except:
|
|||||||
from einops import repeat
|
from einops import repeat
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def relative_l1_distance(last_tensor, current_tensor):
|
def relative_l1_distance(last_tensor, current_tensor):
|
||||||
l1_distance = torch.abs(last_tensor - current_tensor).mean()
|
l1_distance = torch.abs(last_tensor - current_tensor).mean()
|
||||||
@ -751,8 +752,12 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
|||||||
cache = self.teacache_state[suffix]
|
cache = self.teacache_state[suffix]
|
||||||
|
|
||||||
if cache['prev_input'] is not None:
|
if cache['prev_input'] is not None:
|
||||||
|
if kwargs["transformer_options"]["coefficients"] == []:
|
||||||
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
|
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
|
||||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
|
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
|
||||||
|
else:
|
||||||
|
rescale_func = np.poly1d(kwargs["transformer_options"]["coefficients"])
|
||||||
|
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
|
||||||
try:
|
try:
|
||||||
if curr_acc_dist < rel_l1_thresh:
|
if curr_acc_dist < rel_l1_thresh:
|
||||||
should_calc = False
|
should_calc = False
|
||||||
@ -764,7 +769,10 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
|
|||||||
should_calc = True
|
should_calc = True
|
||||||
cache['accumulated_rel_l1_distance'] = 0
|
cache['accumulated_rel_l1_distance'] = 0
|
||||||
|
|
||||||
|
if kwargs["transformer_options"]["coefficients"] == []:
|
||||||
cache['prev_input'] = e0.clone().detach()
|
cache['prev_input'] = e0.clone().detach()
|
||||||
|
else:
|
||||||
|
cache['prev_input'] = e.clone().detach()
|
||||||
|
|
||||||
if not should_calc:
|
if not should_calc:
|
||||||
x += cache['previous_residual'].to(x.device)
|
x += cache['previous_residual'].to(x.device)
|
||||||
@ -801,6 +809,7 @@ class WanVideoTeaCacheKJ:
|
|||||||
"start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}),
|
"start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}),
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}),
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}),
|
||||||
"cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}),
|
"cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}),
|
||||||
|
"coefficients": (["disabled", "1.3B", "14B", "i2v_480", "i2v_720"],),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -811,10 +820,31 @@ class WanVideoTeaCacheKJ:
|
|||||||
DESCRIPTION = "Patch WanVideo model to use TeaCache. Speeds up inference by caching the output of the model and applying it based on the input/output difference. Currently doesn't use coefficients for caching, will be imporoved in the future"
|
DESCRIPTION = "Patch WanVideo model to use TeaCache. Speeds up inference by caching the output of the model and applying it based on the input/output difference. Currently doesn't use coefficients for caching, will be imporoved in the future"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device):
|
def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device, coefficients):
|
||||||
if rel_l1_thresh == 0:
|
if rel_l1_thresh == 0:
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
|
# type_str = str(type(model.model.model_config).__name__)
|
||||||
|
if model.model.diffusion_model.dim == 1536:
|
||||||
|
model_type ="1.3B"
|
||||||
|
# else:
|
||||||
|
# if "WAN21_T2V" in type_str:
|
||||||
|
# model_type = "14B"
|
||||||
|
# elif "WAN21_I2V" in type_str:
|
||||||
|
# model_type = "i2v_480"
|
||||||
|
# else:
|
||||||
|
# model_type = "i2v_720" #how to detect this?
|
||||||
|
|
||||||
|
|
||||||
|
teacache_coefficients_map = {
|
||||||
|
"disabled": [],
|
||||||
|
"1.3B": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01],
|
||||||
|
"14B": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404],
|
||||||
|
"i2v_480": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01],
|
||||||
|
"i2v_720": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683],
|
||||||
|
}
|
||||||
|
coefficients = teacache_coefficients_map[coefficients]
|
||||||
|
|
||||||
teacache_device = mm.get_torch_device() if cache_device == "main_device" else mm.unet_offload_device()
|
teacache_device = mm.get_torch_device() if cache_device == "main_device" else mm.unet_offload_device()
|
||||||
|
|
||||||
model_clone = model.clone()
|
model_clone = model.clone()
|
||||||
@ -822,6 +852,7 @@ class WanVideoTeaCacheKJ:
|
|||||||
model_clone.model_options['transformer_options'] = {}
|
model_clone.model_options['transformer_options'] = {}
|
||||||
model_clone.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh
|
model_clone.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh
|
||||||
model_clone.model_options["transformer_options"]["teacache_device"] = teacache_device
|
model_clone.model_options["transformer_options"]["teacache_device"] = teacache_device
|
||||||
|
model_clone.model_options["transformer_options"]["coefficients"] = coefficients
|
||||||
diffusion_model = model_clone.get_model_object("diffusion_model")
|
diffusion_model = model_clone.get_model_object("diffusion_model")
|
||||||
|
|
||||||
def outer_wrapper(start_percent, end_percent):
|
def outer_wrapper(start_percent, end_percent):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user