From de7e069286978fb2af6c0685e0cf57acbfa55463 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 20 Nov 2024 02:12:56 +0200 Subject: [PATCH] fix noise augment --- nodes.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index 4af7b3e..ebd3102 100644 --- a/nodes.py +++ b/nodes.py @@ -254,19 +254,20 @@ class CogVideoImageEncode: except: pass - if noise_aug_strength > 0: - start_image = add_noise_to_reference_video(start_image, ratio=noise_aug_strength) - if end_image is not None: - end_image = add_noise_to_reference_video(end_image, ratio=noise_aug_strength) latents_list = [] start_image = (start_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W + if noise_aug_strength > 0: + start_image = add_noise_to_reference_video(start_image, ratio=noise_aug_strength) start_latents = vae.encode(start_image).latent_dist.sample(generator) start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W + if end_image is not None: end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) + if noise_aug_strength > 0: + end_image = add_noise_to_reference_video(end_image, ratio=noise_aug_strength) end_latents = vae.encode(end_image).latent_dist.sample(generator) end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W latents_list = [start_latents, end_latents]