mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-01-27 04:25:17 +08:00
mid_image
This commit is contained in:
parent
2b211b9d1b
commit
f9c747eff5
13
nodes.py
13
nodes.py
@ -217,6 +217,7 @@ class CogVideoImageEncode:
|
||||
"start_image": ("IMAGE", ),
|
||||
},
|
||||
"optional": {
|
||||
"mid_image": ("IMAGE", ),
|
||||
"end_image": ("IMAGE", ),
|
||||
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
||||
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
|
||||
@ -231,7 +232,7 @@ class CogVideoImageEncode:
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, vae, start_image, end_image=None, enable_tiling=False, noise_aug_strength=0.0, strength=1.0, start_percent=0.0, end_percent=1.0):
|
||||
def encode(self, vae, start_image, mid_image=None, end_image=None, enable_tiling=False, noise_aug_strength=0.0, strength=1.0, start_percent=0.0, end_percent=1.0):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
@ -263,14 +264,20 @@ class CogVideoImageEncode:
|
||||
start_latents = vae.encode(start_image).latent_dist.sample(generator)
|
||||
start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
||||
|
||||
|
||||
if mid_image is not None:
|
||||
mid_image = (mid_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3)
|
||||
if noise_aug_strength > 0:
|
||||
mid_image = add_noise_to_reference_video(mid_image, ratio=noise_aug_strength)
|
||||
mid_latents = vae.encode(mid_image).latent_dist.sample(generator)
|
||||
mid_latents = mid_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
||||
latents_list = [start_latents, mid_latents]
|
||||
if end_image is not None:
|
||||
end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3)
|
||||
if noise_aug_strength > 0:
|
||||
end_image = add_noise_to_reference_video(end_image, ratio=noise_aug_strength)
|
||||
end_latents = vae.encode(end_image).latent_dist.sample(generator)
|
||||
end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
||||
latents_list = [start_latents, end_latents]
|
||||
latents_list.append(end_latents)
|
||||
final_latents = torch.cat(latents_list, dim=1)
|
||||
else:
|
||||
final_latents = start_latents
|
||||
|
||||
@ -473,22 +473,47 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if image_cond_latents is not None:
|
||||
if image_cond_latents.shape[1] == 3:
|
||||
logger.info("More than one image conditioning frame received, interpolating")
|
||||
total_padding = latents.shape[1] - 3
|
||||
half_padding = total_padding // 2
|
||||
|
||||
padding_shape = (
|
||||
batch_size,
|
||||
(latents.shape[1] - 3),
|
||||
half_padding,
|
||||
self.vae_latent_channels,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
|
||||
middle_frame = image_cond_latents[:, 2, :, :, :]
|
||||
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
|
||||
middle_frame_idx = image_cond_latents.shape[1] // 2
|
||||
image_cond_latents = image_cond_latents[:, middle_frame_idx, :, :, :] = middle_frame
|
||||
middle_frame = image_cond_latents[:, 1, :, :, :].unsqueeze(1)
|
||||
|
||||
image_cond_latents = torch.cat([
|
||||
image_cond_latents[:, 0, :, :, :].unsqueeze(1),
|
||||
latent_padding,
|
||||
middle_frame,
|
||||
latent_padding,
|
||||
image_cond_latents[:, -1, :, :, :].unsqueeze(1)
|
||||
], dim=1)
|
||||
|
||||
# If total_padding is odd, add one more padding after the middle frame
|
||||
if total_padding % 2 != 0:
|
||||
extra_padding = torch.zeros(
|
||||
(batch_size, 1, self.vae_latent_channels,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial),
|
||||
device=device, dtype=self.vae_dtype
|
||||
)
|
||||
image_cond_latents = torch.cat([image_cond_latents, extra_padding], dim=1)
|
||||
|
||||
if self.transformer.config.patch_size_t is not None:
|
||||
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
|
||||
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
||||
|
||||
middle_frame_idx = image_cond_latents.shape[1] // 2
|
||||
print("middle_frame_idx", middle_frame_idx)
|
||||
print(middle_frame.shape)
|
||||
print(image_cond_latents.shape)
|
||||
|
||||
|
||||
elif image_cond_latents.shape[1] == 2:
|
||||
logger.info("More than one image conditioning frame received, interpolating")
|
||||
padding_shape = (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user