mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-10 05:15:05 +08:00
Update model_optimization_nodes.py
This commit is contained in:
parent
d584c711a3
commit
7afb0f906a
@ -1444,38 +1444,21 @@ class WanVideoEnhanceAVideoKJ:
|
||||
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", patched_attn)
|
||||
|
||||
return (model_clone,)
|
||||
|
||||
#region NAG
|
||||
def wan_crossattn_forward_nag(self, x, context, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
context(Tensor): Shape [B, L2, C]
|
||||
"""
|
||||
|
||||
# NAG text attention
|
||||
if context.shape[0] == 2:
|
||||
x, x_real_negative = torch.chunk(x, 2, dim=0)
|
||||
context_positive, context_negative = torch.chunk(context, 2, dim=0)
|
||||
else:
|
||||
context_positive = context
|
||||
context_negative = None
|
||||
|
||||
def normalized_attention_guidance(self, query, context_positive, context_negative):
|
||||
nag_scale = self.nag_scale
|
||||
nag_alpha = self.nag_alpha
|
||||
nag_tau = self.nag_tau
|
||||
|
||||
q = self.norm_q(self.q(x))
|
||||
|
||||
k_positive = self.norm_k(self.k(context_positive))
|
||||
v_positive = self.v(context_positive)
|
||||
k_negative = self.norm_k(self.k(self.nag_context))
|
||||
v_negative = self.v(self.nag_context)
|
||||
k_negative = self.norm_k(self.k(context_negative))
|
||||
v_negative = self.v(context_negative)
|
||||
|
||||
x_positive = comfy.ldm.modules.attention.optimized_attention(q, k_positive, v_positive, heads=self.num_heads)
|
||||
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads)
|
||||
x_positive = x_positive.flatten(2)
|
||||
|
||||
x_negative = comfy.ldm.modules.attention.optimized_attention(q, k_negative, v_negative, heads=self.num_heads)
|
||||
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads)
|
||||
x_negative = x_negative.flatten(2)
|
||||
|
||||
nag_guidance = x_positive * nag_scale - x_negative * (nag_scale - 1)
|
||||
@ -1491,24 +1474,84 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
|
||||
nag_guidance = torch.where(mask, nag_guidance * adjustment, nag_guidance)
|
||||
|
||||
x = nag_guidance * nag_alpha + x_positive * (1 - nag_alpha)
|
||||
del nag_guidance
|
||||
|
||||
return x
|
||||
|
||||
#region NAG
|
||||
def wan_crossattn_forward_nag(self, x, context, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
context(Tensor): Shape [B, L2, C]
|
||||
"""
|
||||
|
||||
if context.shape[0] == 2:
|
||||
x, x_real_negative = torch.chunk(x, 2, dim=0)
|
||||
context_positive, context_negative = torch.chunk(context, 2, dim=0)
|
||||
else:
|
||||
context_positive = context
|
||||
context_negative = None
|
||||
|
||||
q = self.norm_q(self.q(x))
|
||||
|
||||
x = normalized_attention_guidance(self, q, context_positive, self.nag_context)
|
||||
|
||||
if context_negative is not None:
|
||||
q_real_negative = self.norm_q(self.q(x_real_negative))
|
||||
k_real_negative = self.norm_k(self.k(context_negative))
|
||||
v_real_negative = self.v(context_negative)
|
||||
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
|
||||
#x_real_negative = x_real_negative.flatten(2)
|
||||
x = torch.cat([x, x_real_negative], dim=0)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
context(Tensor): Shape [B, L2, C]
|
||||
"""
|
||||
context_img = context[:, :context_img_len]
|
||||
context = context[:, context_img_len:]
|
||||
|
||||
q_img = self.norm_q(self.q(x))
|
||||
k_img = self.norm_k_img(self.k_img(context_img))
|
||||
v_img = self.v_img(context_img)
|
||||
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads)
|
||||
|
||||
if context.shape[0] == 2:
|
||||
x, x_real_negative = torch.chunk(x, 2, dim=0)
|
||||
context_positive, context_negative = torch.chunk(context, 2, dim=0)
|
||||
else:
|
||||
context_positive = context
|
||||
context_negative = None
|
||||
|
||||
q = self.norm_q(self.q(x))
|
||||
|
||||
x = normalized_attention_guidance(self, q, context_positive, self.nag_context)
|
||||
|
||||
if context_negative is not None:
|
||||
q_real_negative = self.norm_q(self.q(x_real_negative))
|
||||
k_real_negative = self.norm_k(self.k(context_negative))
|
||||
v_real_negative = self.v(context_negative)
|
||||
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
|
||||
x = torch.cat([x, x_real_negative], dim=0)
|
||||
|
||||
# output
|
||||
x = x + img_x
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
class WanCrossAttentionPatch:
|
||||
def __init__(self, context, nag_scale, nag_alpha, nag_tau):
|
||||
def __init__(self, context, nag_scale, nag_alpha, nag_tau, i2v=False):
|
||||
self.nag_context = context
|
||||
self.nag_scale = nag_scale
|
||||
self.nag_alpha = nag_alpha
|
||||
self.nag_tau = nag_tau
|
||||
self.i2v = i2v
|
||||
def __get__(self, obj, objtype=None):
|
||||
# Create bound method with stored parameters
|
||||
def wrapped_attention(self_module, *args, **kwargs):
|
||||
@ -1516,7 +1559,10 @@ class WanCrossAttentionPatch:
|
||||
self_module.nag_scale = self.nag_scale
|
||||
self_module.nag_alpha = self.nag_alpha
|
||||
self_module.nag_tau = self.nag_tau
|
||||
return wan_crossattn_forward_nag(self_module, *args, **kwargs)
|
||||
if self.i2v:
|
||||
return wan_i2v_crossattn_forward_nag(self_module, *args, **kwargs)
|
||||
else:
|
||||
return wan_crossattn_forward_nag(self_module, *args, **kwargs)
|
||||
return types.MethodType(wrapped_attention, obj)
|
||||
|
||||
class WanVideoNAG:
|
||||
@ -1526,9 +1572,9 @@ class WanVideoNAG:
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"conditioning": ("CONDITIONING",),
|
||||
"nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of the effect"}),
|
||||
"nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Alpha of the effect"}),
|
||||
"nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Tau of the effect"}),
|
||||
"nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.001, "tooltip": "Strength of negative guidance effect"}),
|
||||
"nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Mixing coefficient in that controls the balance between the normalized guided representation and the original positive representation."}),
|
||||
"nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Clipping threshold that controls how much the guided attention can deviate from the positive attention."}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1547,16 +1593,17 @@ class WanVideoNAG:
|
||||
dtype = mm.unet_dtype()
|
||||
|
||||
model_clone = model.clone()
|
||||
# if 'transformer_options' not in model_clone.model_options:
|
||||
# model_clone.model_options['transformer_options'] = {}
|
||||
|
||||
diffusion_model = model_clone.get_model_object("diffusion_model")
|
||||
|
||||
diffusion_model.text_embedding.to(device)
|
||||
context = diffusion_model.text_embedding(conditioning[0][0].to(device, dtype))
|
||||
|
||||
type_str = str(type(model.model.model_config).__name__)
|
||||
i2v = True if "WAN21_I2V" in type_str else False
|
||||
|
||||
for idx, block in enumerate(diffusion_model.blocks):
|
||||
patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau).__get__(block.cross_attn, block.__class__)
|
||||
patched_attn = WanCrossAttentionPatch(context, nag_scale, nag_alpha, nag_tau, i2v).__get__(block.cross_attn, block.__class__)
|
||||
|
||||
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.cross_attn.forward", patched_attn)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user