mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-25 12:09:10 +08:00
mid_image
This commit is contained in:
parent
2b211b9d1b
commit
f9c747eff5
13
nodes.py
13
nodes.py
@ -217,6 +217,7 @@ class CogVideoImageEncode:
|
|||||||
"start_image": ("IMAGE", ),
|
"start_image": ("IMAGE", ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
|
"mid_image": ("IMAGE", ),
|
||||||
"end_image": ("IMAGE", ),
|
"end_image": ("IMAGE", ),
|
||||||
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
||||||
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
|
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
|
||||||
@ -231,7 +232,7 @@ class CogVideoImageEncode:
|
|||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def encode(self, vae, start_image, end_image=None, enable_tiling=False, noise_aug_strength=0.0, strength=1.0, start_percent=0.0, end_percent=1.0):
|
def encode(self, vae, start_image, mid_image=None, end_image=None, enable_tiling=False, noise_aug_strength=0.0, strength=1.0, start_percent=0.0, end_percent=1.0):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
generator = torch.Generator(device=device).manual_seed(0)
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
@ -263,14 +264,20 @@ class CogVideoImageEncode:
|
|||||||
start_latents = vae.encode(start_image).latent_dist.sample(generator)
|
start_latents = vae.encode(start_image).latent_dist.sample(generator)
|
||||||
start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
||||||
|
|
||||||
|
if mid_image is not None:
|
||||||
|
mid_image = (mid_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3)
|
||||||
|
if noise_aug_strength > 0:
|
||||||
|
mid_image = add_noise_to_reference_video(mid_image, ratio=noise_aug_strength)
|
||||||
|
mid_latents = vae.encode(mid_image).latent_dist.sample(generator)
|
||||||
|
mid_latents = mid_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
||||||
|
latents_list = [start_latents, mid_latents]
|
||||||
if end_image is not None:
|
if end_image is not None:
|
||||||
end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3)
|
end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3)
|
||||||
if noise_aug_strength > 0:
|
if noise_aug_strength > 0:
|
||||||
end_image = add_noise_to_reference_video(end_image, ratio=noise_aug_strength)
|
end_image = add_noise_to_reference_video(end_image, ratio=noise_aug_strength)
|
||||||
end_latents = vae.encode(end_image).latent_dist.sample(generator)
|
end_latents = vae.encode(end_image).latent_dist.sample(generator)
|
||||||
end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
||||||
latents_list = [start_latents, end_latents]
|
latents_list.append(end_latents)
|
||||||
final_latents = torch.cat(latents_list, dim=1)
|
final_latents = torch.cat(latents_list, dim=1)
|
||||||
else:
|
else:
|
||||||
final_latents = start_latents
|
final_latents = start_latents
|
||||||
|
|||||||
@ -473,22 +473,47 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
if image_cond_latents is not None:
|
if image_cond_latents is not None:
|
||||||
if image_cond_latents.shape[1] == 3:
|
if image_cond_latents.shape[1] == 3:
|
||||||
logger.info("More than one image conditioning frame received, interpolating")
|
logger.info("More than one image conditioning frame received, interpolating")
|
||||||
|
total_padding = latents.shape[1] - 3
|
||||||
|
half_padding = total_padding // 2
|
||||||
|
|
||||||
padding_shape = (
|
padding_shape = (
|
||||||
batch_size,
|
batch_size,
|
||||||
(latents.shape[1] - 3),
|
half_padding,
|
||||||
self.vae_latent_channels,
|
self.vae_latent_channels,
|
||||||
height // self.vae_scale_factor_spatial,
|
height // self.vae_scale_factor_spatial,
|
||||||
width // self.vae_scale_factor_spatial,
|
width // self.vae_scale_factor_spatial,
|
||||||
)
|
)
|
||||||
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
|
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
|
||||||
middle_frame = image_cond_latents[:, 2, :, :, :]
|
middle_frame = image_cond_latents[:, 1, :, :, :].unsqueeze(1)
|
||||||
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
|
|
||||||
middle_frame_idx = image_cond_latents.shape[1] // 2
|
image_cond_latents = torch.cat([
|
||||||
image_cond_latents = image_cond_latents[:, middle_frame_idx, :, :, :] = middle_frame
|
image_cond_latents[:, 0, :, :, :].unsqueeze(1),
|
||||||
|
latent_padding,
|
||||||
|
middle_frame,
|
||||||
|
latent_padding,
|
||||||
|
image_cond_latents[:, -1, :, :, :].unsqueeze(1)
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
# If total_padding is odd, add one more padding after the middle frame
|
||||||
|
if total_padding % 2 != 0:
|
||||||
|
extra_padding = torch.zeros(
|
||||||
|
(batch_size, 1, self.vae_latent_channels,
|
||||||
|
height // self.vae_scale_factor_spatial,
|
||||||
|
width // self.vae_scale_factor_spatial),
|
||||||
|
device=device, dtype=self.vae_dtype
|
||||||
|
)
|
||||||
|
image_cond_latents = torch.cat([image_cond_latents, extra_padding], dim=1)
|
||||||
|
|
||||||
if self.transformer.config.patch_size_t is not None:
|
if self.transformer.config.patch_size_t is not None:
|
||||||
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
|
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
|
||||||
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
||||||
|
|
||||||
|
middle_frame_idx = image_cond_latents.shape[1] // 2
|
||||||
|
print("middle_frame_idx", middle_frame_idx)
|
||||||
|
print(middle_frame.shape)
|
||||||
|
print(image_cond_latents.shape)
|
||||||
|
|
||||||
|
|
||||||
elif image_cond_latents.shape[1] == 2:
|
elif image_cond_latents.shape[1] == 2:
|
||||||
logger.info("More than one image conditioning frame received, interpolating")
|
logger.info("More than one image conditioning frame received, interpolating")
|
||||||
padding_shape = (
|
padding_shape = (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user