From 393ec896f75616098a8c8580a79e4e48d714b309 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 17 Mar 2025 14:11:30 +0200 Subject: [PATCH] Better compile compatibility with various patches Shouldn't drop compile when changing slg or enhance-a-video settings anymore --- nodes/model_optimization_nodes.py | 86 +++++++++++++++++++++---------- 1 file changed, 58 insertions(+), 28 deletions(-) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 2a9d0ad..c7b150d 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -486,26 +486,32 @@ class TorchCompileModelWanVideo: m = model.clone() diffusion_model = m.get_model_object("diffusion_model") torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit - if not self._compiled: - try: - if compile_transformer_blocks_only: - for i, block in enumerate(diffusion_model.blocks): + is_compiled = hasattr(model.model.diffusion_model.blocks[0], "_orig_mod") + if is_compiled: + logging.info(f"Already compiled, not reapplying") + else: + logging.info(f"Not compiled, applying") + try: + if compile_transformer_blocks_only: + for i, block in enumerate(diffusion_model.blocks): + if is_compiled: + compiled_block = torch.compile(block._orig_mod, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) + else: compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) - m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block) - else: - compiled_model = torch.compile(diffusion_model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) - m.add_object_patch("diffusion_model", compiled_model) + m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block) + else: + compiled_model = torch.compile(diffusion_model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) + m.add_object_patch("diffusion_model", compiled_model) - self._compiled = True - compile_settings = { - "backend": backend, - "mode": mode, - "fullgraph": fullgraph, - "dynamic": dynamic, - } - setattr(m.model, "compile_settings", compile_settings) - except: - raise RuntimeError("Failed to compile model") + compile_settings = { + "backend": backend, + "mode": mode, + "fullgraph": fullgraph, + "dynamic": dynamic, + } + setattr(m.model, "compile_settings", compile_settings) + except: + raise RuntimeError("Failed to compile model") return (m, ) class TorchCompileVAE: @@ -1074,9 +1080,14 @@ class WanVideoEnhanceAVideoKJ: model_clone.model_options['transformer_options'] = {} model_clone.model_options["transformer_options"]["enhance_weight"] = weight diffusion_model = model_clone.get_model_object("diffusion_model") + + compile_settings = getattr(model.model, "compile_settings", None) for idx, block in enumerate(diffusion_model.blocks): - self_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__) - model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", self_attn) + patched_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__) + if compile_settings is not None: + patched_attn = torch.compile(patched_attn, mode=compile_settings["mode"], dynamic=compile_settings["dynamic"], fullgraph=compile_settings["fullgraph"], backend=compile_settings["backend"]) + + model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", patched_attn) return (model_clone,) @@ -1098,8 +1109,10 @@ class SkipLayerGuidanceWanVideo: def slg(self, model, start_percent, end_percent, blocks): def skip(args, extra_args): transformer_options = extra_args.get("transformer_options", {}) + original_block = extra_args["original_block"] + if not transformer_options: - raise ValueError("transformer_options not found in extra_args, currrently SkipLayerGuidanceWanVideo only works with TeaCacheKJ") + raise ValueError("transformer_options not found in extra_args, currently SkipLayerGuidanceWanVideo only works with TeaCacheKJ") if start_percent <= transformer_options["current_percent"] <= end_percent: if args["img"].shape[0] == 2: prev_img_uncond = args["img"][0].unsqueeze(0) @@ -1110,7 +1123,8 @@ class SkipLayerGuidanceWanVideo: "vec": args["vec"][1], "pe": args["pe"][1] } - block_out = extra_args["original_block"](new_args) + + block_out = original_block(new_args) out = { "img": torch.cat([prev_img_uncond, block_out["img"]], dim=0), @@ -1120,20 +1134,36 @@ class SkipLayerGuidanceWanVideo: } else: if transformer_options.get("cond_or_uncond") == [0]: - out = extra_args["original_block"](args) + out = original_block(args) else: out = args else: - out = extra_args["original_block"](args) + out = original_block(args) return out block_list = [int(x.strip()) for x in blocks.split(",")] - double_layers = [int(i) for i in block_list] - logging.info(f"Selected blocks to skip uncond on: {double_layers}") + blocks = [int(i) for i in block_list] + logging.info(f"Selected blocks to skip uncond on: {blocks}") m = model.clone() - for layer in double_layers: - m.set_model_patch_replace(skip, "dit", "double_block", layer) + for b in blocks: + #m.set_model_patch_replace(skip, "dit", "double_block", b) + model_options = m.model_options["transformer_options"].copy() + if "patches_replace" not in model_options: + model_options["patches_replace"] = {} + else: + model_options["patches_replace"] = model_options["patches_replace"].copy() + + if "dit" not in model_options["patches_replace"]: + model_options["patches_replace"]["dit"] = {} + else: + model_options["patches_replace"]["dit"] = model_options["patches_replace"]["dit"].copy() + + block = ("double_block", b) + + model_options["patches_replace"]["dit"][block] = skip + m.model_options["transformer_options"] = model_options + return (m, ) \ No newline at end of file