From 7afb0f906a6894d1fe8c3a143e8fe02595e0ed10 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 16 Jun 2025 00:35:53 +0300 Subject: [PATCH] Update model_optimization_nodes.py --- nodes/model_optimization_nodes.py | 109 +++++++++++++++++++++--------- 1 file changed, 78 insertions(+), 31 deletions(-) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 18f0af4..9d7c6cb 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -1444,38 +1444,21 @@ class WanVideoEnhanceAVideoKJ: model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", patched_attn) return (model_clone,) - -#region NAG -def wan_crossattn_forward_nag(self, x, context, **kwargs): - r""" - Args: - x(Tensor): Shape [B, L1, C] - context(Tensor): Shape [B, L2, C] - """ - - # NAG text attention - if context.shape[0] == 2: - x, x_real_negative = torch.chunk(x, 2, dim=0) - context_positive, context_negative = torch.chunk(context, 2, dim=0) - else: - context_positive = context - context_negative = None +def normalized_attention_guidance(self, query, context_positive, context_negative): nag_scale = self.nag_scale nag_alpha = self.nag_alpha nag_tau = self.nag_tau - q = self.norm_q(self.q(x)) - k_positive = self.norm_k(self.k(context_positive)) v_positive = self.v(context_positive) - k_negative = self.norm_k(self.k(self.nag_context)) - v_negative = self.v(self.nag_context) + k_negative = self.norm_k(self.k(context_negative)) + v_negative = self.v(context_negative) - x_positive = comfy.ldm.modules.attention.optimized_attention(q, k_positive, v_positive, heads=self.num_heads) + x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads) x_positive = x_positive.flatten(2) - x_negative = comfy.ldm.modules.attention.optimized_attention(q, k_negative, v_negative, heads=self.num_heads) + x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads) x_negative = x_negative.flatten(2) nag_guidance = x_positive * nag_scale - x_negative * (nag_scale - 1) @@ -1491,24 +1474,84 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs): nag_guidance = torch.where(mask, nag_guidance * adjustment, nag_guidance) x = nag_guidance * nag_alpha + x_positive * (1 - nag_alpha) + del nag_guidance + + return x + +#region NAG +def wan_crossattn_forward_nag(self, x, context, **kwargs): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + """ + + if context.shape[0] == 2: + x, x_real_negative = torch.chunk(x, 2, dim=0) + context_positive, context_negative = torch.chunk(context, 2, dim=0) + else: + context_positive = context + context_negative = None + + q = self.norm_q(self.q(x)) + + x = normalized_attention_guidance(self, q, context_positive, self.nag_context) if context_negative is not None: q_real_negative = self.norm_q(self.q(x_real_negative)) k_real_negative = self.norm_k(self.k(context_negative)) v_real_negative = self.v(context_negative) x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads) - #x_real_negative = x_real_negative.flatten(2) x = torch.cat([x, x_real_negative], dim=0) x = self.o(x) return x + +def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + """ + context_img = context[:, :context_img_len] + context = context[:, context_img_len:] + + q_img = self.norm_q(self.q(x)) + k_img = self.norm_k_img(self.k_img(context_img)) + v_img = self.v_img(context_img) + img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads) + + if context.shape[0] == 2: + x, x_real_negative = torch.chunk(x, 2, dim=0) + context_positive, context_negative = torch.chunk(context, 2, dim=0) + else: + context_positive = context + context_negative = None + + q = self.norm_q(self.q(x)) + + x = normalized_attention_guidance(self, q, context_positive, self.nag_context) + + if context_negative is not None: + q_real_negative = self.norm_q(self.q(x_real_negative)) + k_real_negative = self.norm_k(self.k(context_negative)) + v_real_negative = self.v(context_negative) + x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads) + x = torch.cat([x, x_real_negative], dim=0) + + # output + x = x + img_x + x = self.o(x) + return x + class WanCrossAttentionPatch: - def __init__(self, context, nag_scale, nag_alpha, nag_tau): + def __init__(self, context, nag_scale, nag_alpha, nag_tau, i2v=False): self.nag_context = context self.nag_scale = nag_scale self.nag_alpha = nag_alpha self.nag_tau = nag_tau + self.i2v = i2v def __get__(self, obj, objtype=None): # Create bound method with stored parameters def wrapped_attention(self_module, *args, **kwargs): @@ -1516,7 +1559,10 @@ class WanCrossAttentionPatch: self_module.nag_scale = self.nag_scale self_module.nag_alpha = self.nag_alpha self_module.nag_tau = self.nag_tau - return wan_crossattn_forward_nag(self_module, *args, **kwargs) + if self.i2v: + return wan_i2v_crossattn_forward_nag(self_module, *args, **kwargs) + else: + return wan_crossattn_forward_nag(self_module, *args, **kwargs) return types.MethodType(wrapped_attention, obj) class WanVideoNAG: @@ -1526,9 +1572,9 @@ class WanVideoNAG: "required": { "model": ("MODEL",), "conditioning": ("CONDITIONING",), - "nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of the effect"}), - "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Alpha of the effect"}), - "nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Tau of the effect"}), + "nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of negative guidance effect"}), + "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Mixing coefficient in that controls the balance between the normalized guided representation and the original positive representation."}), + "nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Clipping threshold that controls how much the guided attention can deviate from the positive attention."}), } } @@ -1547,16 +1593,17 @@ class WanVideoNAG: dtype = mm.unet_dtype() model_clone = model.clone() - # if 'transformer_options' not in model_clone.model_options: - # model_clone.model_options['transformer_options'] = {} diffusion_model = model_clone.get_model_object("diffusion_model") diffusion_model.text_embedding.to(device) context = diffusion_model.text_embedding(conditioning[0][0].to(device, dtype)) + + type_str = str(type(model.model.model_config).__name__) + i2v = True if "WAN21_I2V" in type_str else False for idx, block in enumerate(diffusion_model.blocks): - patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau).__get__(block.cross_attn, block.__class__) + patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau, i2v).__get__(block.cross_attn, block.__class__) model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.cross_attn.forward", patched_attn)