mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 04:44:30 +08:00
Experimental NAG node
This commit is contained in:
parent
aeab1a7de5
commit
d584c711a3
@ -196,6 +196,7 @@ NODE_CONFIG = {
|
||||
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
|
||||
"CFGZeroStarAndInit": {"class": CFGZeroStarAndInit, "name": "CFG Zero Star/Init"},
|
||||
"ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"},
|
||||
"WanVideoNAG": {"class": WanVideoNAG, "name": "WanVideoNAG"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -1444,6 +1444,123 @@ class WanVideoEnhanceAVideoKJ:
|
||||
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", patched_attn)
|
||||
|
||||
return (model_clone,)
|
||||
|
||||
#region NAG
|
||||
def wan_crossattn_forward_nag(self, x, context, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
context(Tensor): Shape [B, L2, C]
|
||||
"""
|
||||
|
||||
# NAG text attention
|
||||
if context.shape[0] == 2:
|
||||
x, x_real_negative = torch.chunk(x, 2, dim=0)
|
||||
context_positive, context_negative = torch.chunk(context, 2, dim=0)
|
||||
else:
|
||||
context_positive = context
|
||||
context_negative = None
|
||||
|
||||
nag_scale = self.nag_scale
|
||||
nag_alpha = self.nag_alpha
|
||||
nag_tau = self.nag_tau
|
||||
|
||||
q = self.norm_q(self.q(x))
|
||||
|
||||
k_positive = self.norm_k(self.k(context_positive))
|
||||
v_positive = self.v(context_positive)
|
||||
k_negative = self.norm_k(self.k(self.nag_context))
|
||||
v_negative = self.v(self.nag_context)
|
||||
|
||||
x_positive = comfy.ldm.modules.attention.optimized_attention(q, k_positive, v_positive, heads=self.num_heads)
|
||||
x_positive = x_positive.flatten(2)
|
||||
|
||||
x_negative = comfy.ldm.modules.attention.optimized_attention(q, k_negative, v_negative, heads=self.num_heads)
|
||||
x_negative = x_negative.flatten(2)
|
||||
|
||||
nag_guidance = x_positive * nag_scale - x_negative * (nag_scale - 1)
|
||||
|
||||
norm_positive = torch.norm(x_positive, p=1, dim=-1, keepdim=True).expand_as(x_positive)
|
||||
norm_guidance = torch.norm(nag_guidance, p=1, dim=-1, keepdim=True).expand_as(nag_guidance)
|
||||
|
||||
scale = norm_guidance / norm_positive
|
||||
scale = torch.nan_to_num(scale, nan=10.0)
|
||||
|
||||
mask = scale > nag_tau
|
||||
adjustment = (norm_positive * nag_tau) / (norm_guidance + 1e-7)
|
||||
nag_guidance = torch.where(mask, nag_guidance * adjustment, nag_guidance)
|
||||
|
||||
x = nag_guidance * nag_alpha + x_positive * (1 - nag_alpha)
|
||||
|
||||
if context_negative is not None:
|
||||
q_real_negative = self.norm_q(self.q(x_real_negative))
|
||||
k_real_negative = self.norm_k(self.k(context_negative))
|
||||
v_real_negative = self.v(context_negative)
|
||||
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
|
||||
#x_real_negative = x_real_negative.flatten(2)
|
||||
x = torch.cat([x, x_real_negative], dim=0)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
class WanCrossAttentionPatch:
|
||||
def __init__(self, context, nag_scale, nag_alpha, nag_tau):
|
||||
self.nag_context = context
|
||||
self.nag_scale = nag_scale
|
||||
self.nag_alpha = nag_alpha
|
||||
self.nag_tau = nag_tau
|
||||
def __get__(self, obj, objtype=None):
|
||||
# Create bound method with stored parameters
|
||||
def wrapped_attention(self_module, *args, **kwargs):
|
||||
self_module.nag_context = self.nag_context
|
||||
self_module.nag_scale = self.nag_scale
|
||||
self_module.nag_alpha = self.nag_alpha
|
||||
self_module.nag_tau = self.nag_tau
|
||||
return wan_crossattn_forward_nag(self_module, *args, **kwargs)
|
||||
return types.MethodType(wrapped_attention, obj)
|
||||
|
||||
class WanVideoNAG:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"conditioning": ("CONDITIONING",),
|
||||
"nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of the effect"}),
|
||||
"nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Alpha of the effect"}),
|
||||
"nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Tau of the effect"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
DESCRIPTION = "https://github.com/ChenDarYen/Normalized-Attention-Guidance"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model, conditioning, nag_scale, nag_alpha, nag_tau):
|
||||
if nag_scale == 0:
|
||||
return (model,)
|
||||
|
||||
device = mm.get_torch_device()
|
||||
dtype = mm.unet_dtype()
|
||||
|
||||
model_clone = model.clone()
|
||||
# if 'transformer_options' not in model_clone.model_options:
|
||||
# model_clone.model_options['transformer_options'] = {}
|
||||
|
||||
diffusion_model = model_clone.get_model_object("diffusion_model")
|
||||
|
||||
diffusion_model.text_embedding.to(device)
|
||||
context = diffusion_model.text_embedding(conditioning[0][0].to(device, dtype))
|
||||
|
||||
for idx, block in enumerate(diffusion_model.blocks):
|
||||
patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau).__get__(block.cross_attn, block.__class__)
|
||||
|
||||
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.cross_attn.forward", patched_attn)
|
||||
|
||||
return (model_clone,)
|
||||
|
||||
class SkipLayerGuidanceWanVideo:
|
||||
@classmethod
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user