From f9c747eff51522c96c15d69d84c78a7b9d018423 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 27 Nov 2024 01:52:58 +0200 Subject: [PATCH] mid_image --- nodes.py | 13 ++++++++++--- pipeline_cogvideox.py | 35 ++++++++++++++++++++++++++++++----- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/nodes.py b/nodes.py index 3819b54..8e3df39 100644 --- a/nodes.py +++ b/nodes.py @@ -217,6 +217,7 @@ class CogVideoImageEncode: "start_image": ("IMAGE", ), }, "optional": { + "mid_image": ("IMAGE", ), "end_image": ("IMAGE", ), "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}), @@ -231,7 +232,7 @@ class CogVideoImageEncode: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, vae, start_image, end_image=None, enable_tiling=False, noise_aug_strength=0.0, strength=1.0, start_percent=0.0, end_percent=1.0): + def encode(self, vae, start_image, mid_image=None, end_image=None, enable_tiling=False, noise_aug_strength=0.0, strength=1.0, start_percent=0.0, end_percent=1.0): device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=device).manual_seed(0) @@ -263,14 +264,20 @@ class CogVideoImageEncode: start_latents = vae.encode(start_image).latent_dist.sample(generator) start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W - + if mid_image is not None: + mid_image = (mid_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) + if noise_aug_strength > 0: + mid_image = add_noise_to_reference_video(mid_image, ratio=noise_aug_strength) + mid_latents = vae.encode(mid_image).latent_dist.sample(generator) + mid_latents = mid_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W + latents_list = [start_latents, mid_latents] if end_image is not None: end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) if noise_aug_strength > 0: end_image = add_noise_to_reference_video(end_image, ratio=noise_aug_strength) end_latents = vae.encode(end_image).latent_dist.sample(generator) end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W - latents_list = [start_latents, end_latents] + latents_list.append(end_latents) final_latents = torch.cat(latents_list, dim=1) else: final_latents = start_latents diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 33a1047..d1391d3 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -473,22 +473,47 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): if image_cond_latents is not None: if image_cond_latents.shape[1] == 3: logger.info("More than one image conditioning frame received, interpolating") + total_padding = latents.shape[1] - 3 + half_padding = total_padding // 2 + padding_shape = ( batch_size, - (latents.shape[1] - 3), + half_padding, self.vae_latent_channels, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype) - middle_frame = image_cond_latents[:, 2, :, :, :] - image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1) - middle_frame_idx = image_cond_latents.shape[1] // 2 - image_cond_latents = image_cond_latents[:, middle_frame_idx, :, :, :] = middle_frame + middle_frame = image_cond_latents[:, 1, :, :, :].unsqueeze(1) + + image_cond_latents = torch.cat([ + image_cond_latents[:, 0, :, :, :].unsqueeze(1), + latent_padding, + middle_frame, + latent_padding, + image_cond_latents[:, -1, :, :, :].unsqueeze(1) + ], dim=1) + + # If total_padding is odd, add one more padding after the middle frame + if total_padding % 2 != 0: + extra_padding = torch.zeros( + (batch_size, 1, self.vae_latent_channels, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial), + device=device, dtype=self.vae_dtype + ) + image_cond_latents = torch.cat([image_cond_latents, extra_padding], dim=1) if self.transformer.config.patch_size_t is not None: first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...] image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1) + + middle_frame_idx = image_cond_latents.shape[1] // 2 + print("middle_frame_idx", middle_frame_idx) + print(middle_frame.shape) + print(image_cond_latents.shape) + + elif image_cond_latents.shape[1] == 2: logger.info("More than one image conditioning frame received, interpolating") padding_shape = (