From 17a7aed013ae3d0f12d91425701b524fdaa33816 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 28 Nov 2024 18:03:47 +0200 Subject: [PATCH] CogVideoConcatLatent --- model_loading.py | 14 +++++++++++++- nodes.py | 38 +++++++++++++++++++++++++++++--------- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/model_loading.py b/model_loading.py index a8f26d3..932bd99 100644 --- a/model_loading.py +++ b/model_loading.py @@ -637,7 +637,19 @@ class CogVideoXModelLoader: "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), "lora": ("COGLORA", {"default": None}), "compile_args":("COMPILEARGS", ), - "attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}), + "attention_mode": ([ + "sdpa", + "fused_sdpa", + "sageattn", + "fused_sageattn", + "sageattn_qk_int8_pv_fp8_cuda", + "sageattn_qk_int8_pv_fp16_cuda", + "sageattn_qk_int8_pv_fp16_triton", + "fused_sageattn_qk_int8_pv_fp8_cuda", + "fused_sageattn_qk_int8_pv_fp16_cuda", + "fused_sageattn_qk_int8_pv_fp16_triton", + "comfy" + ], {"default": "sdpa"}), } } diff --git a/nodes.py b/nodes.py index 8e3df39..d84854f 100644 --- a/nodes.py +++ b/nodes.py @@ -217,7 +217,6 @@ class CogVideoImageEncode: "start_image": ("IMAGE", ), }, "optional": { - "mid_image": ("IMAGE", ), "end_image": ("IMAGE", ), "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}), @@ -232,7 +231,7 @@ class CogVideoImageEncode: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, vae, start_image, mid_image=None, end_image=None, enable_tiling=False, noise_aug_strength=0.0, strength=1.0, start_percent=0.0, end_percent=1.0): + def encode(self, vae, start_image, end_image=None, enable_tiling=False, noise_aug_strength=0.0, strength=1.0, start_percent=0.0, end_percent=1.0): device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=device).manual_seed(0) @@ -264,13 +263,6 @@ class CogVideoImageEncode: start_latents = vae.encode(start_image).latent_dist.sample(generator) start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W - if mid_image is not None: - mid_image = (mid_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) - if noise_aug_strength > 0: - mid_image = add_noise_to_reference_video(mid_image, ratio=noise_aug_strength) - mid_latents = vae.encode(mid_image).latent_dist.sample(generator) - mid_latents = mid_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W - latents_list = [start_latents, mid_latents] if end_image is not None: end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) if noise_aug_strength > 0: @@ -292,6 +284,32 @@ class CogVideoImageEncode: "start_percent": start_percent, "end_percent": end_percent }, ) + +class CogVideoConcatLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "samples_to": ("LATENT", ), + "samples_from": ("LATENT",), + }, + } + + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("samples",) + FUNCTION = "encode" + CATEGORY = "CogVideoWrapper" + + def encode(self, samples_from, samples_to): + + insert_from = samples_from["samples"] + insert_to = samples_to["samples"] + new_latents = torch.cat((insert_to, insert_from), dim=1) + + print("new latents shape: ", new_latents.shape) + + + samples_to.update({"samples": new_latents}) + return (samples_to, ) class CogVideoImageEncodeFunInP: @classmethod @@ -967,6 +985,7 @@ NODE_CLASS_MAPPINGS = { "CogVideoLatentPreview": CogVideoLatentPreview, "CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings, "CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP, + "CogVideoConcatLatent": CogVideoConcatLatent, } NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoSampler": "CogVideo Sampler", @@ -983,4 +1002,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoLatentPreview": "CogVideo LatentPreview", "CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings", "CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP", + "CogVideoConcatLatent": "CogVideo Concat Latent", }