NAG for batches

This commit is contained in:
kijai 2025-08-02 21:02:48 +03:00
parent a6b867b63a
commit fbdb08f9d6

View File

@ -1451,34 +1451,26 @@ class WanVideoEnhanceAVideoKJ:
return (model_clone,)
def normalized_attention_guidance(self, query, context_positive, context_negative):
nag_scale = self.nag_scale
nag_alpha = self.nag_alpha
nag_tau = self.nag_tau
k_positive = self.norm_k(self.k(context_positive))
v_positive = self.v(context_positive)
k_negative = self.norm_k(self.k(context_negative))
v_negative = self.v(context_negative)
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads)
x_positive = x_positive.flatten(2)
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2)
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2)
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads)
x_negative = x_negative.flatten(2)
nag_guidance = x_positive * self.nag_scale - x_negative * (self.nag_scale - 1)
nag_guidance = x_positive * nag_scale - x_negative * (nag_scale - 1)
norm_positive = torch.norm(x_positive, p=1, dim=-1, keepdim=True).expand_as(x_positive)
norm_guidance = torch.norm(nag_guidance, p=1, dim=-1, keepdim=True).expand_as(nag_guidance)
scale = norm_guidance / norm_positive
scale = torch.nan_to_num(scale, nan=10.0)
mask = scale > nag_tau
adjustment = (norm_positive * nag_tau) / (norm_guidance + 1e-7)
scale = torch.nan_to_num(norm_guidance / norm_positive, nan=10.0)
mask = scale > self.nag_tau
adjustment = (norm_positive * self.nag_tau) / (norm_guidance + 1e-7)
nag_guidance = torch.where(mask, nag_guidance * adjustment, nag_guidance)
x = nag_guidance * nag_alpha + x_positive * (1 - nag_alpha)
x = nag_guidance * self.nag_alpha + x_positive * (1 - self.nag_alpha)
del nag_guidance
return x
@ -1490,27 +1482,38 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
if context.shape[0] == 2:
x, x_real_negative = torch.chunk(x, 2, dim=0)
context_positive, context_negative = torch.chunk(context, 2, dim=0)
# Determine batch splitting and context handling
if self.input_type == "default":
# Single or [pos, neg] pair
if context.shape[0] == 1:
x_pos, context_pos = x, context
x_neg, context_neg = None, None
else:
x_pos, x_neg = torch.chunk(x, 2, dim=0)
context_pos, context_neg = torch.chunk(context, 2, dim=0)
elif self.input_type == "batch":
# Standard batch, no CFG
x_pos, context_pos = x, context
x_neg, context_neg = None, None
# Positive branch
q_pos = self.norm_q(self.q(x_pos))
nag_context = self.nag_context
if self.input_type == "batch":
nag_context = nag_context.repeat(x_pos.shape[0], 1, 1)
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context)
# Negative branch
if x_neg is not None and context_neg is not None:
q_neg = self.norm_q(self.q(x_neg))
k_neg = self.norm_k(self.k(context_neg))
v_neg = self.v(context_neg)
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads)
x = torch.cat([x_pos_out, x_neg_out], dim=0)
else:
context_positive = context
context_negative = None
x = x_pos_out
q = self.norm_q(self.q(x))
x = normalized_attention_guidance(self, q, context_positive, self.nag_context)
if context_negative is not None:
q_real_negative = self.norm_q(self.q(x_real_negative))
k_real_negative = self.norm_k(self.k(context_negative))
v_real_negative = self.v(context_negative)
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
x = torch.cat([x, x_real_negative], dim=0)
x = self.o(x)
return x
return self.o(x)
def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
@ -1551,12 +1554,13 @@ def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
return x
class WanCrossAttentionPatch:
def __init__(self, context, nag_scale, nag_alpha, nag_tau, i2v=False):
def __init__(self, context, nag_scale, nag_alpha, nag_tau, i2v=False, input_type="default"):
self.nag_context = context
self.nag_scale = nag_scale
self.nag_alpha = nag_alpha
self.nag_tau = nag_tau
self.i2v = i2v
self.input_type = input_type
def __get__(self, obj, objtype=None):
# Create bound method with stored parameters
def wrapped_attention(self_module, *args, **kwargs):
@ -1564,6 +1568,7 @@ class WanCrossAttentionPatch:
self_module.nag_scale = self.nag_scale
self_module.nag_alpha = self.nag_alpha
self_module.nag_tau = self.nag_tau
self_module.input_type = self.input_type
if self.i2v:
return wan_i2v_crossattn_forward_nag(self_module, *args, **kwargs)
else:
@ -1580,7 +1585,11 @@ class WanVideoNAG:
"nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of negative guidance effect"}),
"nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Mixing coefficient in that controls the balance between the normalized guided representation and the original positive representation."}),
"nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Clipping threshold that controls how much the guided attention can deviate from the positive attention."}),
}
},
"optional": {
"input_type": (["default", "batch"], {"tooltip": "Type of the model input"}),
},
}
RETURN_TYPES = ("MODEL",)
@ -1590,7 +1599,7 @@ class WanVideoNAG:
DESCRIPTION = "https://github.com/ChenDarYen/Normalized-Attention-Guidance"
EXPERIMENTAL = True
def patch(self, model, conditioning, nag_scale, nag_alpha, nag_tau):
def patch(self, model, conditioning, nag_scale, nag_alpha, nag_tau, input_type="default"):
if nag_scale == 0:
return (model,)
@ -1608,7 +1617,7 @@ class WanVideoNAG:
i2v = True if "WAN21_I2V" in type_str else False
for idx, block in enumerate(diffusion_model.blocks):
patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau, i2v).__get__(block.cross_attn, block.__class__)
patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau, i2v, input_type=input_type).__get__(block.cross_attn, block.__class__)
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.cross_attn.forward", patched_attn)