Add SkipLayerGuidanceWanVideo

This commit is contained in:
kijai 2025-03-16 17:24:23 +02:00
parent 6b7eeebe44
commit 7c488a16ef
2 changed files with 69 additions and 10 deletions

View File

@ -186,6 +186,7 @@ NODE_CONFIG = {
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"}, "ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
"WanVideoTeaCacheKJ": {"class": WanVideoTeaCacheKJ, "name": "WanVideo Tea Cache (native)"}, "WanVideoTeaCacheKJ": {"class": WanVideoTeaCacheKJ, "name": "WanVideo Tea Cache (native)"},
"WanVideoEnhanceAVideoKJ": {"class": WanVideoEnhanceAVideoKJ, "name": "WanVideo Enhance A Video (native)"}, "WanVideoEnhanceAVideoKJ": {"class": WanVideoEnhanceAVideoKJ, "name": "WanVideo Enhance A Video (native)"},
"SkipLayerGuidanceWanVideo": {"class": SkipLayerGuidanceWanVideo, "name": "Skip Layer Guidance WanVideo"},
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"}, "TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"}, "HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},

View File

@ -791,7 +791,7 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
x = out["img"] x = out["img"]
else: else:
x = block(x, e=e0, freqs=freqs, context=context) x = block(x, e=e0, freqs=freqs, context=context)
@ -853,7 +853,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
if coefficients == "disabled" and rel_l1_thresh > 0.1: if coefficients == "disabled" and rel_l1_thresh > 0.1:
logging.warning("Threshold value is too high for TeaCache without coefficients, consider using coefficients for better results.") logging.warning("Threshold value is too high for TeaCache without coefficients, consider using coefficients for better results.")
if coefficients != "disabled" and rel_l1_thresh < 0.1: if coefficients != "disabled" and rel_l1_thresh < 0.1 and "1.3B" not in coefficients:
logging.warning("Threshold value is too low for TeaCache with coefficients, consider using higher threshold value for better results.") logging.warning("Threshold value is too low for TeaCache with coefficients, consider using higher threshold value for better results.")
# type_str = str(type(model.model.model_config).__name__) # type_str = str(type(model.model.model_config).__name__)
@ -914,6 +914,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
logging.info("\nResetting TeaCache state") logging.info("\nResetting TeaCache state")
current_percent = current_step_index / (len(sigmas) - 1) current_percent = current_step_index / (len(sigmas) - 1)
c["transformer_options"]["current_percent"] = current_percent
if start_percent <= current_percent <= end_percent: if start_percent <= current_percent <= end_percent:
c["transformer_options"]["teacache_enabled"] = True c["transformer_options"]["teacache_enabled"] = True
@ -921,8 +922,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
diffusion_model, diffusion_model,
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__) forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
) )
else:
context = nullcontext()
with context: with context:
out = model_function(input, timestep, **c) out = model_function(input, timestep, **c)
if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"): if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"):
@ -1073,3 +1073,61 @@ class WanVideoEnhanceAVideoKJ:
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", self_attn) model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", self_attn)
return (model_clone,) return (model_clone,)
class SkipLayerGuidanceWanVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"blocks": ("STRING", {"default": "10", "multiline": False}),
"start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "slg"
EXPERIMENTAL = True
DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks"
CATEGORY = "advanced/guidance"
def slg(self, model, start_percent, end_percent, blocks):
def skip(args, extra_args):
transformer_options = extra_args.get("transformer_options", {})
if not transformer_options:
raise ValueError("transformer_options not found in extra_args, currrently SkipLayerGuidanceWanVideo only works with TeaCacheKJ")
if start_percent <= transformer_options["current_percent"] <= end_percent:
if args["img"].shape[0] == 2:
prev_img_uncond = args["img"][0].unsqueeze(0)
new_args = {
"img": args["img"][1],
"txt": args["txt"][1],
"vec": args["vec"][1],
"pe": args["pe"][1]
}
block_out = extra_args["original_block"](new_args)
out = {
"img": torch.cat([prev_img_uncond, block_out["img"]], dim=0),
"txt": args["txt"],
"vec": args["vec"],
"pe": args["pe"]
}
else:
if transformer_options.get("cond_or_uncond") == [0]:
out = extra_args["original_block"](args)
else:
out = args
else:
out = extra_args["original_block"](args)
return out
block_list = [int(x.strip()) for x in blocks.split(",")]
double_layers = [int(i) for i in block_list]
logging.info(f"Selected blocks to skip uncond on: {double_layers}")
m = model.clone()
for layer in double_layers:
m.set_model_patch_replace(skip, "dit", "double_block", layer)
return (m, )