From d584c711a374e8267496dc5241ff879588212360 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 15 Jun 2025 22:05:24 +0300 Subject: [PATCH] Experimental NAG node --- __init__.py | 1 + nodes/model_optimization_nodes.py | 117 ++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/__init__.py b/__init__.py index 81ef8c2..fce7a50 100644 --- a/__init__.py +++ b/__init__.py @@ -196,6 +196,7 @@ NODE_CONFIG = { "HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"}, "CFGZeroStarAndInit": {"class": CFGZeroStarAndInit, "name": "CFG Zero Star/Init"}, "ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"}, + "WanVideoNAG": {"class": WanVideoNAG, "name": "WanVideoNAG"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 9f1c5ef..18f0af4 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -1444,6 +1444,123 @@ class WanVideoEnhanceAVideoKJ: model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", patched_attn) return (model_clone,) + +#region NAG +def wan_crossattn_forward_nag(self, x, context, **kwargs): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + """ + + # NAG text attention + if context.shape[0] == 2: + x, x_real_negative = torch.chunk(x, 2, dim=0) + context_positive, context_negative = torch.chunk(context, 2, dim=0) + else: + context_positive = context + context_negative = None + + nag_scale = self.nag_scale + nag_alpha = self.nag_alpha + nag_tau = self.nag_tau + + q = self.norm_q(self.q(x)) + + k_positive = self.norm_k(self.k(context_positive)) + v_positive = self.v(context_positive) + k_negative = self.norm_k(self.k(self.nag_context)) + v_negative = self.v(self.nag_context) + + x_positive = comfy.ldm.modules.attention.optimized_attention(q, k_positive, v_positive, heads=self.num_heads) + x_positive = x_positive.flatten(2) + + x_negative = comfy.ldm.modules.attention.optimized_attention(q, k_negative, v_negative, heads=self.num_heads) + x_negative = x_negative.flatten(2) + + nag_guidance = x_positive * nag_scale - x_negative * (nag_scale - 1) + + norm_positive = torch.norm(x_positive, p=1, dim=-1, keepdim=True).expand_as(x_positive) + norm_guidance = torch.norm(nag_guidance, p=1, dim=-1, keepdim=True).expand_as(nag_guidance) + + scale = norm_guidance / norm_positive + scale = torch.nan_to_num(scale, nan=10.0) + + mask = scale > nag_tau + adjustment = (norm_positive * nag_tau) / (norm_guidance + 1e-7) + nag_guidance = torch.where(mask, nag_guidance * adjustment, nag_guidance) + + x = nag_guidance * nag_alpha + x_positive * (1 - nag_alpha) + + if context_negative is not None: + q_real_negative = self.norm_q(self.q(x_real_negative)) + k_real_negative = self.norm_k(self.k(context_negative)) + v_real_negative = self.v(context_negative) + x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads) + #x_real_negative = x_real_negative.flatten(2) + x = torch.cat([x, x_real_negative], dim=0) + + x = self.o(x) + return x + +class WanCrossAttentionPatch: + def __init__(self, context, nag_scale, nag_alpha, nag_tau): + self.nag_context = context + self.nag_scale = nag_scale + self.nag_alpha = nag_alpha + self.nag_tau = nag_tau + def __get__(self, obj, objtype=None): + # Create bound method with stored parameters + def wrapped_attention(self_module, *args, **kwargs): + self_module.nag_context = self.nag_context + self_module.nag_scale = self.nag_scale + self_module.nag_alpha = self.nag_alpha + self_module.nag_tau = self.nag_tau + return wan_crossattn_forward_nag(self_module, *args, **kwargs) + return types.MethodType(wrapped_attention, obj) + +class WanVideoNAG: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "conditioning": ("CONDITIONING",), + "nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of the effect"}), + "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Alpha of the effect"}), + "nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Tau of the effect"}), + } + } + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "patch" + CATEGORY = "KJNodes/experimental" + DESCRIPTION = "https://github.com/ChenDarYen/Normalized-Attention-Guidance" + EXPERIMENTAL = True + + def patch(self, model, conditioning, nag_scale, nag_alpha, nag_tau): + if nag_scale == 0: + return (model,) + + device = mm.get_torch_device() + dtype = mm.unet_dtype() + + model_clone = model.clone() + # if 'transformer_options' not in model_clone.model_options: + # model_clone.model_options['transformer_options'] = {} + + diffusion_model = model_clone.get_model_object("diffusion_model") + + diffusion_model.text_embedding.to(device) + context = diffusion_model.text_embedding(conditioning[0][0].to(device, dtype)) + + for idx, block in enumerate(diffusion_model.blocks): + patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau).__get__(block.cross_attn, block.__class__) + + model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.cross_attn.forward", patched_attn) + + return (model_clone,) class SkipLayerGuidanceWanVideo: @classmethod