mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-23 13:14:37 +08:00
Add WanVideoTeaCache
This commit is contained in:
parent
9a15e22f5e
commit
d00082f648
@ -183,6 +183,7 @@ NODE_CONFIG = {
|
||||
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
|
||||
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
|
||||
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
|
||||
"WanVideoTeaCache": {"class": WanVideoTeaCache, "name": "WanVideo Tea Cache"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from comfy.ldm.modules import attention as comfy_attention
|
||||
|
||||
import logging
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
@ -192,6 +192,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
model_options = {}
|
||||
if dtype := DTYPE_MAP.get(weight_dtype):
|
||||
model_options["dtype"] = dtype
|
||||
print(f"Setting {model_name} weight dtype to {dtype}")
|
||||
|
||||
if weight_dtype == "fp8_e4m3fn_fast":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
@ -211,6 +212,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
if dtype := DTYPE_MAP.get(compute_dtype):
|
||||
model.set_model_compute_dtype(dtype)
|
||||
model.force_cast_weights = False
|
||||
print(f"Setting {model_name} compute dtype to {dtype}")
|
||||
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||
|
||||
return (model,)
|
||||
@ -676,3 +678,202 @@ class TorchCompileCosmosModel:
|
||||
raise RuntimeError("Failed to compile model")
|
||||
|
||||
return (m, )
|
||||
|
||||
|
||||
#teacache
|
||||
|
||||
from comfy.ldm.wan.model import sinusoidal_embedding_1d
|
||||
from einops import repeat
|
||||
from unittest.mock import patch
|
||||
from contextlib import nullcontext
|
||||
|
||||
def relative_l1_distance(last_tensor, current_tensor):
|
||||
l1_distance = torch.abs(last_tensor - current_tensor).mean()
|
||||
norm = torch.abs(last_tensor).mean()
|
||||
relative_l1_distance = l1_distance / norm
|
||||
return relative_l1_distance.to(torch.float32)
|
||||
|
||||
#for now as there doesn't seem to be a way to pass transformer_options to the forward_orig currently
|
||||
def teacache_wanvideo_forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, **kwargs):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
if clip_fea is not None and self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
#teacache for cond and uncond separately
|
||||
rel_l1_thresh = kwargs["transformer_options"]["rel_l1_thresh"]
|
||||
cache_device = kwargs["transformer_options"]["teacache_device"]
|
||||
is_cond = True if kwargs["transformer_options"]["cond_or_uncond"] == [0] else False
|
||||
|
||||
should_calc = True
|
||||
suffix = "cond" if is_cond else "uncond"
|
||||
|
||||
# Init cache dict if not exists
|
||||
if not hasattr(self, 'teacache_state'):
|
||||
self.teacache_state = {
|
||||
'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
||||
'teacache_skipped_steps': 0, 'previous_residual': None},
|
||||
'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
|
||||
'teacache_skipped_steps': 0, 'previous_residual': None}
|
||||
}
|
||||
logging.info("TeaCache: Initialized")
|
||||
|
||||
cache = self.teacache_state[suffix]
|
||||
|
||||
if cache['prev_input'] is not None:
|
||||
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
|
||||
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
|
||||
try:
|
||||
if curr_acc_dist < rel_l1_thresh:
|
||||
should_calc = False
|
||||
cache['accumulated_rel_l1_distance'] = curr_acc_dist
|
||||
else:
|
||||
should_calc = True
|
||||
cache['accumulated_rel_l1_distance'] = 0
|
||||
except:
|
||||
should_calc = True
|
||||
cache['accumulated_rel_l1_distance'] = 0
|
||||
|
||||
cache['prev_input'] = e0.clone().detach()
|
||||
|
||||
if not should_calc:
|
||||
x += cache['previous_residual'].to(x.device)
|
||||
cache['teacache_skipped_steps'] += 1
|
||||
print(f"TeaCache: Skipping {suffix} step")
|
||||
|
||||
if should_calc:
|
||||
original_x = x.clone().detach()
|
||||
# arguments
|
||||
block_wargs = dict(
|
||||
e=e0,
|
||||
freqs=freqs,
|
||||
context=context)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **block_wargs)
|
||||
|
||||
cache['previous_residual'] = (x - original_x).to(cache_device)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
class WanVideoTeaCache:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"rel_l1_thresh": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Threshold for to determine when to apply the cache, compromise between speed and accuracy"}),
|
||||
"start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps to use with TeaCache."}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps to use with TeaCache."}),
|
||||
"cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "patch_teacache"
|
||||
CATEGORY = "KJNodes/teacache"
|
||||
DESCRIPTION = "Patch WanVideo model to use TeaCache. Speeds up inference by caching the output of the model and applying it based on the input/output difference. Currently doesn't use coefficients for caching, will be imporoved in the future"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch_teacache(self, model, rel_l1_thresh, start_percent, end_percent, cache_device):
|
||||
if rel_l1_thresh == 0:
|
||||
return (model,)
|
||||
|
||||
teacache_device = mm.get_torch_device() if cache_device == "main_device" else mm.unet_offload_device()
|
||||
|
||||
model_clone = model.clone()
|
||||
if 'transformer_options' not in model_clone.model_options:
|
||||
model_clone.model_options['transformer_options'] = {}
|
||||
model_clone.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh
|
||||
model_clone.model_options["transformer_options"]["teacache_device"] = teacache_device
|
||||
diffusion_model = model_clone.get_model_object("diffusion_model")
|
||||
|
||||
def outer_wrapper(start_percent, end_percent):
|
||||
def unet_wrapper_function(model_function, kwargs):
|
||||
input = kwargs["input"]
|
||||
timestep = kwargs["timestep"]
|
||||
c = kwargs["c"]
|
||||
sigmas = c["transformer_options"]["sample_sigmas"]
|
||||
cond_or_uncond = kwargs["cond_or_uncond"]
|
||||
last_step = (len(sigmas) - 1)
|
||||
|
||||
matched_step_index = (sigmas == timestep[0] ).nonzero()
|
||||
if len(matched_step_index) > 0:
|
||||
current_step_index = matched_step_index.item()
|
||||
else:
|
||||
for i in range(len(sigmas) - 1):
|
||||
# walk from beginning of steps until crossing the timestep
|
||||
if (sigmas[i] - timestep) * (sigmas[i + 1] - timestep) <= 0:
|
||||
current_step_index = i
|
||||
break
|
||||
else:
|
||||
current_step_index = 0
|
||||
|
||||
if current_step_index == 0:
|
||||
if hasattr(diffusion_model, "teacache_state"):
|
||||
delattr(diffusion_model, "teacache_state")
|
||||
logging.info("Resetting TeaCache state")
|
||||
|
||||
current_percent = current_step_index / (len(sigmas) - 1)
|
||||
if start_percent <= current_percent <= end_percent:
|
||||
c["transformer_options"]["teacache_enabled"] = True
|
||||
|
||||
context = patch.multiple(
|
||||
diffusion_model,
|
||||
forward=teacache_wanvideo_forward.__get__(diffusion_model, diffusion_model.__class__),
|
||||
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
|
||||
)
|
||||
else:
|
||||
context = nullcontext()
|
||||
with context:
|
||||
out = model_function(input, timestep, **c)
|
||||
if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"):
|
||||
if cond_or_uncond[0] == 0:
|
||||
skipped_steps_cond = diffusion_model.teacache_state["cond"]["teacache_skipped_steps"]
|
||||
skipped_steps_uncond = diffusion_model.teacache_state["uncond"]["teacache_skipped_steps"]
|
||||
|
||||
logging.info("-----------------------------------")
|
||||
logging.info(f"TeaCache skipped:")
|
||||
logging.info(f"{skipped_steps_cond} cond steps")
|
||||
logging.info(f"{skipped_steps_uncond} uncond step")
|
||||
logging.info(f"out of {last_step} steps")
|
||||
logging.info("-----------------------------------")
|
||||
return out
|
||||
return unet_wrapper_function
|
||||
|
||||
model_clone.set_model_unet_function_wrapper(outer_wrapper(start_percent=start_percent, end_percent=end_percent))
|
||||
|
||||
return (model_clone,)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user