Experimental NAG node

This commit is contained in:
kijai 2025-06-15 22:05:24 +03:00
parent aeab1a7de5
commit d584c711a3
2 changed files with 118 additions and 0 deletions

View File

@ -196,6 +196,7 @@ NODE_CONFIG = {
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
"CFGZeroStarAndInit": {"class": CFGZeroStarAndInit, "name": "CFG Zero Star/Init"},
"ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"},
"WanVideoNAG": {"class": WanVideoNAG, "name": "WanVideoNAG"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -1444,6 +1444,123 @@ class WanVideoEnhanceAVideoKJ:
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", patched_attn)
return (model_clone,)
#region NAG
def wan_crossattn_forward_nag(self, x, context, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
# NAG text attention
if context.shape[0] == 2:
x, x_real_negative = torch.chunk(x, 2, dim=0)
context_positive, context_negative = torch.chunk(context, 2, dim=0)
else:
context_positive = context
context_negative = None
nag_scale = self.nag_scale
nag_alpha = self.nag_alpha
nag_tau = self.nag_tau
q = self.norm_q(self.q(x))
k_positive = self.norm_k(self.k(context_positive))
v_positive = self.v(context_positive)
k_negative = self.norm_k(self.k(self.nag_context))
v_negative = self.v(self.nag_context)
x_positive = comfy.ldm.modules.attention.optimized_attention(q, k_positive, v_positive, heads=self.num_heads)
x_positive = x_positive.flatten(2)
x_negative = comfy.ldm.modules.attention.optimized_attention(q, k_negative, v_negative, heads=self.num_heads)
x_negative = x_negative.flatten(2)
nag_guidance = x_positive * nag_scale - x_negative * (nag_scale - 1)
norm_positive = torch.norm(x_positive, p=1, dim=-1, keepdim=True).expand_as(x_positive)
norm_guidance = torch.norm(nag_guidance, p=1, dim=-1, keepdim=True).expand_as(nag_guidance)
scale = norm_guidance / norm_positive
scale = torch.nan_to_num(scale, nan=10.0)
mask = scale > nag_tau
adjustment = (norm_positive * nag_tau) / (norm_guidance + 1e-7)
nag_guidance = torch.where(mask, nag_guidance * adjustment, nag_guidance)
x = nag_guidance * nag_alpha + x_positive * (1 - nag_alpha)
if context_negative is not None:
q_real_negative = self.norm_q(self.q(x_real_negative))
k_real_negative = self.norm_k(self.k(context_negative))
v_real_negative = self.v(context_negative)
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
#x_real_negative = x_real_negative.flatten(2)
x = torch.cat([x, x_real_negative], dim=0)
x = self.o(x)
return x
class WanCrossAttentionPatch:
def __init__(self, context, nag_scale, nag_alpha, nag_tau):
self.nag_context = context
self.nag_scale = nag_scale
self.nag_alpha = nag_alpha
self.nag_tau = nag_tau
def __get__(self, obj, objtype=None):
# Create bound method with stored parameters
def wrapped_attention(self_module, *args, **kwargs):
self_module.nag_context = self.nag_context
self_module.nag_scale = self.nag_scale
self_module.nag_alpha = self.nag_alpha
self_module.nag_tau = self.nag_tau
return wan_crossattn_forward_nag(self_module, *args, **kwargs)
return types.MethodType(wrapped_attention, obj)
class WanVideoNAG:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"conditioning": ("CONDITIONING",),
"nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of the effect"}),
"nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Alpha of the effect"}),
"nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Tau of the effect"}),
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "patch"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = "https://github.com/ChenDarYen/Normalized-Attention-Guidance"
EXPERIMENTAL = True
def patch(self, model, conditioning, nag_scale, nag_alpha, nag_tau):
if nag_scale == 0:
return (model,)
device = mm.get_torch_device()
dtype = mm.unet_dtype()
model_clone = model.clone()
# if 'transformer_options' not in model_clone.model_options:
# model_clone.model_options['transformer_options'] = {}
diffusion_model = model_clone.get_model_object("diffusion_model")
diffusion_model.text_embedding.to(device)
context = diffusion_model.text_embedding(conditioning[0][0].to(device, dtype))
for idx, block in enumerate(diffusion_model.blocks):
patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau).__get__(block.cross_attn, block.__class__)
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.cross_attn.forward", patched_attn)
return (model_clone,)
class SkipLayerGuidanceWanVideo:
@classmethod