Add WanVideoTeaCache

This commit is contained in:
kijai 2025-03-03 15:23:53 +02:00
parent 9a15e22f5e
commit d00082f648
2 changed files with 203 additions and 1 deletions

View File

@ -183,6 +183,7 @@ NODE_CONFIG = {
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
"WanVideoTeaCache": {"class": WanVideoTeaCache, "name": "WanVideo Tea Cache"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -1,5 +1,5 @@
from comfy.ldm.modules import attention as comfy_attention
import logging
import comfy.model_patcher
import comfy.utils
import comfy.sd
@ -192,6 +192,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
model_options = {}
if dtype := DTYPE_MAP.get(weight_dtype):
model_options["dtype"] = dtype
print(f"Setting {model_name} weight dtype to {dtype}")
if weight_dtype == "fp8_e4m3fn_fast":
model_options["dtype"] = torch.float8_e4m3fn
@ -211,6 +212,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
if dtype := DTYPE_MAP.get(compute_dtype):
model.set_model_compute_dtype(dtype)
model.force_cast_weights = False
print(f"Setting {model_name} compute dtype to {dtype}")
self._patch_modules(patch_cublaslinear, sage_attention)
return (model,)
@ -676,3 +678,202 @@ class TorchCompileCosmosModel:
raise RuntimeError("Failed to compile model")
return (m, )
#teacache
from comfy.ldm.wan.model import sinusoidal_embedding_1d
from einops import repeat
from unittest.mock import patch
from contextlib import nullcontext
def relative_l1_distance(last_tensor, current_tensor):
l1_distance = torch.abs(last_tensor - current_tensor).mean()
norm = torch.abs(last_tensor).mean()
relative_l1_distance = l1_distance / norm
return relative_l1_distance.to(torch.float32)
#for now as there doesn't seem to be a way to pass transformer_options to the forward_orig currently
def teacache_wanvideo_forward(self, x, timestep, context, clip_fea=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, **kwargs)[:, :, :t, :h, :w]
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, **kwargs):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(context)
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
#teacache for cond and uncond separately
rel_l1_thresh = kwargs["transformer_options"]["rel_l1_thresh"]
cache_device = kwargs["transformer_options"]["teacache_device"]
is_cond = True if kwargs["transformer_options"]["cond_or_uncond"] == [0] else False
should_calc = True
suffix = "cond" if is_cond else "uncond"
# Init cache dict if not exists
if not hasattr(self, 'teacache_state'):
self.teacache_state = {
'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
'teacache_skipped_steps': 0, 'previous_residual': None},
'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
'teacache_skipped_steps': 0, 'previous_residual': None}
}
logging.info("TeaCache: Initialized")
cache = self.teacache_state[suffix]
if cache['prev_input'] is not None:
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
try:
if curr_acc_dist < rel_l1_thresh:
should_calc = False
cache['accumulated_rel_l1_distance'] = curr_acc_dist
else:
should_calc = True
cache['accumulated_rel_l1_distance'] = 0
except:
should_calc = True
cache['accumulated_rel_l1_distance'] = 0
cache['prev_input'] = e0.clone().detach()
if not should_calc:
x += cache['previous_residual'].to(x.device)
cache['teacache_skipped_steps'] += 1
print(f"TeaCache: Skipping {suffix} step")
if should_calc:
original_x = x.clone().detach()
# arguments
block_wargs = dict(
e=e0,
freqs=freqs,
context=context)
for block in self.blocks:
x = block(x, **block_wargs)
cache['previous_residual'] = (x - original_x).to(cache_device)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
class WanVideoTeaCache:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"rel_l1_thresh": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Threshold for to determine when to apply the cache, compromise between speed and accuracy"}),
"start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}),
"cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}),
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "patch_teacache"
CATEGORY = "KJNodes/teacache"
DESCRIPTION = "Patch WanVideo model to use TeaCache. Speeds up inference by caching the output of the model and applying it based on the input/output difference. Currently doesn't use coefficients for caching, will be imporoved in the future"
EXPERIMENTAL = True
def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device):
if rel_l1_thresh == 0:
return (model,)
teacache_device = mm.get_torch_device() if cache_device == "main_device" else mm.unet_offload_device()
model_clone = model.clone()
if 'transformer_options' not in model_clone.model_options:
model_clone.model_options['transformer_options'] = {}
model_clone.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh
model_clone.model_options["transformer_options"]["teacache_device"] = teacache_device
diffusion_model = model_clone.get_model_object("diffusion_model")
def outer_wrapper(start_percent, end_percent):
def unet_wrapper_function(model_function, kwargs):
input = kwargs["input"]
timestep = kwargs["timestep"]
c = kwargs["c"]
sigmas = c["transformer_options"]["sample_sigmas"]
cond_or_uncond = kwargs["cond_or_uncond"]
last_step = (len(sigmas) - 1)
matched_step_index = (sigmas == timestep[0] ).nonzero()
if len(matched_step_index) > 0:
current_step_index = matched_step_index.item()
else:
for i in range(len(sigmas) - 1):
# walk from beginning of steps until crossing the timestep
if (sigmas[i] - timestep) * (sigmas[i + 1] - timestep) <= 0:
current_step_index = i
break
else:
current_step_index = 0
if current_step_index == 0:
if hasattr(diffusion_model, "teacache_state"):
delattr(diffusion_model, "teacache_state")
logging.info("Resetting TeaCache state")
current_percent = current_step_index / (len(sigmas) - 1)
if start_percent <= current_percent <= end_percent:
c["transformer_options"]["teacache_enabled"] = True
context = patch.multiple(
diffusion_model,
forward=teacache_wanvideo_forward.__get__(diffusion_model, diffusion_model.__class__),
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
)
else:
context = nullcontext()
with context:
out = model_function(input, timestep, **c)
if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"):
if cond_or_uncond[0] == 0:
skipped_steps_cond = diffusion_model.teacache_state["cond"]["teacache_skipped_steps"]
skipped_steps_uncond = diffusion_model.teacache_state["uncond"]["teacache_skipped_steps"]
logging.info("-----------------------------------")
logging.info(f"TeaCache skipped:")
logging.info(f"{skipped_steps_cond} cond steps")
logging.info(f"{skipped_steps_uncond} uncond step")
logging.info(f"out of {last_step} steps")
logging.info("-----------------------------------")
return out
return unet_wrapper_function
model_clone.set_model_unet_function_wrapper(outer_wrapper(start_percent=start_percent, end_percent=end_percent))
return (model_clone,)