From 79f529a84a8c20fe5dcdfa984c6be7a94102c014 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 21 Dec 2025 14:23:06 +0200 Subject: [PATCH] Add VAEDecodeLoopKJ VAE decoding node that helps fix the seam that can appear when decoding looped videos. --- __init__.py | 1 + nodes/nodes.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/__init__.py b/__init__.py index 05af56b..2a7f73e 100644 --- a/__init__.py +++ b/__init__.py @@ -197,6 +197,7 @@ NODE_CONFIG = { "PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Patch Sage Attention KJ"}, "LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"}, "VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"}, + "VAEDecodeLoopKJ": {"class": VAEDecodeLoopKJ, "name": "VAE Decode Loop KJ"}, "ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"}, "ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"}, "ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"}, diff --git a/nodes/nodes.py b/nodes/nodes.py index 33701fd..372c851 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2873,3 +2873,44 @@ class AddNoiseToTrackPath(io.ComfyNode): "track_visibility": mask, } return io.NodeOutput(out_track) + + +class VAEDecodeLoopKJ: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT", {"tooltip": "The latent to be decoded."}), + "vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."}), + "overlap_latent_frames": ("INT", {"default": 2, "min": 2, "max": 8, "step": 1, "tooltip": "Number of frames to blend for seamless loop, for Wan 2 works and HunyuanVideo 1.5 should use 4"}), + } + } + RETURN_TYPES = ("IMAGE",) + OUTPUT_TOOLTIPS = ("The decoded images.",) + FUNCTION = "decode" + CATEGORY = "KJNodes/vae" + DESCRIPTION = "Video latent VAE decoding to fix artifacts on loop seams." + + def decode(self, vae, samples, overlap_latent_frames): + latents = samples["samples"] + + images = vae.decode(latents) + if overlap_latent_frames <= 0: + if len(images.shape) == 5: + images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) + return (images, ) + + end_frames = overlap_latent_frames + 1 + start_frames = overlap_latent_frames + + temp_images = vae.decode(torch.cat([latents[:, :, -end_frames:]] + [latents[:, :, :start_frames]], dim=2)).cpu().float() + + total_concat = end_frames + start_frames + temp_start = total_concat * 2 - 1 + main_start = total_concat + (overlap_latent_frames if overlap_latent_frames > 2 else 0) + + images = torch.cat([temp_images[:, temp_start:].to(images), images[:, main_start:]], dim=1) + if len(images.shape) == 5: + images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) + + return (images, )