Update model_optimization_nodes.py

This commit is contained in:
kijai 2025-06-16 00:35:53 +03:00
parent d584c711a3
commit 7afb0f906a

View File

@ -1444,38 +1444,21 @@ class WanVideoEnhanceAVideoKJ:
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", patched_attn)
return (model_clone,)
#region NAG
def wan_crossattn_forward_nag(self, x, context, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
# NAG text attention
if context.shape[0] == 2:
x, x_real_negative = torch.chunk(x, 2, dim=0)
context_positive, context_negative = torch.chunk(context, 2, dim=0)
else:
context_positive = context
context_negative = None
def normalized_attention_guidance(self, query, context_positive, context_negative):
nag_scale = self.nag_scale
nag_alpha = self.nag_alpha
nag_tau = self.nag_tau
q = self.norm_q(self.q(x))
k_positive = self.norm_k(self.k(context_positive))
v_positive = self.v(context_positive)
k_negative = self.norm_k(self.k(self.nag_context))
v_negative = self.v(self.nag_context)
k_negative = self.norm_k(self.k(context_negative))
v_negative = self.v(context_negative)
x_positive = comfy.ldm.modules.attention.optimized_attention(q, k_positive, v_positive, heads=self.num_heads)
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads)
x_positive = x_positive.flatten(2)
x_negative = comfy.ldm.modules.attention.optimized_attention(q, k_negative, v_negative, heads=self.num_heads)
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads)
x_negative = x_negative.flatten(2)
nag_guidance = x_positive * nag_scale - x_negative * (nag_scale - 1)
@ -1491,24 +1474,84 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
nag_guidance = torch.where(mask, nag_guidance * adjustment, nag_guidance)
x = nag_guidance * nag_alpha + x_positive * (1 - nag_alpha)
del nag_guidance
return x
#region NAG
def wan_crossattn_forward_nag(self, x, context, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
if context.shape[0] == 2:
x, x_real_negative = torch.chunk(x, 2, dim=0)
context_positive, context_negative = torch.chunk(context, 2, dim=0)
else:
context_positive = context
context_negative = None
q = self.norm_q(self.q(x))
x = normalized_attention_guidance(self, q, context_positive, self.nag_context)
if context_negative is not None:
q_real_negative = self.norm_q(self.q(x_real_negative))
k_real_negative = self.norm_k(self.k(context_negative))
v_real_negative = self.v(context_negative)
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
#x_real_negative = x_real_negative.flatten(2)
x = torch.cat([x, x_real_negative], dim=0)
x = self.o(x)
return x
def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
context_img = context[:, :context_img_len]
context = context[:, context_img_len:]
q_img = self.norm_q(self.q(x))
k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img)
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads)
if context.shape[0] == 2:
x, x_real_negative = torch.chunk(x, 2, dim=0)
context_positive, context_negative = torch.chunk(context, 2, dim=0)
else:
context_positive = context
context_negative = None
q = self.norm_q(self.q(x))
x = normalized_attention_guidance(self, q, context_positive, self.nag_context)
if context_negative is not None:
q_real_negative = self.norm_q(self.q(x_real_negative))
k_real_negative = self.norm_k(self.k(context_negative))
v_real_negative = self.v(context_negative)
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
x = torch.cat([x, x_real_negative], dim=0)
# output
x = x + img_x
x = self.o(x)
return x
class WanCrossAttentionPatch:
def __init__(self, context, nag_scale, nag_alpha, nag_tau):
def __init__(self, context, nag_scale, nag_alpha, nag_tau, i2v=False):
self.nag_context = context
self.nag_scale = nag_scale
self.nag_alpha = nag_alpha
self.nag_tau = nag_tau
self.i2v = i2v
def __get__(self, obj, objtype=None):
# Create bound method with stored parameters
def wrapped_attention(self_module, *args, **kwargs):
@ -1516,7 +1559,10 @@ class WanCrossAttentionPatch:
self_module.nag_scale = self.nag_scale
self_module.nag_alpha = self.nag_alpha
self_module.nag_tau = self.nag_tau
return wan_crossattn_forward_nag(self_module, *args, **kwargs)
if self.i2v:
return wan_i2v_crossattn_forward_nag(self_module, *args, **kwargs)
else:
return wan_crossattn_forward_nag(self_module, *args, **kwargs)
return types.MethodType(wrapped_attention, obj)
class WanVideoNAG:
@ -1526,9 +1572,9 @@ class WanVideoNAG:
"required": {
"model": ("MODEL",),
"conditioning": ("CONDITIONING",),
"nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of the effect"}),
"nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Alpha of the effect"}),
"nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Tau of the effect"}),
"nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of negative guidance effect"}),
"nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Mixing coefficient in that controls the balance between the normalized guided representation and the original positive representation."}),
"nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Clipping threshold that controls how much the guided attention can deviate from the positive attention."}),
}
}
@ -1547,16 +1593,17 @@ class WanVideoNAG:
dtype = mm.unet_dtype()
model_clone = model.clone()
# if 'transformer_options' not in model_clone.model_options:
# model_clone.model_options['transformer_options'] = {}
diffusion_model = model_clone.get_model_object("diffusion_model")
diffusion_model.text_embedding.to(device)
context = diffusion_model.text_embedding(conditioning[0][0].to(device, dtype))
type_str = str(type(model.model.model_config).__name__)
i2v = True if "WAN21_I2V" in type_str else False
for idx, block in enumerate(diffusion_model.blocks):
patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau).__get__(block.cross_attn, block.__class__)
patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau, i2v).__get__(block.cross_attn, block.__class__)
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.cross_attn.forward", patched_attn)