mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-27 17:39:08 +08:00
Add WanVideoTeaCache
This commit is contained in:
parent
9a15e22f5e
commit
d00082f648
@ -183,6 +183,7 @@ NODE_CONFIG = {
|
|||||||
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
|
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
|
||||||
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
|
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
|
||||||
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
|
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
|
||||||
|
"WanVideoTeaCache": {"class": WanVideoTeaCache, "name": "WanVideo Tea Cache"},
|
||||||
|
|
||||||
#instance diffusion
|
#instance diffusion
|
||||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from comfy.ldm.modules import attention as comfy_attention
|
from comfy.ldm.modules import attention as comfy_attention
|
||||||
|
import logging
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
@ -192,6 +192,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
|||||||
model_options = {}
|
model_options = {}
|
||||||
if dtype := DTYPE_MAP.get(weight_dtype):
|
if dtype := DTYPE_MAP.get(weight_dtype):
|
||||||
model_options["dtype"] = dtype
|
model_options["dtype"] = dtype
|
||||||
|
print(f"Setting {model_name} weight dtype to {dtype}")
|
||||||
|
|
||||||
if weight_dtype == "fp8_e4m3fn_fast":
|
if weight_dtype == "fp8_e4m3fn_fast":
|
||||||
model_options["dtype"] = torch.float8_e4m3fn
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
@ -211,6 +212,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
|||||||
if dtype := DTYPE_MAP.get(compute_dtype):
|
if dtype := DTYPE_MAP.get(compute_dtype):
|
||||||
model.set_model_compute_dtype(dtype)
|
model.set_model_compute_dtype(dtype)
|
||||||
model.force_cast_weights = False
|
model.force_cast_weights = False
|
||||||
|
print(f"Setting {model_name} compute dtype to {dtype}")
|
||||||
self._patch_modules(patch_cublaslinear, sage_attention)
|
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||||
|
|
||||||
return (model,)
|
return (model,)
|
||||||
@ -676,3 +678,202 @@ class TorchCompileCosmosModel:
|
|||||||
raise RuntimeError("Failed to compile model")
|
raise RuntimeError("Failed to compile model")
|
||||||
|
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
|
#teacache
|
||||||
|
|
||||||
|
from comfy.ldm.wan.model import sinusoidal_embedding_1d
|
||||||
|
from einops import repeat
|
||||||
|
from unittest.mock import patch
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
def relative_l1_distance(last_tensor, current_tensor):
|
||||||
|
l1_distance = torch.abs(last_tensor - current_tensor).mean()
|
||||||
|
norm = torch.abs(last_tensor).mean()
|
||||||
|
relative_l1_distance = l1_distance / norm
|
||||||
|
return relative_l1_distance.to(torch.float32)
|
||||||
|
|
||||||
|
#for now as there doesn't seem to be a way to pass transformer_options to the forward_orig currently
|
||||||
|
def teacache_wanvideo_forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
||||||
|
bs, c, t, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
patch_size = self.patch_size
|
||||||
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||||
|
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||||
|
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||||
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||||
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||||
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
|
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, **kwargs):
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||||
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||||
|
|
||||||
|
# context
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
if clip_fea is not None and self.img_emb is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
|
||||||
|
#teacache for cond and uncond separately
|
||||||
|
rel_l1_thresh = kwargs["transformer_options"]["rel_l1_thresh"]
|
||||||
|
cache_device = kwargs["transformer_options"]["teacache_device"]
|
||||||
|
is_cond = True if kwargs["transformer_options"]["cond_or_uncond"] == [0] else False
|
||||||
|
|
||||||
|
should_calc = True
|
||||||
|
suffix = "cond" if is_cond else "uncond"
|
||||||
|
|
||||||
|
# Init cache dict if not exists
|
||||||
|
if not hasattr(self, 'teacache_state'):
|
||||||
|
self.teacache_state = {
|
||||||
|
'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
||||||
|
'teacache_skipped_steps': 0, 'previous_residual': None},
|
||||||
|
'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
||||||
|
'teacache_skipped_steps': 0, 'previous_residual': None}
|
||||||
|
}
|
||||||
|
logging.info("TeaCache: Initialized")
|
||||||
|
|
||||||
|
cache = self.teacache_state[suffix]
|
||||||
|
|
||||||
|
if cache['prev_input'] is not None:
|
||||||
|
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
|
||||||
|
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
|
||||||
|
try:
|
||||||
|
if curr_acc_dist < rel_l1_thresh:
|
||||||
|
should_calc = False
|
||||||
|
cache['accumulated_rel_l1_distance'] = curr_acc_dist
|
||||||
|
else:
|
||||||
|
should_calc = True
|
||||||
|
cache['accumulated_rel_l1_distance'] = 0
|
||||||
|
except:
|
||||||
|
should_calc = True
|
||||||
|
cache['accumulated_rel_l1_distance'] = 0
|
||||||
|
|
||||||
|
cache['prev_input'] = e0.clone().detach()
|
||||||
|
|
||||||
|
if not should_calc:
|
||||||
|
x += cache['previous_residual'].to(x.device)
|
||||||
|
cache['teacache_skipped_steps'] += 1
|
||||||
|
print(f"TeaCache: Skipping {suffix} step")
|
||||||
|
|
||||||
|
if should_calc:
|
||||||
|
original_x = x.clone().detach()
|
||||||
|
# arguments
|
||||||
|
block_wargs = dict(
|
||||||
|
e=e0,
|
||||||
|
freqs=freqs,
|
||||||
|
context=context)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, **block_wargs)
|
||||||
|
|
||||||
|
cache['previous_residual'] = (x - original_x).to(cache_device)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class WanVideoTeaCache:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL",),
|
||||||
|
"rel_l1_thresh": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Threshold for to determine when to apply the cache, compromise between speed and accuracy"}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}),
|
||||||
|
"cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
FUNCTION = "patch_teacache"
|
||||||
|
CATEGORY = "KJNodes/teacache"
|
||||||
|
DESCRIPTION = "Patch WanVideo model to use TeaCache. Speeds up inference by caching the output of the model and applying it based on the input/output difference. Currently doesn't use coefficients for caching, will be imporoved in the future"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device):
|
||||||
|
if rel_l1_thresh == 0:
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
teacache_device = mm.get_torch_device() if cache_device == "main_device" else mm.unet_offload_device()
|
||||||
|
|
||||||
|
model_clone = model.clone()
|
||||||
|
if 'transformer_options' not in model_clone.model_options:
|
||||||
|
model_clone.model_options['transformer_options'] = {}
|
||||||
|
model_clone.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh
|
||||||
|
model_clone.model_options["transformer_options"]["teacache_device"] = teacache_device
|
||||||
|
diffusion_model = model_clone.get_model_object("diffusion_model")
|
||||||
|
|
||||||
|
def outer_wrapper(start_percent, end_percent):
|
||||||
|
def unet_wrapper_function(model_function, kwargs):
|
||||||
|
input = kwargs["input"]
|
||||||
|
timestep = kwargs["timestep"]
|
||||||
|
c = kwargs["c"]
|
||||||
|
sigmas = c["transformer_options"]["sample_sigmas"]
|
||||||
|
cond_or_uncond = kwargs["cond_or_uncond"]
|
||||||
|
last_step = (len(sigmas) - 1)
|
||||||
|
|
||||||
|
matched_step_index = (sigmas == timestep[0] ).nonzero()
|
||||||
|
if len(matched_step_index) > 0:
|
||||||
|
current_step_index = matched_step_index.item()
|
||||||
|
else:
|
||||||
|
for i in range(len(sigmas) - 1):
|
||||||
|
# walk from beginning of steps until crossing the timestep
|
||||||
|
if (sigmas[i] - timestep) * (sigmas[i + 1] - timestep) <= 0:
|
||||||
|
current_step_index = i
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
current_step_index = 0
|
||||||
|
|
||||||
|
if current_step_index == 0:
|
||||||
|
if hasattr(diffusion_model, "teacache_state"):
|
||||||
|
delattr(diffusion_model, "teacache_state")
|
||||||
|
logging.info("Resetting TeaCache state")
|
||||||
|
|
||||||
|
current_percent = current_step_index / (len(sigmas) - 1)
|
||||||
|
if start_percent <= current_percent <= end_percent:
|
||||||
|
c["transformer_options"]["teacache_enabled"] = True
|
||||||
|
|
||||||
|
context = patch.multiple(
|
||||||
|
diffusion_model,
|
||||||
|
forward=teacache_wanvideo_forward.__get__(diffusion_model, diffusion_model.__class__),
|
||||||
|
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
context = nullcontext()
|
||||||
|
with context:
|
||||||
|
out = model_function(input, timestep, **c)
|
||||||
|
if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"):
|
||||||
|
if cond_or_uncond[0] == 0:
|
||||||
|
skipped_steps_cond = diffusion_model.teacache_state["cond"]["teacache_skipped_steps"]
|
||||||
|
skipped_steps_uncond = diffusion_model.teacache_state["uncond"]["teacache_skipped_steps"]
|
||||||
|
|
||||||
|
logging.info("-----------------------------------")
|
||||||
|
logging.info(f"TeaCache skipped:")
|
||||||
|
logging.info(f"{skipped_steps_cond} cond steps")
|
||||||
|
logging.info(f"{skipped_steps_uncond} uncond step")
|
||||||
|
logging.info(f"out of {last_step} steps")
|
||||||
|
logging.info("-----------------------------------")
|
||||||
|
return out
|
||||||
|
return unet_wrapper_function
|
||||||
|
|
||||||
|
model_clone.set_model_unet_function_wrapper(outer_wrapper(start_percent=start_percent, end_percent=end_percent))
|
||||||
|
|
||||||
|
return (model_clone,)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user