mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-02 14:46:49 +08:00
Initial support for Fun 1.5
I2V works, interpolation doesn't seem to (not sure if it should)
This commit is contained in:
parent
795f8b0565
commit
b5eefbf4d4
@ -147,6 +147,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose",
|
"alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose",
|
||||||
"alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose",
|
"alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose",
|
||||||
"alibaba-pai/CogVideoX-Fun-V1.1-5b-Control",
|
"alibaba-pai/CogVideoX-Fun-V1.1-5b-Control",
|
||||||
|
"alibaba-pai/CogVideoX-Fun-V1.5-5b-InP",
|
||||||
"feizhengcong/CogvideoX-Interpolation",
|
"feizhengcong/CogvideoX-Interpolation",
|
||||||
"NimVideo/cogvideox-2b-img2vid"
|
"NimVideo/cogvideox-2b-img2vid"
|
||||||
],
|
],
|
||||||
@ -215,7 +216,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
download_path = folder_paths.get_folder_paths("CogVideo")[0]
|
download_path = folder_paths.get_folder_paths("CogVideo")[0]
|
||||||
|
|
||||||
if "Fun" in model:
|
if "Fun" in model:
|
||||||
if not "1.1" in model:
|
if not "1.1" and not "1.5" in model:
|
||||||
repo_id = "kijai/CogVideoX-Fun-pruned"
|
repo_id = "kijai/CogVideoX-Fun-pruned"
|
||||||
if "2b" in model:
|
if "2b" in model:
|
||||||
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-2b-InP") # location of the official model
|
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-2b-InP") # location of the official model
|
||||||
@ -225,7 +226,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP") # location of the official model
|
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP") # location of the official model
|
||||||
if not os.path.exists(base_path):
|
if not os.path.exists(base_path):
|
||||||
base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP")
|
base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP")
|
||||||
elif "1.1" in model:
|
else:
|
||||||
repo_id = model
|
repo_id = model
|
||||||
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", (model.split("/")[-1])) # location of the official model
|
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", (model.split("/")[-1])) # location of the official model
|
||||||
if not os.path.exists(base_path):
|
if not os.path.exists(base_path):
|
||||||
@ -278,7 +279,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder, attention_mode=attention_mode)
|
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder, attention_mode=attention_mode)
|
||||||
transformer = transformer.to(dtype).to(transformer_load_device)
|
transformer = transformer.to(dtype).to(transformer_load_device)
|
||||||
|
|
||||||
if "1.5" in model:
|
if "1.5" in model and not "fun" in model:
|
||||||
transformer.config.sample_height = 300
|
transformer.config.sample_height = 300
|
||||||
transformer.config.sample_width = 300
|
transformer.config.sample_width = 300
|
||||||
|
|
||||||
|
|||||||
6
nodes.py
6
nodes.py
@ -360,8 +360,8 @@ class CogVideoImageEncodeFunInP:
|
|||||||
masked_image_latents = masked_image_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
masked_image_latents = masked_image_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
||||||
|
|
||||||
mask = torch.zeros_like(masked_image_latents[:, :, :1, :, :])
|
mask = torch.zeros_like(masked_image_latents[:, :, :1, :, :])
|
||||||
if end_image is not None:
|
#if end_image is not None:
|
||||||
mask[:, -1, :, :, :] = 0
|
# mask[:, -1, :, :, :] = 0
|
||||||
mask[:, 0, :, :, :] = vae_scaling_factor
|
mask[:, 0, :, :, :] = vae_scaling_factor
|
||||||
|
|
||||||
final_latents = masked_image_latents * vae_scaling_factor
|
final_latents = masked_image_latents * vae_scaling_factor
|
||||||
@ -623,7 +623,7 @@ class CogVideoSampler:
|
|||||||
image_conds = image_cond_latents["samples"]
|
image_conds = image_cond_latents["samples"]
|
||||||
image_cond_start_percent = image_cond_latents.get("start_percent", 0.0)
|
image_cond_start_percent = image_cond_latents.get("start_percent", 0.0)
|
||||||
image_cond_end_percent = image_cond_latents.get("end_percent", 1.0)
|
image_cond_end_percent = image_cond_latents.get("end_percent", 1.0)
|
||||||
if "1.5" in model_name or "1_5" in model_name:
|
if ("1.5" in model_name or "1_5" in model_name) and not "fun" in model_name.lower():
|
||||||
image_conds = image_conds / 0.7 # needed for 1.5 models
|
image_conds = image_conds / 0.7 # needed for 1.5 models
|
||||||
else:
|
else:
|
||||||
if not "fun" in model_name.lower():
|
if not "fun" in model_name.lower():
|
||||||
|
|||||||
@ -471,6 +471,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
|
|
||||||
# 5.5.
|
# 5.5.
|
||||||
if image_cond_latents is not None:
|
if image_cond_latents is not None:
|
||||||
|
image_cond_frame_count = image_cond_latents.size(1)
|
||||||
|
patch_size_t = self.transformer.config.patch_size_t
|
||||||
if image_cond_latents.shape[1] == 2:
|
if image_cond_latents.shape[1] == 2:
|
||||||
logger.info("More than one image conditioning frame received, interpolating")
|
logger.info("More than one image conditioning frame received, interpolating")
|
||||||
padding_shape = (
|
padding_shape = (
|
||||||
@ -482,8 +484,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
)
|
)
|
||||||
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
|
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
|
||||||
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
|
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
|
||||||
if self.transformer.config.patch_size_t is not None:
|
if patch_size_t:
|
||||||
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
|
first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...]
|
||||||
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
||||||
|
|
||||||
logger.info(f"image cond latents shape: {image_cond_latents.shape}")
|
logger.info(f"image cond latents shape: {image_cond_latents.shape}")
|
||||||
@ -500,13 +502,19 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
|
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
|
||||||
image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1)
|
image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1)
|
||||||
# Select the first frame along the second dimension
|
# Select the first frame along the second dimension
|
||||||
if self.transformer.config.patch_size_t is not None:
|
if patch_size_t:
|
||||||
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
|
first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...]
|
||||||
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
||||||
else:
|
else:
|
||||||
image_cond_latents = image_cond_latents.repeat(1, latents.shape[1], 1, 1, 1)
|
image_cond_latents = image_cond_latents.repeat(1, latents.shape[1], 1, 1, 1)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Received {image_cond_latents.shape[1]} image conditioning frames")
|
logger.info(f"Received {image_cond_latents.shape[1]} image conditioning frames")
|
||||||
|
if fun_mask is not None and patch_size_t:
|
||||||
|
logger.info(f"1.5 model received {fun_mask.shape[1]} masks")
|
||||||
|
first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...]
|
||||||
|
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
||||||
|
fun_mask_first_frame = fun_mask[:, : image_cond_frame_count % patch_size_t, ...]
|
||||||
|
fun_mask = torch.cat([fun_mask_first_frame, fun_mask], dim=1)
|
||||||
image_cond_latents = image_cond_latents.to(self.vae_dtype)
|
image_cond_latents = image_cond_latents.to(self.vae_dtype)
|
||||||
|
|
||||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user