This commit is contained in:
kijai 2024-12-17 09:10:03 +02:00
parent b5eefbf4d4
commit 0758d2d016

View File

@ -473,7 +473,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if image_cond_latents is not None:
image_cond_frame_count = image_cond_latents.size(1)
patch_size_t = self.transformer.config.patch_size_t
if image_cond_latents.shape[1] == 2:
if image_cond_frame_count == 2:
logger.info("More than one image conditioning frame received, interpolating")
padding_shape = (
batch_size,
@ -485,11 +485,11 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
if patch_size_t:
first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...]
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % patch_size_t, ...]
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
logger.info(f"image cond latents shape: {image_cond_latents.shape}")
elif image_cond_latents.shape[1] == 1:
elif image_cond_frame_count == 1:
logger.info("Only one image conditioning frame received, img2vid")
if self.input_with_padding:
padding_shape = (
@ -503,7 +503,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1)
# Select the first frame along the second dimension
if patch_size_t:
first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...]
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % patch_size_t, ...]
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
else:
image_cond_latents = image_cond_latents.repeat(1, latents.shape[1], 1, 1, 1)