mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-23 06:54:28 +08:00
Add VAEDecodeLoopKJ
VAE decoding node that helps fix the seam that can appear when decoding looped videos.
This commit is contained in:
parent
70a95fa264
commit
79f529a84a
@ -197,6 +197,7 @@ NODE_CONFIG = {
|
||||
"PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Patch Sage Attention KJ"},
|
||||
"LeapfusionHunyuanI2VPatcher": {"class": LeapfusionHunyuanI2V, "name": "Leapfusion Hunyuan I2V Patcher"},
|
||||
"VAELoaderKJ": {"class": VAELoaderKJ, "name": "VAELoader KJ"},
|
||||
"VAEDecodeLoopKJ": {"class": VAEDecodeLoopKJ, "name": "VAE Decode Loop KJ"},
|
||||
"ScheduledCFGGuidance": {"class": ScheduledCFGGuidance, "name": "Scheduled CFG Guidance"},
|
||||
"ApplyRifleXRoPE_HunuyanVideo": {"class": ApplyRifleXRoPE_HunuyanVideo, "name": "Apply RifleXRoPE HunuyanVideo"},
|
||||
"ApplyRifleXRoPE_WanVideo": {"class": ApplyRifleXRoPE_WanVideo, "name": "Apply RifleXRoPE WanVideo"},
|
||||
|
||||
@ -2873,3 +2873,44 @@ class AddNoiseToTrackPath(io.ComfyNode):
|
||||
"track_visibility": mask,
|
||||
}
|
||||
return io.NodeOutput(out_track)
|
||||
|
||||
|
||||
class VAEDecodeLoopKJ:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"samples": ("LATENT", {"tooltip": "The latent to be decoded."}),
|
||||
"vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."}),
|
||||
"overlap_latent_frames": ("INT", {"default": 2, "min": 2, "max": 8, "step": 1, "tooltip": "Number of frames to blend for seamless loop, for Wan 2 works and HunyuanVideo 1.5 should use 4"}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
OUTPUT_TOOLTIPS = ("The decoded images.",)
|
||||
FUNCTION = "decode"
|
||||
CATEGORY = "KJNodes/vae"
|
||||
DESCRIPTION = "Video latent VAE decoding to fix artifacts on loop seams."
|
||||
|
||||
def decode(self, vae, samples, overlap_latent_frames):
|
||||
latents = samples["samples"]
|
||||
|
||||
images = vae.decode(latents)
|
||||
if overlap_latent_frames <= 0:
|
||||
if len(images.shape) == 5:
|
||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||
return (images, )
|
||||
|
||||
end_frames = overlap_latent_frames + 1
|
||||
start_frames = overlap_latent_frames
|
||||
|
||||
temp_images = vae.decode(torch.cat([latents[:, :, -end_frames:]] + [latents[:, :, :start_frames]], dim=2)).cpu().float()
|
||||
|
||||
total_concat = end_frames + start_frames
|
||||
temp_start = total_concat * 2 - 1
|
||||
main_start = total_concat + (overlap_latent_frames if overlap_latent_frames > 2 else 0)
|
||||
|
||||
images = torch.cat([temp_images[:, temp_start:].to(images), images[:, main_start:]], dim=1)
|
||||
if len(images.shape) == 5:
|
||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||
|
||||
return (images, )
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user