From 68db110554d5f1d9bef8d027a111a49fd7f85e1b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 7 Mar 2025 16:35:30 +0200 Subject: [PATCH] Exclude TeaCache from compile to avoid possible compile errors, make compiling whole model default for WanVideo --- nodes/model_optimization_nodes.py | 92 +++++++++++++++++-------------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index d523232..80ea967 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -462,7 +462,7 @@ class TorchCompileModelWanVideo: "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), - "compile_transformer_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile all transformer blocks"}), + "compile_transformer_blocks_only": ("BOOLEAN", {"default": False, "tooltip": "Compile only transformer blocks"}), }, } RETURN_TYPES = ("MODEL",) @@ -471,16 +471,20 @@ class TorchCompileModelWanVideo: CATEGORY = "KJNodes/torchcompile" EXPERIMENTAL = True - def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks): + def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only): m = model.clone() diffusion_model = m.get_model_object("diffusion_model") torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit if not self._compiled: try: - if compile_transformer_blocks: + if compile_transformer_blocks_only: for i, block in enumerate(diffusion_model.blocks): compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block) + else: + compiled_model = torch.compile(diffusion_model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) + m.add_object_patch("diffusion_model", compiled_model) + self._compiled = True compile_settings = { "backend": backend, @@ -731,54 +735,58 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) - #teacache for cond and uncond separately - rel_l1_thresh = kwargs["transformer_options"]["rel_l1_thresh"] - cache_device = kwargs["transformer_options"]["teacache_device"] - is_cond = True if kwargs["transformer_options"]["cond_or_uncond"] == [0] else False + @torch.compiler.disable() + def tea_cache(x, e0, e, kwargs): + #teacache for cond and uncond separately + rel_l1_thresh = kwargs["transformer_options"]["rel_l1_thresh"] + + is_cond = True if kwargs["transformer_options"]["cond_or_uncond"] == [0] else False - should_calc = True - suffix = "cond" if is_cond else "uncond" + should_calc = True + suffix = "cond" if is_cond else "uncond" - # Init cache dict if not exists - if not hasattr(self, 'teacache_state'): - self.teacache_state = { - 'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None, - 'teacache_skipped_steps': 0, 'previous_residual': None}, - 'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None, - 'teacache_skipped_steps': 0, 'previous_residual': None} - } - logging.info("\nTeaCache: Initialized") + # Init cache dict if not exists + if not hasattr(self, 'teacache_state'): + self.teacache_state = { + 'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None, + 'teacache_skipped_steps': 0, 'previous_residual': None}, + 'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None, + 'teacache_skipped_steps': 0, 'previous_residual': None} + } + logging.info("\nTeaCache: Initialized") - cache = self.teacache_state[suffix] + cache = self.teacache_state[suffix] - if cache['prev_input'] is not None: - if kwargs["transformer_options"]["coefficients"] == []: - temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0) - curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1 - else: - rescale_func = np.poly1d(kwargs["transformer_options"]["coefficients"]) - curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item()) - try: - if curr_acc_dist < rel_l1_thresh: - should_calc = False - cache['accumulated_rel_l1_distance'] = curr_acc_dist + if cache['prev_input'] is not None: + if kwargs["transformer_options"]["coefficients"] == []: + temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0) + curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1 else: + rescale_func = np.poly1d(kwargs["transformer_options"]["coefficients"]) + curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item()) + try: + if curr_acc_dist < rel_l1_thresh: + should_calc = False + cache['accumulated_rel_l1_distance'] = curr_acc_dist + else: + should_calc = True + cache['accumulated_rel_l1_distance'] = 0 + except: should_calc = True cache['accumulated_rel_l1_distance'] = 0 - except: - should_calc = True - cache['accumulated_rel_l1_distance'] = 0 - if kwargs["transformer_options"]["coefficients"] == []: - cache['prev_input'] = e0.clone().detach() - else: - cache['prev_input'] = e.clone().detach() + if kwargs["transformer_options"]["coefficients"] == []: + cache['prev_input'] = e0.clone().detach() + else: + cache['prev_input'] = e.clone().detach() - if not should_calc: - x += cache['previous_residual'].to(x.device) - cache['teacache_skipped_steps'] += 1 - #print(f"TeaCache: Skipping {suffix} step") + if not should_calc: + x += cache['previous_residual'].to(x.device) + cache['teacache_skipped_steps'] += 1 + #print(f"TeaCache: Skipping {suffix} step") + return should_calc, cache + should_calc, cache = tea_cache(x, e0, e, kwargs) if should_calc: original_x = x.clone().detach() # arguments @@ -790,7 +798,7 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non for block in self.blocks: x = block(x, **block_wargs) - cache['previous_residual'] = (x - original_x).to(cache_device) + cache['previous_residual'] = (x - original_x).to(kwargs["transformer_options"]["teacache_device"]) # head x = self.head(x, e)