From d00082f648e2a16d0a850e3c492fdf0cdac07a91 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 3 Mar 2025 15:23:53 +0200 Subject: [PATCH] Add WanVideoTeaCache --- __init__.py | 1 + nodes/model_optimization_nodes.py | 203 +++++++++++++++++++++++++++++- 2 files changed, 203 insertions(+), 1 deletion(-) diff --git a/__init__.py b/__init__.py index b9a875f..af20c95 100644 --- a/__init__.py +++ b/__init__.py @@ -183,6 +183,7 @@ NODE_CONFIG = { "ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"}, "ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"}, "ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"}, + "WanVideoTeaCache": {"class": WanVideoTeaCache, "name": "WanVideo Tea Cache"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 5f6f3df..e566d4c 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -1,5 +1,5 @@ from comfy.ldm.modules import attention as comfy_attention - +import logging import comfy.model_patcher import comfy.utils import comfy.sd @@ -192,6 +192,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): model_options = {} if dtype := DTYPE_MAP.get(weight_dtype): model_options["dtype"] = dtype + print(f"Setting {model_name} weight dtype to {dtype}") if weight_dtype == "fp8_e4m3fn_fast": model_options["dtype"] = torch.float8_e4m3fn @@ -211,6 +212,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): if dtype := DTYPE_MAP.get(compute_dtype): model.set_model_compute_dtype(dtype) model.force_cast_weights = False + print(f"Setting {model_name} compute dtype to {dtype}") self._patch_modules(patch_cublaslinear, sage_attention) return (model,) @@ -676,3 +678,202 @@ class TorchCompileCosmosModel: raise RuntimeError("Failed to compile model") return (m, ) + + +#teacache + +from comfy.ldm.wan.model import sinusoidal_embedding_1d +from einops import repeat +from unittest.mock import patch +from contextlib import nullcontext + +def relative_l1_distance(last_tensor, current_tensor): + l1_distance = torch.abs(last_tensor - current_tensor).mean() + norm = torch.abs(last_tensor).mean() + relative_l1_distance = l1_distance / norm + return relative_l1_distance.to(torch.float32) + +#for now as there doesn't seem to be a way to pass transformer_options to the forward_orig currently +def teacache_wanvideo_forward(self, x, timestep, context, clip_fea=None, **kwargs): + bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + patch_size = self.patch_size + t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) + h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) + w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) + img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) + + freqs = self.rope_embedder(img_ids).movedim(1, 2) + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, **kwargs)[:, :, :t, :h, :w] + +def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, **kwargs): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # context + context = self.text_embedding(context) + if clip_fea is not None and self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + #teacache for cond and uncond separately + rel_l1_thresh = kwargs["transformer_options"]["rel_l1_thresh"] + cache_device = kwargs["transformer_options"]["teacache_device"] + is_cond = True if kwargs["transformer_options"]["cond_or_uncond"] == [0] else False + + should_calc = True + suffix = "cond" if is_cond else "uncond" + + # Init cache dict if not exists + if not hasattr(self, 'teacache_state'): + self.teacache_state = { + 'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None, + 'teacache_skipped_steps': 0, 'previous_residual': None}, + 'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None, + 'teacache_skipped_steps': 0, 'previous_residual': None} + } + logging.info("TeaCache: Initialized") + + cache = self.teacache_state[suffix] + + if cache['prev_input'] is not None: + temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0) + curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1 + try: + if curr_acc_dist < rel_l1_thresh: + should_calc = False + cache['accumulated_rel_l1_distance'] = curr_acc_dist + else: + should_calc = True + cache['accumulated_rel_l1_distance'] = 0 + except: + should_calc = True + cache['accumulated_rel_l1_distance'] = 0 + + cache['prev_input'] = e0.clone().detach() + + if not should_calc: + x += cache['previous_residual'].to(x.device) + cache['teacache_skipped_steps'] += 1 + print(f"TeaCache: Skipping {suffix} step") + + if should_calc: + original_x = x.clone().detach() + # arguments + block_wargs = dict( + e=e0, + freqs=freqs, + context=context) + + for block in self.blocks: + x = block(x, **block_wargs) + + cache['previous_residual'] = (x - original_x).to(cache_device) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x + +class WanVideoTeaCache: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "rel_l1_thresh": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Threshold for to determine when to apply the cache, compromise between speed and accuracy"}), + "start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}), + "cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}), + } + } + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "patch_teacache" + CATEGORY = "KJNodes/teacache" + DESCRIPTION = "Patch WanVideo model to use TeaCache. Speeds up inference by caching the output of the model and applying it based on the input/output difference. Currently doesn't use coefficients for caching, will be imporoved in the future" + EXPERIMENTAL = True + + def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device): + if rel_l1_thresh == 0: + return (model,) + + teacache_device = mm.get_torch_device() if cache_device == "main_device" else mm.unet_offload_device() + + model_clone = model.clone() + if 'transformer_options' not in model_clone.model_options: + model_clone.model_options['transformer_options'] = {} + model_clone.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh + model_clone.model_options["transformer_options"]["teacache_device"] = teacache_device + diffusion_model = model_clone.get_model_object("diffusion_model") + + def outer_wrapper(start_percent, end_percent): + def unet_wrapper_function(model_function, kwargs): + input = kwargs["input"] + timestep = kwargs["timestep"] + c = kwargs["c"] + sigmas = c["transformer_options"]["sample_sigmas"] + cond_or_uncond = kwargs["cond_or_uncond"] + last_step = (len(sigmas) - 1) + + matched_step_index = (sigmas == timestep[0] ).nonzero() + if len(matched_step_index) > 0: + current_step_index = matched_step_index.item() + else: + for i in range(len(sigmas) - 1): + # walk from beginning of steps until crossing the timestep + if (sigmas[i] - timestep) * (sigmas[i + 1] - timestep) <= 0: + current_step_index = i + break + else: + current_step_index = 0 + + if current_step_index == 0: + if hasattr(diffusion_model, "teacache_state"): + delattr(diffusion_model, "teacache_state") + logging.info("Resetting TeaCache state") + + current_percent = current_step_index / (len(sigmas) - 1) + if start_percent <= current_percent <= end_percent: + c["transformer_options"]["teacache_enabled"] = True + + context = patch.multiple( + diffusion_model, + forward=teacache_wanvideo_forward.__get__(diffusion_model, diffusion_model.__class__), + forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__) + ) + else: + context = nullcontext() + with context: + out = model_function(input, timestep, **c) + if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"): + if cond_or_uncond[0] == 0: + skipped_steps_cond = diffusion_model.teacache_state["cond"]["teacache_skipped_steps"] + skipped_steps_uncond = diffusion_model.teacache_state["uncond"]["teacache_skipped_steps"] + + logging.info("-----------------------------------") + logging.info(f"TeaCache skipped:") + logging.info(f"{skipped_steps_cond} cond steps") + logging.info(f"{skipped_steps_uncond} uncond step") + logging.info(f"out of {last_step} steps") + logging.info("-----------------------------------") + return out + return unet_wrapper_function + + model_clone.set_model_unet_function_wrapper(outer_wrapper(start_percent=start_percent, end_percent=end_percent)) + + return (model_clone,)