From c19ad3491691dde7219bc0dab17acefb81bd0be0 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 14 Mar 2025 18:42:31 +0200 Subject: [PATCH] update to match comfyui latest update --- nodes/model_optimization_nodes.py | 54 ++++++++++++------------------- 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index ce3a3b6..8084575 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -712,24 +712,7 @@ def relative_l1_distance(last_tensor, current_tensor): relative_l1_distance = l1_distance / norm return relative_l1_distance.to(torch.float32) -#for now as there doesn't seem to be a way to pass transformer_options to the forward_orig currently -def teacache_wanvideo_forward(self, x, timestep, context, clip_fea=None, **kwargs): - bs, c, t, h, w = x.shape - x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) - patch_size = self.patch_size - t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) - h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) - w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) - img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) - img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) - img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) - img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) - - freqs = self.rope_embedder(img_ids).movedim(1, 2) - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, **kwargs)[:, :, :t, :h, :w] - -def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, **kwargs): +def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] @@ -749,9 +732,9 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non @torch.compiler.disable() def tea_cache(x, e0, e, kwargs): #teacache for cond and uncond separately - rel_l1_thresh = kwargs["transformer_options"]["rel_l1_thresh"] + rel_l1_thresh = transformer_options["rel_l1_thresh"] - is_cond = True if kwargs["transformer_options"]["cond_or_uncond"] == [0] else False + is_cond = True if transformer_options["cond_or_uncond"] == [0] else False should_calc = True suffix = "cond" if is_cond else "uncond" @@ -769,11 +752,11 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non cache = self.teacache_state[suffix] if cache['prev_input'] is not None: - if kwargs["transformer_options"]["coefficients"] == []: + if transformer_options["coefficients"] == []: temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0) curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1 else: - rescale_func = np.poly1d(kwargs["transformer_options"]["coefficients"]) + rescale_func = np.poly1d(transformer_options["coefficients"]) curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item()) try: if curr_acc_dist < rel_l1_thresh: @@ -786,7 +769,7 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non should_calc = True cache['accumulated_rel_l1_distance'] = 0 - if kwargs["transformer_options"]["coefficients"] == []: + if transformer_options["coefficients"] == []: cache['prev_input'] = e0.clone().detach() else: cache['prev_input'] = e.clone().detach() @@ -800,16 +783,20 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non should_calc, cache = tea_cache(x, e0, e, kwargs) if should_calc: original_x = x.clone().detach() - # arguments - block_wargs = dict( - e=e0, - freqs=freqs, - context=context) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context) - for block in self.blocks: - x = block(x, **block_wargs) - - cache['previous_residual'] = (x - original_x).to(kwargs["transformer_options"]["teacache_device"]) + cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"]) # head x = self.head(x, e) @@ -932,7 +919,6 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC context = patch.multiple( diffusion_model, - forward=teacache_wanvideo_forward.__get__(diffusion_model, diffusion_model.__class__), forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__) ) else: @@ -961,7 +947,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.math import apply_rope -from comfy.ldm.wan.model import WanSelfAttention + def modified_wan_self_attention_forward(self, x, freqs): r""" Args: