mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-13 14:54:39 +08:00
Add transformer_options for enhance-a-video and NAG patches
This commit is contained in:
parent
e833a3f7df
commit
468fcc86f0
@ -1316,6 +1316,7 @@ class WanVideoTeaCacheKJ:
|
|||||||
RETURN_NAMES = ("model",)
|
RETURN_NAMES = ("model",)
|
||||||
FUNCTION = "patch_teacache"
|
FUNCTION = "patch_teacache"
|
||||||
CATEGORY = "KJNodes/teacache"
|
CATEGORY = "KJNodes/teacache"
|
||||||
|
DEPRECATED = True
|
||||||
DESCRIPTION = """
|
DESCRIPTION = """
|
||||||
Patch WanVideo model to use TeaCache. Speeds up inference by caching the output and
|
Patch WanVideo model to use TeaCache. Speeds up inference by caching the output and
|
||||||
applying it instead of doing the step. Best results are achieved by choosing the
|
applying it instead of doing the step. Best results are achieved by choosing the
|
||||||
@ -1450,7 +1451,7 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
|||||||
|
|
||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope
|
||||||
|
|
||||||
def modified_wan_self_attention_forward(self, x, freqs):
|
def modified_wan_self_attention_forward(self, x, freqs, transformer_options={}):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||||
@ -1470,13 +1471,23 @@ def modified_wan_self_attention_forward(self, x, freqs):
|
|||||||
q, k = apply_rope(q, k, freqs)
|
q, k = apply_rope(q, k, freqs)
|
||||||
|
|
||||||
feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight)
|
feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight)
|
||||||
|
|
||||||
x = comfy.ldm.modules.attention.optimized_attention(
|
try:
|
||||||
q.view(b, s, n * d),
|
x = comfy.ldm.modules.attention.optimized_attention(
|
||||||
k.view(b, s, n * d),
|
q.view(b, s, n * d),
|
||||||
v,
|
k.view(b, s, n * d),
|
||||||
heads=self.num_heads,
|
v,
|
||||||
)
|
heads=self.num_heads,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
# backward compatibility for now
|
||||||
|
x = comfy.ldm.modules.attention.attention(
|
||||||
|
q.view(b, s, n * d),
|
||||||
|
k.view(b, s, n * d),
|
||||||
|
v,
|
||||||
|
heads=self.num_heads,
|
||||||
|
)
|
||||||
|
|
||||||
x = self.o(x)
|
x = self.o(x)
|
||||||
|
|
||||||
@ -1581,14 +1592,18 @@ class WanVideoEnhanceAVideoKJ:
|
|||||||
|
|
||||||
return (model_clone,)
|
return (model_clone,)
|
||||||
|
|
||||||
def normalized_attention_guidance(self, query, context_positive, context_negative):
|
def normalized_attention_guidance(self, query, context_positive, context_negative, transformer_options={}):
|
||||||
k_positive = self.norm_k(self.k(context_positive))
|
k_positive = self.norm_k(self.k(context_positive))
|
||||||
v_positive = self.v(context_positive)
|
v_positive = self.v(context_positive)
|
||||||
k_negative = self.norm_k(self.k(context_negative))
|
k_negative = self.norm_k(self.k(context_negative))
|
||||||
v_negative = self.v(context_negative)
|
v_negative = self.v(context_negative)
|
||||||
|
|
||||||
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2)
|
try:
|
||||||
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2)
|
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads, transformer_options=transformer_options).flatten(2)
|
||||||
|
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads, transformer_options=transformer_options).flatten(2)
|
||||||
|
except: #backwards compatibility for now
|
||||||
|
x_positive = comfy.ldm.modules.attention.optimized_attention(query, k_positive, v_positive, heads=self.num_heads).flatten(2)
|
||||||
|
x_negative = comfy.ldm.modules.attention.optimized_attention(query, k_negative, v_negative, heads=self.num_heads).flatten(2)
|
||||||
|
|
||||||
nag_guidance = x_positive * self.nag_scale - x_negative * (self.nag_scale - 1)
|
nag_guidance = x_positive * self.nag_scale - x_negative * (self.nag_scale - 1)
|
||||||
|
|
||||||
@ -1607,7 +1622,7 @@ def normalized_attention_guidance(self, query, context_positive, context_negativ
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
#region NAG
|
#region NAG
|
||||||
def wan_crossattn_forward_nag(self, x, context, **kwargs):
|
def wan_crossattn_forward_nag(self, x, context, transformer_options={}, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
@ -1632,14 +1647,20 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
|
|||||||
nag_context = self.nag_context
|
nag_context = self.nag_context
|
||||||
if self.input_type == "batch":
|
if self.input_type == "batch":
|
||||||
nag_context = nag_context.repeat(x_pos.shape[0], 1, 1)
|
nag_context = nag_context.repeat(x_pos.shape[0], 1, 1)
|
||||||
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context)
|
try:
|
||||||
|
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context, transformer_options=transformer_options)
|
||||||
|
except: #backwards compatibility for now
|
||||||
|
x_pos_out = normalized_attention_guidance(self, q_pos, context_pos, nag_context)
|
||||||
|
|
||||||
# Negative branch
|
# Negative branch
|
||||||
if x_neg is not None and context_neg is not None:
|
if x_neg is not None and context_neg is not None:
|
||||||
q_neg = self.norm_q(self.q(x_neg))
|
q_neg = self.norm_q(self.q(x_neg))
|
||||||
k_neg = self.norm_k(self.k(context_neg))
|
k_neg = self.norm_k(self.k(context_neg))
|
||||||
v_neg = self.v(context_neg)
|
v_neg = self.v(context_neg)
|
||||||
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads)
|
try:
|
||||||
|
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads, transformer_options=transformer_options)
|
||||||
|
except: #backwards compatibility for now
|
||||||
|
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.num_heads)
|
||||||
x = torch.cat([x_pos_out, x_neg_out], dim=0)
|
x = torch.cat([x_pos_out, x_neg_out], dim=0)
|
||||||
else:
|
else:
|
||||||
x = x_pos_out
|
x = x_pos_out
|
||||||
@ -1647,7 +1668,7 @@ def wan_crossattn_forward_nag(self, x, context, **kwargs):
|
|||||||
return self.o(x)
|
return self.o(x)
|
||||||
|
|
||||||
|
|
||||||
def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
|
def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len, transformer_options={}, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
@ -1659,7 +1680,10 @@ def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
|
|||||||
q_img = self.norm_q(self.q(x))
|
q_img = self.norm_q(self.q(x))
|
||||||
k_img = self.norm_k_img(self.k_img(context_img))
|
k_img = self.norm_k_img(self.k_img(context_img))
|
||||||
v_img = self.v_img(context_img)
|
v_img = self.v_img(context_img)
|
||||||
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads)
|
try:
|
||||||
|
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
|
||||||
|
except: #backwards compatibility for now
|
||||||
|
img_x = comfy.ldm.modules.attention.optimized_attention(q_img, k_img, v_img, heads=self.num_heads)
|
||||||
|
|
||||||
if context.shape[0] == 2:
|
if context.shape[0] == 2:
|
||||||
x, x_real_negative = torch.chunk(x, 2, dim=0)
|
x, x_real_negative = torch.chunk(x, 2, dim=0)
|
||||||
@ -1670,13 +1694,16 @@ def wan_i2v_crossattn_forward_nag(self, x, context, context_img_len):
|
|||||||
|
|
||||||
q = self.norm_q(self.q(x))
|
q = self.norm_q(self.q(x))
|
||||||
|
|
||||||
x = normalized_attention_guidance(self, q, context_positive, self.nag_context)
|
x = normalized_attention_guidance(self, q, context_positive, self.nag_context, transformer_options=transformer_options)
|
||||||
|
|
||||||
if context_negative is not None:
|
if context_negative is not None:
|
||||||
q_real_negative = self.norm_q(self.q(x_real_negative))
|
q_real_negative = self.norm_q(self.q(x_real_negative))
|
||||||
k_real_negative = self.norm_k(self.k(context_negative))
|
k_real_negative = self.norm_k(self.k(context_negative))
|
||||||
v_real_negative = self.v(context_negative)
|
v_real_negative = self.v(context_negative)
|
||||||
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
|
try:
|
||||||
|
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads, transformer_options=transformer_options)
|
||||||
|
except: #backwards compatibility for now
|
||||||
|
x_real_negative = comfy.ldm.modules.attention.optimized_attention(q_real_negative, k_real_negative, v_real_negative, heads=self.num_heads)
|
||||||
x = torch.cat([x, x_real_negative], dim=0)
|
x = torch.cat([x, x_real_negative], dim=0)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
@ -1766,7 +1793,7 @@ class SkipLayerGuidanceWanVideo:
|
|||||||
FUNCTION = "slg"
|
FUNCTION = "slg"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks"
|
DESCRIPTION = "Simplified skip layer guidance that only skips the uncond on selected blocks"
|
||||||
|
DEPRECATED = True
|
||||||
CATEGORY = "advanced/guidance"
|
CATEGORY = "advanced/guidance"
|
||||||
|
|
||||||
def slg(self, model, start_percent, end_percent, blocks):
|
def slg(self, model, start_percent, end_percent, blocks):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user