mid_image

This commit is contained in:
kijai 2024-11-27 01:52:58 +02:00
parent 2b211b9d1b
commit f9c747eff5
2 changed files with 40 additions and 8 deletions

View File

@ -217,6 +217,7 @@ class CogVideoImageEncode:
"start_image": ("IMAGE", ),
},
"optional": {
"mid_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
@ -231,7 +232,7 @@ class CogVideoImageEncode:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, vae, start_image, end_image=None, enable_tiling=False, noise_aug_strength=0.0, strength=1.0, start_percent=0.0, end_percent=1.0):
def encode(self, vae, start_image, mid_image=None, end_image=None, enable_tiling=False, noise_aug_strength=0.0, strength=1.0, start_percent=0.0, end_percent=1.0):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
@ -263,14 +264,20 @@ class CogVideoImageEncode:
start_latents = vae.encode(start_image).latent_dist.sample(generator)
start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
if mid_image is not None:
mid_image = (mid_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3)
if noise_aug_strength > 0:
mid_image = add_noise_to_reference_video(mid_image, ratio=noise_aug_strength)
mid_latents = vae.encode(mid_image).latent_dist.sample(generator)
mid_latents = mid_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
latents_list = [start_latents, mid_latents]
if end_image is not None:
end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3)
if noise_aug_strength > 0:
end_image = add_noise_to_reference_video(end_image, ratio=noise_aug_strength)
end_latents = vae.encode(end_image).latent_dist.sample(generator)
end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
latents_list = [start_latents, end_latents]
latents_list.append(end_latents)
final_latents = torch.cat(latents_list, dim=1)
else:
final_latents = start_latents

View File

@ -473,22 +473,47 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if image_cond_latents is not None:
if image_cond_latents.shape[1] == 3:
logger.info("More than one image conditioning frame received, interpolating")
total_padding = latents.shape[1] - 3
half_padding = total_padding // 2
padding_shape = (
batch_size,
(latents.shape[1] - 3),
half_padding,
self.vae_latent_channels,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
middle_frame = image_cond_latents[:, 2, :, :, :]
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
middle_frame_idx = image_cond_latents.shape[1] // 2
image_cond_latents = image_cond_latents[:, middle_frame_idx, :, :, :] = middle_frame
middle_frame = image_cond_latents[:, 1, :, :, :].unsqueeze(1)
image_cond_latents = torch.cat([
image_cond_latents[:, 0, :, :, :].unsqueeze(1),
latent_padding,
middle_frame,
latent_padding,
image_cond_latents[:, -1, :, :, :].unsqueeze(1)
], dim=1)
# If total_padding is odd, add one more padding after the middle frame
if total_padding % 2 != 0:
extra_padding = torch.zeros(
(batch_size, 1, self.vae_latent_channels,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial),
device=device, dtype=self.vae_dtype
)
image_cond_latents = torch.cat([image_cond_latents, extra_padding], dim=1)
if self.transformer.config.patch_size_t is not None:
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
middle_frame_idx = image_cond_latents.shape[1] // 2
print("middle_frame_idx", middle_frame_idx)
print(middle_frame.shape)
print(image_cond_latents.shape)
elif image_cond_latents.shape[1] == 2:
logger.info("More than one image conditioning frame received, interpolating")
padding_shape = (