mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 12:54:40 +08:00
Add transformer_options for enhance-a-video and NAG patches
This commit is contained in:
parent
e833a3f7df
commit
468fcc86f0
@ -1316,6 +1316,7 @@ class WanVideoTeaCacheKJ:
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "patch_teacache"
|
||||
CATEGORY = "KJNodes/teacache"
|
||||
DEPRECATED = True
|
||||
DESCRIPTION = """
|
||||
Patch WanVideo model to use TeaCache. Speeds up inference by caching the output and
|
||||
applying it instead of doing the step. Best results are achieved by choosing the
|
||||
@ -1450,7 +1451,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
|
||||
def modified_wan_self_attention_forward(self, x, freqs):
|
||||
def modified_wan_self_attention_forward(self, x, freqs, transformer_options={}):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
@ -1470,13 +1471,23 @@ def modified_wan_self_attention_forward(self, x, freqs):
|
||||
q, k = apply_rope(q, k, freqs)
|
||||
|
||||
feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight)
|
||||
|
||||
x = comfy.ldm.modules.attention.optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
try:
|
||||
x = comfy.ldm.modules.attention.optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
except:
|
||||
# backward compatibility for now
|
||||
x = comfy.ldm.modules.attention.attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
x = self.o(x)
|
||||
|
||||
@ -1581,14 +1592,18 @@ class WanVideoEnhanceAVideoKJ:
|
||||
|
||||
return (model_clone,)
|
||||
|
||||
def normalized_attention_guidance(self, query, context_positive, context_negative):
|
||||
def normalized_attention_guidance(self, query, context_positive, context_negative, transformer_options={}):
|
||||
k_positive = self.norm_k(self.k(context_positive))
|
||||
v_positive = self.v(context_positive)
|
||||
k_negative = self.norm_k(self.k(context_negative))
|
||||
v_negative = self.v(context_negative)
|
||||
|
||||
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2)
|
||||
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2)
|
||||
try:
|
||||
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads, transformer_options=transformer_options).flatten(2)
|
||||
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads, transformer_options=transformer_options).flatten(2)
|
||||
except: #backwards compatibility for now
|
||||
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2)
|
||||
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2)
|
||||
|
||||
nag_guidance = x_positive * self.nag_scale - x_negative * (self.nag_scale - 1)
|
||||
|
||||
@ -1607,7 +1622,7 @@ def normalized_attention_guidance(self, query, context_positive, context_negativ
|
||||
return x
|
||||
|
||||
#region NAG
|
||||
def wan_crossattn_forward_nag(self, x, context, **kwargs):
|
||||
def wan_crossattn_forward_nag(self, x, context, transformer_options={}, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
@ -1632,14 +1647,20 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
|
||||
nag_context = self.nag_context
|
||||
if self.input_type == "batch":
|
||||
nag_context = nag_context.repeat(x_pos.shape[0], 1, 1)
|
||||
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context)
|
||||
try:
|
||||
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context, transformer_options=transformer_options)
|
||||
except: #backwards compatibility for now
|
||||
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context)
|
||||
|
||||
# Negative branch
|
||||
if x_neg is not None and context_neg is not None:
|
||||
q_neg = self.norm_q(self.q(x_neg))
|
||||
k_neg = self.norm_k(self.k(context_neg))
|
||||
v_neg = self.v(context_neg)
|
||||
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads)
|
||||
try:
|
||||
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads, transformer_options=transformer_options)
|
||||
except: #backwards compatibility for now
|
||||
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads)
|
||||
x = torch.cat([x_pos_out, x_neg_out], dim=0)
|
||||
else:
|
||||
x = x_pos_out
|
||||
@ -1647,7 +1668,7 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
|
||||
return self.o(x)
|
||||
|
||||
|
||||
def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
|
||||
def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len, transformer_options={}, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
@ -1659,7 +1680,10 @@ def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
|
||||
q_img = self.norm_q(self.q(x))
|
||||
k_img = self.norm_k_img(self.k_img(context_img))
|
||||
v_img = self.v_img(context_img)
|
||||
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads)
|
||||
try:
|
||||
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
|
||||
except: #backwards compatibility for now
|
||||
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads)
|
||||
|
||||
if context.shape[0] == 2:
|
||||
x, x_real_negative = torch.chunk(x, 2, dim=0)
|
||||
@ -1670,13 +1694,16 @@ def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
|
||||
|
||||
q = self.norm_q(self.q(x))
|
||||
|
||||
x = normalized_attention_guidance(self, q, context_positive, self.nag_context)
|
||||
x = normalized_attention_guidance(self, q, context_positive, self.nag_context, transformer_options=transformer_options)
|
||||
|
||||
if context_negative is not None:
|
||||
q_real_negative = self.norm_q(self.q(x_real_negative))
|
||||
k_real_negative = self.norm_k(self.k(context_negative))
|
||||
v_real_negative = self.v(context_negative)
|
||||
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
|
||||
try:
|
||||
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads, transformer_options=transformer_options)
|
||||
except: #backwards compatibility for now
|
||||
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
|
||||
x = torch.cat([x, x_real_negative], dim=0)
|
||||
|
||||
# output
|
||||
@ -1766,7 +1793,7 @@ class SkipLayerGuidanceWanVideo:
|
||||
FUNCTION = "slg"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks"
|
||||
|
||||
DEPRECATED = True
|
||||
CATEGORY = "advanced/guidance"
|
||||
|
||||
def slg(self, model, start_percent, end_percent, blocks):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user