CogVideoConcatLatent

This commit is contained in:
kijai 2024-11-28 18:03:47 +02:00
parent f9c747eff5
commit 17a7aed013
2 changed files with 42 additions and 10 deletions

View File

@ -637,7 +637,19 @@ class CogVideoXModelLoader:
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ),
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}),
"attention_mode": ([
"sdpa",
"fused_sdpa",
"sageattn",
"fused_sageattn",
"sageattn_qk_int8_pv_fp8_cuda",
"sageattn_qk_int8_pv_fp16_cuda",
"sageattn_qk_int8_pv_fp16_triton",
"fused_sageattn_qk_int8_pv_fp8_cuda",
"fused_sageattn_qk_int8_pv_fp16_cuda",
"fused_sageattn_qk_int8_pv_fp16_triton",
"comfy"
], {"default": "sdpa"}),
}
}

View File

@ -217,7 +217,6 @@ class CogVideoImageEncode:
"start_image": ("IMAGE", ),
},
"optional": {
"mid_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
@ -232,7 +231,7 @@ class CogVideoImageEncode:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, vae, start_image, mid_image=None, end_image=None, enable_tiling=False, noise_aug_strength=0.0, strength=1.0, start_percent=0.0, end_percent=1.0):
def encode(self, vae, start_image, end_image=None, enable_tiling=False, noise_aug_strength=0.0, strength=1.0, start_percent=0.0, end_percent=1.0):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
@ -264,13 +263,6 @@ class CogVideoImageEncode:
start_latents = vae.encode(start_image).latent_dist.sample(generator)
start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
if mid_image is not None:
mid_image = (mid_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3)
if noise_aug_strength > 0:
mid_image = add_noise_to_reference_video(mid_image, ratio=noise_aug_strength)
mid_latents = vae.encode(mid_image).latent_dist.sample(generator)
mid_latents = mid_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
latents_list = [start_latents, mid_latents]
if end_image is not None:
end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3)
if noise_aug_strength > 0:
@ -292,6 +284,32 @@ class CogVideoImageEncode:
"start_percent": start_percent,
"end_percent": end_percent
}, )
class CogVideoConcatLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"samples_to": ("LATENT", ),
"samples_from": ("LATENT",),
},
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, samples_from, samples_to):
insert_from = samples_from["samples"]
insert_to = samples_to["samples"]
new_latents = torch.cat((insert_to, insert_from), dim=1)
print("new latents shape: ", new_latents.shape)
samples_to.update({"samples": new_latents})
return (samples_to, )
class CogVideoImageEncodeFunInP:
@classmethod
@ -967,6 +985,7 @@ NODE_CLASS_MAPPINGS = {
"CogVideoLatentPreview": CogVideoLatentPreview,
"CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings,
"CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP,
"CogVideoConcatLatent": CogVideoConcatLatent,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoSampler": "CogVideo Sampler",
@ -983,4 +1002,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoLatentPreview": "CogVideo LatentPreview",
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
"CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
"CogVideoConcatLatent": "CogVideo Concat Latent",
}