Add VAEDecodeLoopKJ

VAE decoding node that helps fix the seam that can appear when decoding looped videos.
This commit is contained in:
kijai 2025-12-21 14:23:06 +02:00
parent 70a95fa264
commit 79f529a84a
2 changed files with 42 additions and 0 deletions

View File

@ -197,6 +197,7 @@ NODE_CONFIG = {
"PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Patch Sage Attention KJ"},
"LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"},
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
"VAEDecodeLoopKJ": {"class": VAEDecodeLoopKJ, "name": "VAE Decode Loop KJ"},
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},

View File

@ -2873,3 +2873,44 @@ class AddNoiseToTrackPath(io.ComfyNode):
"track_visibility": mask,
}
return io.NodeOutput(out_track)
class VAEDecodeLoopKJ:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT", {"tooltip": "The latent to be decoded."}),
"vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."}),
"overlap_latent_frames": ("INT", {"default": 2, "min": 2, "max": 8, "step": 1, "tooltip": "Number of frames to blend for seamless loop, for Wan 2 works and HunyuanVideo 1.5 should use 4"}),
}
}
RETURN_TYPES = ("IMAGE",)
OUTPUT_TOOLTIPS = ("The decoded images.",)
FUNCTION = "decode"
CATEGORY = "KJNodes/vae"
DESCRIPTION = "Video latent VAE decoding to fix artifacts on loop seams."
def decode(self, vae, samples, overlap_latent_frames):
latents = samples["samples"]
images = vae.decode(latents)
if overlap_latent_frames <= 0:
if len(images.shape) == 5:
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )
end_frames = overlap_latent_frames + 1
start_frames = overlap_latent_frames
temp_images = vae.decode(torch.cat([latents[:, :, -end_frames:]] + [latents[:, :, :start_frames]], dim=2)).cpu().float()
total_concat = end_frames + start_frames
temp_start = total_concat * 2 - 1
main_start = total_concat + (overlap_latent_frames if overlap_latent_frames > 2 else 0)
images = torch.cat([temp_images[:, temp_start:].to(images), images[:, main_start:]], dim=1)
if len(images.shape) == 5:
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )