Add transformer_options for enhance-a-video and NAG patches

This commit is contained in:
kijai 2025-09-14 11:49:04 +03:00
parent e833a3f7df
commit 468fcc86f0

View File

@ -1316,6 +1316,7 @@ class WanVideoTeaCacheKJ:
RETURN_NAMES = ("model",) RETURN_NAMES = ("model",)
FUNCTION = "patch_teacache" FUNCTION = "patch_teacache"
CATEGORY = "KJNodes/teacache" CATEGORY = "KJNodes/teacache"
DEPRECATED = True
DESCRIPTION = """ DESCRIPTION = """
Patch WanVideo model to use TeaCache. Speeds up inference by caching the output and Patch WanVideo model to use TeaCache. Speeds up inference by caching the output and
applying it instead of doing the step. Best results are achieved by choosing the applying it instead of doing the step. Best results are achieved by choosing the
@ -1450,7 +1451,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
from comfy.ldm.flux.math import apply_rope from comfy.ldm.flux.math import apply_rope
def modified_wan_self_attention_forward(self, x, freqs): def modified_wan_self_attention_forward(self, x, freqs, transformer_options={}):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads] x(Tensor): Shape [B, L, num_heads, C / num_heads]
@ -1471,11 +1472,21 @@ def modified_wan_self_attention_forward(self, x, freqs):
feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight) feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight)
try:
x = comfy.ldm.modules.attention.optimized_attention( x = comfy.ldm.modules.attention.optimized_attention(
q.view(b, s, n * d), q.view(b, s, n * d),
k.view(b, s, n * d), k.view(b, s, n * d),
v, v,
heads=self.num_heads, heads=self.num_heads,
transformer_options=transformer_options,
)
except:
# backward compatibility for now
x = comfy.ldm.modules.attention.attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v,
heads=self.num_heads,
) )
x = self.o(x) x = self.o(x)
@ -1581,12 +1592,16 @@ class WanVideoEnhanceAVideoKJ:
return (model_clone,) return (model_clone,)
def normalized_attention_guidance(self, query, context_positive, context_negative): def normalized_attention_guidance(self, query, context_positive, context_negative, transformer_options={}):
k_positive = self.norm_k(self.k(context_positive)) k_positive = self.norm_k(self.k(context_positive))
v_positive = self.v(context_positive) v_positive = self.v(context_positive)
k_negative = self.norm_k(self.k(context_negative)) k_negative = self.norm_k(self.k(context_negative))
v_negative = self.v(context_negative) v_negative = self.v(context_negative)
try:
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads, transformer_options=transformer_options).flatten(2)
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads, transformer_options=transformer_options).flatten(2)
except: #backwards compatibility for now
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2) x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2)
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2) x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2)
@ -1607,7 +1622,7 @@ def normalized_attention_guidance(self, query, context_positive, context_negativ
return x return x
#region NAG #region NAG
def wan_crossattn_forward_nag(self, x, context, **kwargs): def wan_crossattn_forward_nag(self, x, context, transformer_options={}, **kwargs):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L1, C] x(Tensor): Shape [B, L1, C]
@ -1632,6 +1647,9 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
nag_context = self.nag_context nag_context = self.nag_context
if self.input_type == "batch": if self.input_type == "batch":
nag_context = nag_context.repeat(x_pos.shape[0], 1, 1) nag_context = nag_context.repeat(x_pos.shape[0], 1, 1)
try:
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context, transformer_options=transformer_options)
except: #backwards compatibility for now
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context) x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context)
# Negative branch # Negative branch
@ -1639,6 +1657,9 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
q_neg = self.norm_q(self.q(x_neg)) q_neg = self.norm_q(self.q(x_neg))
k_neg = self.norm_k(self.k(context_neg)) k_neg = self.norm_k(self.k(context_neg))
v_neg = self.v(context_neg) v_neg = self.v(context_neg)
try:
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads, transformer_options=transformer_options)
except: #backwards compatibility for now
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads) x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads)
x = torch.cat([x_pos_out, x_neg_out], dim=0) x = torch.cat([x_pos_out, x_neg_out], dim=0)
else: else:
@ -1647,7 +1668,7 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
return self.o(x) return self.o(x)
def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len): def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len, transformer_options={}, **kwargs):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L1, C] x(Tensor): Shape [B, L1, C]
@ -1659,6 +1680,9 @@ def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
q_img = self.norm_q(self.q(x)) q_img = self.norm_q(self.q(x))
k_img = self.norm_k_img(self.k_img(context_img)) k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img) v_img = self.v_img(context_img)
try:
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
except: #backwards compatibility for now
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads) img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads)
if context.shape[0] == 2: if context.shape[0] == 2:
@ -1670,12 +1694,15 @@ def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
q = self.norm_q(self.q(x)) q = self.norm_q(self.q(x))
x = normalized_attention_guidance(self, q, context_positive, self.nag_context) x = normalized_attention_guidance(self, q, context_positive, self.nag_context, transformer_options=transformer_options)
if context_negative is not None: if context_negative is not None:
q_real_negative = self.norm_q(self.q(x_real_negative)) q_real_negative = self.norm_q(self.q(x_real_negative))
k_real_negative = self.norm_k(self.k(context_negative)) k_real_negative = self.norm_k(self.k(context_negative))
v_real_negative = self.v(context_negative) v_real_negative = self.v(context_negative)
try:
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads, transformer_options=transformer_options)
except: #backwards compatibility for now
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads) x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
x = torch.cat([x, x_real_negative], dim=0) x = torch.cat([x, x_real_negative], dim=0)
@ -1766,7 +1793,7 @@ class SkipLayerGuidanceWanVideo:
FUNCTION = "slg" FUNCTION = "slg"
EXPERIMENTAL = True EXPERIMENTAL = True
DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks" DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks"
DEPRECATED = True
CATEGORY = "advanced/guidance" CATEGORY = "advanced/guidance"
def slg(self, model, start_percent, end_percent, blocks): def slg(self, model, start_percent, end_percent, blocks):