mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 21:04:41 +08:00
NAG for batches
This commit is contained in:
parent
a6b867b63a
commit
fbdb08f9d6
@ -1451,34 +1451,26 @@ class WanVideoEnhanceAVideoKJ:
|
|||||||
return (model_clone,)
|
return (model_clone,)
|
||||||
|
|
||||||
def normalized_attention_guidance(self, query, context_positive, context_negative):
|
def normalized_attention_guidance(self, query, context_positive, context_negative):
|
||||||
nag_scale = self.nag_scale
|
|
||||||
nag_alpha = self.nag_alpha
|
|
||||||
nag_tau = self.nag_tau
|
|
||||||
|
|
||||||
k_positive = self.norm_k(self.k(context_positive))
|
k_positive = self.norm_k(self.k(context_positive))
|
||||||
v_positive = self.v(context_positive)
|
v_positive = self.v(context_positive)
|
||||||
k_negative = self.norm_k(self.k(context_negative))
|
k_negative = self.norm_k(self.k(context_negative))
|
||||||
v_negative = self.v(context_negative)
|
v_negative = self.v(context_negative)
|
||||||
|
|
||||||
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads)
|
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2)
|
||||||
x_positive = x_positive.flatten(2)
|
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2)
|
||||||
|
|
||||||
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads)
|
nag_guidance = x_positive * self.nag_scale - x_negative * (self.nag_scale - 1)
|
||||||
x_negative = x_negative.flatten(2)
|
|
||||||
|
|
||||||
nag_guidance = x_positive * nag_scale - x_negative * (nag_scale - 1)
|
|
||||||
|
|
||||||
norm_positive = torch.norm(x_positive, p=1, dim=-1, keepdim=True).expand_as(x_positive)
|
norm_positive = torch.norm(x_positive, p=1, dim=-1, keepdim=True).expand_as(x_positive)
|
||||||
norm_guidance = torch.norm(nag_guidance, p=1, dim=-1, keepdim=True).expand_as(nag_guidance)
|
norm_guidance = torch.norm(nag_guidance, p=1, dim=-1, keepdim=True).expand_as(nag_guidance)
|
||||||
|
|
||||||
scale = norm_guidance / norm_positive
|
scale = torch.nan_to_num(norm_guidance / norm_positive, nan=10.0)
|
||||||
scale = torch.nan_to_num(scale, nan=10.0)
|
|
||||||
|
mask = scale > self.nag_tau
|
||||||
mask = scale > nag_tau
|
adjustment = (norm_positive * self.nag_tau) / (norm_guidance + 1e-7)
|
||||||
adjustment = (norm_positive * nag_tau) / (norm_guidance + 1e-7)
|
|
||||||
nag_guidance = torch.where(mask, nag_guidance * adjustment, nag_guidance)
|
nag_guidance = torch.where(mask, nag_guidance * adjustment, nag_guidance)
|
||||||
|
|
||||||
x = nag_guidance * nag_alpha + x_positive * (1 - nag_alpha)
|
x = nag_guidance * self.nag_alpha + x_positive * (1 - self.nag_alpha)
|
||||||
del nag_guidance
|
del nag_guidance
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@ -1490,27 +1482,38 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
|
|||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
context(Tensor): Shape [B, L2, C]
|
context(Tensor): Shape [B, L2, C]
|
||||||
"""
|
"""
|
||||||
|
# Determine batch splitting and context handling
|
||||||
if context.shape[0] == 2:
|
if self.input_type == "default":
|
||||||
x, x_real_negative = torch.chunk(x, 2, dim=0)
|
# Single or [pos, neg] pair
|
||||||
context_positive, context_negative = torch.chunk(context, 2, dim=0)
|
if context.shape[0] == 1:
|
||||||
|
x_pos, context_pos = x, context
|
||||||
|
x_neg, context_neg = None, None
|
||||||
|
else:
|
||||||
|
x_pos, x_neg = torch.chunk(x, 2, dim=0)
|
||||||
|
context_pos, context_neg = torch.chunk(context, 2, dim=0)
|
||||||
|
elif self.input_type == "batch":
|
||||||
|
# Standard batch, no CFG
|
||||||
|
x_pos, context_pos = x, context
|
||||||
|
x_neg, context_neg = None, None
|
||||||
|
|
||||||
|
# Positive branch
|
||||||
|
q_pos = self.norm_q(self.q(x_pos))
|
||||||
|
nag_context = self.nag_context
|
||||||
|
if self.input_type == "batch":
|
||||||
|
nag_context = nag_context.repeat(x_pos.shape[0], 1, 1)
|
||||||
|
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context)
|
||||||
|
|
||||||
|
# Negative branch
|
||||||
|
if x_neg is not None and context_neg is not None:
|
||||||
|
q_neg = self.norm_q(self.q(x_neg))
|
||||||
|
k_neg = self.norm_k(self.k(context_neg))
|
||||||
|
v_neg = self.v(context_neg)
|
||||||
|
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads)
|
||||||
|
x = torch.cat([x_pos_out, x_neg_out], dim=0)
|
||||||
else:
|
else:
|
||||||
context_positive = context
|
x = x_pos_out
|
||||||
context_negative = None
|
|
||||||
|
|
||||||
q = self.norm_q(self.q(x))
|
return self.o(x)
|
||||||
|
|
||||||
x = normalized_attention_guidance(self, q, context_positive, self.nag_context)
|
|
||||||
|
|
||||||
if context_negative is not None:
|
|
||||||
q_real_negative = self.norm_q(self.q(x_real_negative))
|
|
||||||
k_real_negative = self.norm_k(self.k(context_negative))
|
|
||||||
v_real_negative = self.v(context_negative)
|
|
||||||
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
|
|
||||||
x = torch.cat([x, x_real_negative], dim=0)
|
|
||||||
|
|
||||||
x = self.o(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
|
def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
|
||||||
@ -1551,12 +1554,13 @@ def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class WanCrossAttentionPatch:
|
class WanCrossAttentionPatch:
|
||||||
def __init__(self, context, nag_scale, nag_alpha, nag_tau, i2v=False):
|
def __init__(self, context, nag_scale, nag_alpha, nag_tau, i2v=False, input_type="default"):
|
||||||
self.nag_context = context
|
self.nag_context = context
|
||||||
self.nag_scale = nag_scale
|
self.nag_scale = nag_scale
|
||||||
self.nag_alpha = nag_alpha
|
self.nag_alpha = nag_alpha
|
||||||
self.nag_tau = nag_tau
|
self.nag_tau = nag_tau
|
||||||
self.i2v = i2v
|
self.i2v = i2v
|
||||||
|
self.input_type = input_type
|
||||||
def __get__(self, obj, objtype=None):
|
def __get__(self, obj, objtype=None):
|
||||||
# Create bound method with stored parameters
|
# Create bound method with stored parameters
|
||||||
def wrapped_attention(self_module, *args, **kwargs):
|
def wrapped_attention(self_module, *args, **kwargs):
|
||||||
@ -1564,6 +1568,7 @@ class WanCrossAttentionPatch:
|
|||||||
self_module.nag_scale = self.nag_scale
|
self_module.nag_scale = self.nag_scale
|
||||||
self_module.nag_alpha = self.nag_alpha
|
self_module.nag_alpha = self.nag_alpha
|
||||||
self_module.nag_tau = self.nag_tau
|
self_module.nag_tau = self.nag_tau
|
||||||
|
self_module.input_type = self.input_type
|
||||||
if self.i2v:
|
if self.i2v:
|
||||||
return wan_i2v_crossattn_forward_nag(self_module, *args, **kwargs)
|
return wan_i2v_crossattn_forward_nag(self_module, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -1580,7 +1585,11 @@ class WanVideoNAG:
|
|||||||
"nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of negative guidance effect"}),
|
"nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of negative guidance effect"}),
|
||||||
"nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Mixing coefficient in that controls the balance between the normalized guided representation and the original positive representation."}),
|
"nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Mixing coefficient in that controls the balance between the normalized guided representation and the original positive representation."}),
|
||||||
"nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Clipping threshold that controls how much the guided attention can deviate from the positive attention."}),
|
"nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Clipping threshold that controls how much the guided attention can deviate from the positive attention."}),
|
||||||
}
|
},
|
||||||
|
"optional": {
|
||||||
|
"input_type": (["default", "batch"], {"tooltip": "Type of the model input"}),
|
||||||
|
},
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
@ -1590,7 +1599,7 @@ class WanVideoNAG:
|
|||||||
DESCRIPTION = "https://github.com/ChenDarYen/Normalized-Attention-Guidance"
|
DESCRIPTION = "https://github.com/ChenDarYen/Normalized-Attention-Guidance"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def patch(self, model, conditioning, nag_scale, nag_alpha, nag_tau):
|
def patch(self, model, conditioning, nag_scale, nag_alpha, nag_tau, input_type="default"):
|
||||||
if nag_scale == 0:
|
if nag_scale == 0:
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
@ -1608,7 +1617,7 @@ class WanVideoNAG:
|
|||||||
i2v = True if "WAN21_I2V" in type_str else False
|
i2v = True if "WAN21_I2V" in type_str else False
|
||||||
|
|
||||||
for idx, block in enumerate(diffusion_model.blocks):
|
for idx, block in enumerate(diffusion_model.blocks):
|
||||||
patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau, i2v).__get__(block.cross_attn, block.__class__)
|
patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau, i2v, input_type=input_type).__get__(block.cross_attn, block.__class__)
|
||||||
|
|
||||||
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.cross_attn.forward", patched_attn)
|
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.cross_attn.forward", patched_attn)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user