From 263961539ee6d374fa145175fee9b21dbd88f00c Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 9 Mar 2025 16:53:51 +0200 Subject: [PATCH] Add WanVideoEnhanceAVideoKJ --- __init__.py | 1 + nodes/model_optimization_nodes.py | 131 ++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+) diff --git a/__init__.py b/__init__.py index 9f5507a..c18ce7c 100644 --- a/__init__.py +++ b/__init__.py @@ -184,6 +184,7 @@ NODE_CONFIG = { "ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"}, "ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"}, "WanVideoTeaCacheKJ": {"class": WanVideoTeaCacheKJ, "name": "WanVideo Tea Cache (native)"}, + "WanVideoEnhanceAVideoKJ": {"class": WanVideoEnhanceAVideoKJ, "name": "WanVideo Enhance A Video (native)"}, "TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"}, #instance diffusion diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 7f53c14..418ddb4 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -945,3 +945,134 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC model_clone.set_model_unet_function_wrapper(outer_wrapper(start_percent=start_percent, end_percent=end_percent)) return (model_clone,) + + + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.wan.model import WanSelfAttention +def modified_wan_self_attention_forward(self, x, freqs): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n * d) + return q, k, v + + q, k, v = qkv_fn(x) + + q, k = apply_rope(q, k, freqs) + + feta_scores = get_feta_scores(q, k, self.num_frames, self.enhance_weight) + + x = optimized_attention( + q.view(b, s, n * d), + k.view(b, s, n * d), + v, + heads=self.num_heads, + ) + + x = self.o(x) + + x *= feta_scores + + return x + +from einops import rearrange +def get_feta_scores(query, key, num_frames, enhance_weight): + img_q, img_k = query, key #torch.Size([2, 9216, 12, 128]) + + _, ST, num_heads, head_dim = img_q.shape + spatial_dim = ST / num_frames + spatial_dim = int(spatial_dim) + + query_image = rearrange( + img_q, "B (T S) N C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim + ) + key_image = rearrange( + img_k, "B (T S) N C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim + ) + + return feta_score(query_image, key_image, head_dim, num_frames, enhance_weight) + +def feta_score(query_image, key_image, head_dim, num_frames, enhance_weight): + scale = head_dim**-0.5 + query_image = query_image * scale + attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32 + attn_temp = attn_temp.to(torch.float32) + attn_temp = attn_temp.softmax(dim=-1) + + # Reshape to [batch_size * num_tokens, num_frames, num_frames] + attn_temp = attn_temp.reshape(-1, num_frames, num_frames) + + # Create a mask for diagonal elements + diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() + diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1) + + # Zero out diagonal elements + attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) + + # Calculate mean for each token's attention matrix + # Number of off-diagonal elements per matrix is n*n - n + num_off_diag = num_frames * num_frames - num_frames + mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag + + enhance_scores = mean_scores.mean() * (num_frames + enhance_weight) + enhance_scores = enhance_scores.clamp(min=1) + return enhance_scores + +import types +class WanAttentionPatch: + def __init__(self, num_frames, weight): + self.num_frames = num_frames + self.enhance_weight = weight + + def __get__(self, obj, objtype=None): + # Create bound method with stored parameters + def wrapped_attention(self_module, *args, **kwargs): + self_module.num_frames = self.num_frames + self_module.enhance_weight = self.enhance_weight + return modified_wan_self_attention_forward(self_module, *args, **kwargs) + return types.MethodType(wrapped_attention, obj) + +class WanVideoEnhanceAVideoKJ: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "latent": ("LATENT", {"tooltip": "Only used to get the latent count"}), + "weight": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of the enhance effect"}), + } + } + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "enhance" + CATEGORY = "KJNodes/experimental" + DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video" + EXPERIMENTAL = True + + def enhance(self, model, weight, latent): + if weight == 0: + return (model,) + + num_frames = latent["samples"].shape[2] + + model_clone = model.clone() + if 'transformer_options' not in model_clone.model_options: + model_clone.model_options['transformer_options'] = {} + model_clone.model_options["transformer_options"]["enhance_weight"] = weight + diffusion_model = model_clone.get_model_object("diffusion_model") + for idx, block in enumerate(diffusion_model.blocks): + self_attn = WanAttentionPatch(num_frames, weight).__get__(block.self_attn, block.__class__) + model_clone.add_object_patch(f"diffusion_model.blocks.{idx}.self_attn.forward", self_attn) + + return (model_clone,) \ No newline at end of file