Add SkipLayerGuidanceWanVideo

This commit is contained in:
kijai 2025-03-16 17:24:23 +02:00
parent 6b7eeebe44
commit 7c488a16ef
2 changed files with 69 additions and 10 deletions

View File

@ -186,6 +186,7 @@ NODE_CONFIG = {
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
"WanVideoTeaCacheKJ": {"class": WanVideoTeaCacheKJ, "name": "WanVideo Tea Cache (native)"},
"WanVideoEnhanceAVideoKJ": {"class": WanVideoEnhanceAVideoKJ, "name": "WanVideo Enhance A Video (native)"},
"SkipLayerGuidanceWanVideo": {"class": SkipLayerGuidanceWanVideo, "name": "Skip Layer Guidance WanVideo"},
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},

View File

@ -791,7 +791,7 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
@ -853,7 +853,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
if coefficients == "disabled" and rel_l1_thresh > 0.1:
logging.warning("Threshold value is too high for TeaCache without coefficients, consider using coefficients for better results.")
if coefficients != "disabled" and rel_l1_thresh < 0.1:
if coefficients != "disabled" and rel_l1_thresh < 0.1 and "1.3B" not in coefficients:
logging.warning("Threshold value is too low for TeaCache with coefficients, consider using higher threshold value for better results.")
# type_str = str(type(model.model.model_config).__name__)
@ -914,15 +914,15 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
logging.info("\nResetting TeaCache state")
current_percent = current_step_index / (len(sigmas) - 1)
c["transformer_options"]["current_percent"] = current_percent
if start_percent <= current_percent <= end_percent:
c["transformer_options"]["teacache_enabled"] = True
context = patch.multiple(
diffusion_model,
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
)
else:
context = nullcontext()
context = patch.multiple(
diffusion_model,
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
)
with context:
out = model_function(input, timestep, **c)
if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"):
@ -1072,4 +1072,62 @@ class WanVideoEnhanceAVideoKJ:
self_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__)
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", self_attn)
return (model_clone,)
return (model_clone,)
class SkipLayerGuidanceWanVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"blocks": ("STRING", {"default": "10", "multiline": False}),
"start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "slg"
EXPERIMENTAL = True
DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks"
CATEGORY = "advanced/guidance"
def slg(self, model, start_percent, end_percent, blocks):
def skip(args, extra_args):
transformer_options = extra_args.get("transformer_options", {})
if not transformer_options:
raise ValueError("transformer_options not found in extra_args, currrently SkipLayerGuidanceWanVideo only works with TeaCacheKJ")
if start_percent <= transformer_options["current_percent"] <= end_percent:
if args["img"].shape[0] == 2:
prev_img_uncond = args["img"][0].unsqueeze(0)
new_args = {
"img": args["img"][1],
"txt": args["txt"][1],
"vec": args["vec"][1],
"pe": args["pe"][1]
}
block_out = extra_args["original_block"](new_args)
out = {
"img": torch.cat([prev_img_uncond, block_out["img"]], dim=0),
"txt": args["txt"],
"vec": args["vec"],
"pe": args["pe"]
}
else:
if transformer_options.get("cond_or_uncond") == [0]:
out = extra_args["original_block"](args)
else:
out = args
else:
out = extra_args["original_block"](args)
return out
block_list = [int(x.strip()) for x in blocks.split(",")]
double_layers = [int(i) for i in block_list]
logging.info(f"Selected blocks to skip uncond on: {double_layers}")
m = model.clone()
for layer in double_layers:
m.set_model_patch_replace(skip, "dit", "double_block", layer)
return (m, )