diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 7e7e769..e858248 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -1316,6 +1316,7 @@ class WanVideoTeaCacheKJ: RETURN_NAMES = ("model",) FUNCTION = "patch_teacache" CATEGORY = "KJNodes/teacache" + DEPRECATED = True DESCRIPTION = """ Patch WanVideo model to use TeaCache. Speeds up inference by caching the output and applying it instead of doing the step. Best results are achieved by choosing the @@ -1450,7 +1451,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC from comfy.ldm.flux.math import apply_rope -def modified_wan_self_attention_forward(self, x, freqs): +def modified_wan_self_attention_forward(self, x, freqs, transformer_options={}): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] @@ -1470,13 +1471,23 @@ def modified_wan_self_attention_forward(self, x, freqs): q, k = apply_rope(q, k, freqs) feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight) - - x = comfy.ldm.modules.attention.optimized_attention( - q.view(b, s, n * d), - k.view(b, s, n * d), - v, - heads=self.num_heads, - ) + + try: + x = comfy.ldm.modules.attention.optimized_attention( + q.view(b, s, n * d), + k.view(b, s, n * d), + v, + heads=self.num_heads, + transformer_options=transformer_options, + ) + except: + # backward compatibility for now + x = comfy.ldm.modules.attention.attention( + q.view(b, s, n * d), + k.view(b, s, n * d), + v, + heads=self.num_heads, + ) x = self.o(x) @@ -1581,14 +1592,18 @@ class WanVideoEnhanceAVideoKJ: return (model_clone,) -def normalized_attention_guidance(self, query, context_positive, context_negative): +def normalized_attention_guidance(self, query, context_positive, context_negative, transformer_options={}): k_positive = self.norm_k(self.k(context_positive)) v_positive = self.v(context_positive) k_negative = self.norm_k(self.k(context_negative)) v_negative = self.v(context_negative) - x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2) - x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2) + try: + x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads, transformer_options=transformer_options).flatten(2) + x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads, transformer_options=transformer_options).flatten(2) + except: #backwards compatibility for now + x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2) + x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2) nag_guidance = x_positive * self.nag_scale - x_negative * (self.nag_scale - 1) @@ -1607,7 +1622,7 @@ def normalized_attention_guidance(self, query, context_positive, context_negativ return x #region NAG -def wan_crossattn_forward_nag(self, x, context, **kwargs): +def wan_crossattn_forward_nag(self, x, context, transformer_options={}, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -1632,14 +1647,20 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs): nag_context = self.nag_context if self.input_type == "batch": nag_context = nag_context.repeat(x_pos.shape[0], 1, 1) - x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context) + try: + x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context, transformer_options=transformer_options) + except: #backwards compatibility for now + x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context) # Negative branch if x_neg is not None and context_neg is not None: q_neg = self.norm_q(self.q(x_neg)) k_neg = self.norm_k(self.k(context_neg)) v_neg = self.v(context_neg) - x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads) + try: + x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads, transformer_options=transformer_options) + except: #backwards compatibility for now + x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads) x = torch.cat([x_pos_out, x_neg_out], dim=0) else: x = x_pos_out @@ -1647,7 +1668,7 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs): return self.o(x) -def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len): +def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len, transformer_options={}, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -1659,7 +1680,10 @@ def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len): q_img = self.norm_q(self.q(x)) k_img = self.norm_k_img(self.k_img(context_img)) v_img = self.v_img(context_img) - img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads) + try: + img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options) + except: #backwards compatibility for now + img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads) if context.shape[0] == 2: x, x_real_negative = torch.chunk(x, 2, dim=0) @@ -1670,13 +1694,16 @@ def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len): q = self.norm_q(self.q(x)) - x = normalized_attention_guidance(self, q, context_positive, self.nag_context) + x = normalized_attention_guidance(self, q, context_positive, self.nag_context, transformer_options=transformer_options) if context_negative is not None: q_real_negative = self.norm_q(self.q(x_real_negative)) k_real_negative = self.norm_k(self.k(context_negative)) v_real_negative = self.v(context_negative) - x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads) + try: + x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads, transformer_options=transformer_options) + except: #backwards compatibility for now + x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads) x = torch.cat([x, x_real_negative], dim=0) # output @@ -1766,7 +1793,7 @@ class SkipLayerGuidanceWanVideo: FUNCTION = "slg" EXPERIMENTAL = True DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks" - + DEPRECATED = True CATEGORY = "advanced/guidance" def slg(self, model, start_percent, end_percent, blocks):