mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-23 11:24:32 +08:00
Add WanVideoEnhanceAVideoKJ
This commit is contained in:
parent
665f59fae3
commit
263961539e
@ -184,6 +184,7 @@ NODE_CONFIG = {
|
||||
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
|
||||
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
|
||||
"WanVideoTeaCacheKJ": {"class": WanVideoTeaCacheKJ, "name": "WanVideo Tea Cache (native)"},
|
||||
"WanVideoEnhanceAVideoKJ": {"class": WanVideoEnhanceAVideoKJ, "name": "WanVideo Enhance A Video (native)"},
|
||||
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
|
||||
|
||||
#instance diffusion
|
||||
|
||||
@ -945,3 +945,134 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
||||
model_clone.set_model_unet_function_wrapper(outer_wrapper(start_percent=start_percent, end_percent=end_percent))
|
||||
|
||||
return (model_clone,)
|
||||
|
||||
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.wan.model import WanSelfAttention
|
||||
def modified_wan_self_attention_forward(self, x, freqs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
# query, key, value function
|
||||
def qkv_fn(x):
|
||||
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||
v = self.v(x).view(b, s, n * d)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = qkv_fn(x)
|
||||
|
||||
q, k = apply_rope(q, k, freqs)
|
||||
|
||||
feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight)
|
||||
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
x = self.o(x)
|
||||
|
||||
x *= feta_scores
|
||||
|
||||
return x
|
||||
|
||||
from einops import rearrange
|
||||
def get_feta_scores(query, key, num_frames, enhance_weight):
|
||||
img_q, img_k = query, key #torch.Size([2, 9216, 12, 128])
|
||||
|
||||
_, ST, num_heads, head_dim = img_q.shape
|
||||
spatial_dim = ST / num_frames
|
||||
spatial_dim = int(spatial_dim)
|
||||
|
||||
query_image = rearrange(
|
||||
img_q, "B (T S) N C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
|
||||
)
|
||||
key_image = rearrange(
|
||||
img_k, "B (T S) N C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
|
||||
)
|
||||
|
||||
return feta_score(query_image, key_image, head_dim, num_frames, enhance_weight)
|
||||
|
||||
def feta_score(query_image, key_image, head_dim, num_frames, enhance_weight):
|
||||
scale = head_dim**-0.5
|
||||
query_image = query_image * scale
|
||||
attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
|
||||
attn_temp = attn_temp.to(torch.float32)
|
||||
attn_temp = attn_temp.softmax(dim=-1)
|
||||
|
||||
# Reshape to [batch_size * num_tokens, num_frames, num_frames]
|
||||
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
|
||||
|
||||
# Create a mask for diagonal elements
|
||||
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
|
||||
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
|
||||
|
||||
# Zero out diagonal elements
|
||||
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
|
||||
|
||||
# Calculate mean for each token's attention matrix
|
||||
# Number of off-diagonal elements per matrix is n*n - n
|
||||
num_off_diag = num_frames * num_frames - num_frames
|
||||
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
|
||||
|
||||
enhance_scores = mean_scores.mean() * (num_frames + enhance_weight)
|
||||
enhance_scores = enhance_scores.clamp(min=1)
|
||||
return enhance_scores
|
||||
|
||||
import types
|
||||
class WanAttentionPatch:
|
||||
def __init__(self, num_frames, weight):
|
||||
self.num_frames = num_frames
|
||||
self.enhance_weight = weight
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
# Create bound method with stored parameters
|
||||
def wrapped_attention(self_module, *args, **kwargs):
|
||||
self_module.num_frames = self.num_frames
|
||||
self_module.enhance_weight = self.enhance_weight
|
||||
return modified_wan_self_attention_forward(self_module, *args, **kwargs)
|
||||
return types.MethodType(wrapped_attention, obj)
|
||||
|
||||
class WanVideoEnhanceAVideoKJ:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"latent": ("LATENT", {"tooltip": "Only used to get the latent count"}),
|
||||
"weight": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of the enhance effect"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "enhance"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def enhance(self, model, weight, latent):
|
||||
if weight == 0:
|
||||
return (model,)
|
||||
|
||||
num_frames = latent["samples"].shape[2]
|
||||
|
||||
model_clone = model.clone()
|
||||
if 'transformer_options' not in model_clone.model_options:
|
||||
model_clone.model_options['transformer_options'] = {}
|
||||
model_clone.model_options["transformer_options"]["enhance_weight"] = weight
|
||||
diffusion_model = model_clone.get_model_object("diffusion_model")
|
||||
for idx, block in enumerate(diffusion_model.blocks):
|
||||
self_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__)
|
||||
model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", self_attn)
|
||||
|
||||
return (model_clone,)
|
||||
Loading…
x
Reference in New Issue
Block a user