Add transformer_options for enhance-a-video and NAG patches

This commit is contained in:
kijai 2025-09-14 11:49:04 +03:00
parent e833a3f7df
commit 468fcc86f0

View File

@ -1316,6 +1316,7 @@ class WanVideoTeaCacheKJ:
RETURN_NAMES = ("model",)
FUNCTION = "patch_teacache"
CATEGORY = "KJNodes/teacache"
DEPRECATED = True
DESCRIPTION = """
Patch WanVideo model to use TeaCache. Speeds up inference by caching the output and
applying it instead of doing the step. Best results are achieved by choosing the
@ -1450,7 +1451,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
from comfy.ldm.flux.math import apply_rope
def modified_wan_self_attention_forward(self, x, freqs):
def modified_wan_self_attention_forward(self, x, freqs, transformer_options={}):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@ -1470,13 +1471,23 @@ def modified_wan_self_attention_forward(self, x, freqs):
q, k = apply_rope(q, k, freqs)
feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight)
x = comfy.ldm.modules.attention.optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v,
heads=self.num_heads,
)
try:
x = comfy.ldm.modules.attention.optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v,
heads=self.num_heads,
transformer_options=transformer_options,
)
except:
# backward compatibility for now
x = comfy.ldm.modules.attention.attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v,
heads=self.num_heads,
)
x = self.o(x)
@ -1581,14 +1592,18 @@ class WanVideoEnhanceAVideoKJ:
return (model_clone,)
def normalized_attention_guidance(self, query, context_positive, context_negative):
def normalized_attention_guidance(self, query, context_positive, context_negative, transformer_options={}):
k_positive = self.norm_k(self.k(context_positive))
v_positive = self.v(context_positive)
k_negative = self.norm_k(self.k(context_negative))
v_negative = self.v(context_negative)
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2)
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2)
try:
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads, transformer_options=transformer_options).flatten(2)
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads, transformer_options=transformer_options).flatten(2)
except: #backwards compatibility for now
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2)
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2)
nag_guidance = x_positive * self.nag_scale - x_negative * (self.nag_scale - 1)
@ -1607,7 +1622,7 @@ def normalized_attention_guidance(self, query, context_positive, context_negativ
return x
#region NAG
def wan_crossattn_forward_nag(self, x, context, **kwargs):
def wan_crossattn_forward_nag(self, x, context, transformer_options={}, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@ -1632,14 +1647,20 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
nag_context = self.nag_context
if self.input_type == "batch":
nag_context = nag_context.repeat(x_pos.shape[0], 1, 1)
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context)
try:
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context, transformer_options=transformer_options)
except: #backwards compatibility for now
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context)
# Negative branch
if x_neg is not None and context_neg is not None:
q_neg = self.norm_q(self.q(x_neg))
k_neg = self.norm_k(self.k(context_neg))
v_neg = self.v(context_neg)
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads)
try:
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads, transformer_options=transformer_options)
except: #backwards compatibility for now
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads)
x = torch.cat([x_pos_out, x_neg_out], dim=0)
else:
x = x_pos_out
@ -1647,7 +1668,7 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
return self.o(x)
def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len, transformer_options={}, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@ -1659,7 +1680,10 @@ def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
q_img = self.norm_q(self.q(x))
k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img)
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads)
try:
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
except: #backwards compatibility for now
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads)
if context.shape[0] == 2:
x, x_real_negative = torch.chunk(x, 2, dim=0)
@ -1670,13 +1694,16 @@ def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
q = self.norm_q(self.q(x))
x = normalized_attention_guidance(self, q, context_positive, self.nag_context)
x = normalized_attention_guidance(self, q, context_positive, self.nag_context, transformer_options=transformer_options)
if context_negative is not None:
q_real_negative = self.norm_q(self.q(x_real_negative))
k_real_negative = self.norm_k(self.k(context_negative))
v_real_negative = self.v(context_negative)
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
try:
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads, transformer_options=transformer_options)
except: #backwards compatibility for now
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
x = torch.cat([x, x_real_negative], dim=0)
# output
@ -1766,7 +1793,7 @@ class SkipLayerGuidanceWanVideo:
FUNCTION = "slg"
EXPERIMENTAL = True
DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks"
DEPRECATED = True
CATEGORY = "advanced/guidance"
def slg(self, model, start_percent, end_percent, blocks):