Initial support for Fun 1.5

I2V works, interpolation doesn't seem to (not sure if it should)
This commit is contained in:
kijai 2024-12-17 01:16:11 +02:00
parent 795f8b0565
commit b5eefbf4d4
3 changed files with 19 additions and 10 deletions

View File

@ -147,6 +147,7 @@ class DownloadAndLoadCogVideoModel:
"alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose",
"alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose",
"alibaba-pai/CogVideoX-Fun-V1.1-5b-Control",
"alibaba-pai/CogVideoX-Fun-V1.5-5b-InP",
"feizhengcong/CogvideoX-Interpolation",
"NimVideo/cogvideox-2b-img2vid"
],
@ -215,7 +216,7 @@ class DownloadAndLoadCogVideoModel:
download_path = folder_paths.get_folder_paths("CogVideo")[0]
if "Fun" in model:
if not "1.1" in model:
if not "1.1" and not "1.5" in model:
repo_id = "kijai/CogVideoX-Fun-pruned"
if "2b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-2b-InP") # location of the official model
@ -225,7 +226,7 @@ class DownloadAndLoadCogVideoModel:
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP") # location of the official model
if not os.path.exists(base_path):
base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP")
elif "1.1" in model:
else:
repo_id = model
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", (model.split("/")[-1])) # location of the official model
if not os.path.exists(base_path):
@ -278,7 +279,7 @@ class DownloadAndLoadCogVideoModel:
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder, attention_mode=attention_mode)
transformer = transformer.to(dtype).to(transformer_load_device)
if "1.5" in model:
if "1.5" in model and not "fun" in model:
transformer.config.sample_height = 300
transformer.config.sample_width = 300

View File

@ -360,8 +360,8 @@ class CogVideoImageEncodeFunInP:
masked_image_latents = masked_image_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
mask = torch.zeros_like(masked_image_latents[:, :, :1, :, :])
if end_image is not None:
mask[:, -1, :, :, :] = 0
#if end_image is not None:
# mask[:, -1, :, :, :] = 0
mask[:, 0, :, :, :] = vae_scaling_factor
final_latents = masked_image_latents * vae_scaling_factor
@ -623,7 +623,7 @@ class CogVideoSampler:
image_conds = image_cond_latents["samples"]
image_cond_start_percent = image_cond_latents.get("start_percent", 0.0)
image_cond_end_percent = image_cond_latents.get("end_percent", 1.0)
if "1.5" in model_name or "1_5" in model_name:
if ("1.5" in model_name or "1_5" in model_name) and not "fun" in model_name.lower():
image_conds = image_conds / 0.7 # needed for 1.5 models
else:
if not "fun" in model_name.lower():

View File

@ -471,6 +471,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
# 5.5.
if image_cond_latents is not None:
image_cond_frame_count = image_cond_latents.size(1)
patch_size_t = self.transformer.config.patch_size_t
if image_cond_latents.shape[1] == 2:
logger.info("More than one image conditioning frame received, interpolating")
padding_shape = (
@ -482,8 +484,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
)
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
if self.transformer.config.patch_size_t is not None:
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
if patch_size_t:
first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...]
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
logger.info(f"image cond latents shape: {image_cond_latents.shape}")
@ -500,13 +502,19 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1)
# Select the first frame along the second dimension
if self.transformer.config.patch_size_t is not None:
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
if patch_size_t:
first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...]
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
else:
image_cond_latents = image_cond_latents.repeat(1, latents.shape[1], 1, 1, 1)
else:
logger.info(f"Received {image_cond_latents.shape[1]} image conditioning frames")
if fun_mask is not None and patch_size_t:
logger.info(f"1.5 model received {fun_mask.shape[1]} masks")
first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...]
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
fun_mask_first_frame = fun_mask[:, : image_cond_frame_count % patch_size_t, ...]
fun_mask = torch.cat([fun_mask_first_frame, fun_mask], dim=1)
image_cond_latents = image_cond_latents.to(self.vae_dtype)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline