diff --git a/__init__.py b/__init__.py index a323a64..07fac28 100644 --- a/__init__.py +++ b/__init__.py @@ -186,6 +186,7 @@ NODE_CONFIG = { "ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"}, "WanVideoTeaCacheKJ": {"class": WanVideoTeaCacheKJ, "name": "WanVideo Tea Cache (native)"}, "WanVideoEnhanceAVideoKJ": {"class": WanVideoEnhanceAVideoKJ, "name": "WanVideo Enhance A Video (native)"}, + "SkipLayerGuidanceWanVideo": {"class": SkipLayerGuidanceWanVideo, "name": "Skip Layer Guidance WanVideo"}, "TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"}, "HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index bc13353..79ebee1 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -791,7 +791,7 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non out = {} out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options}) x = out["img"] else: x = block(x, e=e0, freqs=freqs, context=context) @@ -853,7 +853,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC if coefficients == "disabled" and rel_l1_thresh > 0.1: logging.warning("Threshold value is too high for TeaCache without coefficients, consider using coefficients for better results.") - if coefficients != "disabled" and rel_l1_thresh < 0.1: + if coefficients != "disabled" and rel_l1_thresh < 0.1 and "1.3B" not in coefficients: logging.warning("Threshold value is too low for TeaCache with coefficients, consider using higher threshold value for better results.") # type_str = str(type(model.model.model_config).__name__) @@ -914,15 +914,15 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC logging.info("\nResetting TeaCache state") current_percent = current_step_index / (len(sigmas) - 1) + c["transformer_options"]["current_percent"] = current_percent if start_percent <= current_percent <= end_percent: c["transformer_options"]["teacache_enabled"] = True - - context = patch.multiple( - diffusion_model, - forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__) - ) - else: - context = nullcontext() + + context = patch.multiple( + diffusion_model, + forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__) + ) + with context: out = model_function(input, timestep, **c) if current_step_index+1 == last_step and hasattr(diffusion_model, "teacache_state"): @@ -1072,4 +1072,62 @@ class WanVideoEnhanceAVideoKJ: self_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__) model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", self_attn) - return (model_clone,) \ No newline at end of file + return (model_clone,) + +class SkipLayerGuidanceWanVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "blocks": ("STRING", {"default": "10", "multiline": False}), + "start_percent": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "slg" + EXPERIMENTAL = True + DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks" + + CATEGORY = "advanced/guidance" + + def slg(self, model, start_percent, end_percent, blocks): + def skip(args, extra_args): + transformer_options = extra_args.get("transformer_options", {}) + if not transformer_options: + raise ValueError("transformer_options not found in extra_args, currrently SkipLayerGuidanceWanVideo only works with TeaCacheKJ") + if start_percent <= transformer_options["current_percent"] <= end_percent: + if args["img"].shape[0] == 2: + prev_img_uncond = args["img"][0].unsqueeze(0) + + new_args = { + "img": args["img"][1], + "txt": args["txt"][1], + "vec": args["vec"][1], + "pe": args["pe"][1] + } + block_out = extra_args["original_block"](new_args) + + out = { + "img": torch.cat([prev_img_uncond, block_out["img"]], dim=0), + "txt": args["txt"], + "vec": args["vec"], + "pe": args["pe"] + } + else: + if transformer_options.get("cond_or_uncond") == [0]: + out = extra_args["original_block"](args) + else: + out = args + else: + out = extra_args["original_block"](args) + return out + + block_list = [int(x.strip()) for x in blocks.split(",")] + double_layers = [int(i) for i in block_list] + logging.info(f"Selected blocks to skip uncond on: {double_layers}") + + m = model.clone() + + for layer in double_layers: + m.set_model_patch_replace(skip, "dit", "double_block", layer) + + return (m, ) \ No newline at end of file