Update nodes.py

This commit is contained in:
kijai 2024-09-19 17:29:22 +03:00
parent c3eae2edc1
commit 62ebaf986a

View File

@ -274,6 +274,37 @@ class CogVideoTextEncode:
return (embeds, )
class CogVideoTextEncodeCombine:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning_1": ("CONDITIONING",),
"conditioning_2": ("CONDITIONING",),
"combination_mode": (["average", "weighted_average", "concatenate"], {"default": "weighted_average"}),
"weighted_average_ratio": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
},
}
RETURN_TYPES = ("CONDITIONING",)
RETURN_NAMES = ("conditioning",)
FUNCTION = "process"
CATEGORY = "CogVideoWrapper"
def process(self, conditioning_1, conditioning_2, combination_mode, weighted_average_ratio):
if conditioning_1.shape != conditioning_2.shape:
raise ValueError("conditioning_1 and conditioning_2 must have the same shape")
if combination_mode == "average":
embeds = (conditioning_1 + conditioning_2) / 2
elif combination_mode == "weighted_average":
embeds = conditioning_1 * (1 - weighted_average_ratio) + conditioning_2 * weighted_average_ratio
elif combination_mode == "concatenate":
embeds = torch.cat((conditioning_1, conditioning_2), dim=-1)
else:
raise ValueError("Invalid combination mode")
return (embeds, )
class CogVideoImageEncode:
@classmethod
def INPUT_TYPES(s):
@ -390,6 +421,9 @@ class CogVideoSampler:
denoise_strength=1.0, image_cond_latents=None):
mm.soft_empty_cache()
base_path = pipeline["base_path"]
assert "Fun" not in base_path, "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'"
assert t_tile_length > t_tile_overlap, "t_tile_length must be greater than t_tile_overlap"
assert t_tile_length <= num_frames, "t_tile_length must be equal or less than num_frames"
t_tile_length = t_tile_length // 4
@ -399,7 +433,7 @@ class CogVideoSampler:
offload_device = mm.unet_offload_device()
pipe = pipeline["pipe"]
dtype = pipeline["dtype"]
base_path = pipeline["base_path"]
if not pipeline["cpu_offloading"]:
pipe.transformer.to(device)
@ -556,6 +590,8 @@ class CogVideoXFunSampler:
offload_device = mm.unet_offload_device()
pipe = pipeline["pipe"]
dtype = pipeline["dtype"]
base_path = pipeline["base_path"]
assert "Fun" in base_path, "'Unfun' models not supported in 'CogVideoXFunSampler', use the 'CogVideoSampler'"
pipe.enable_model_cpu_offload(device=device)
@ -575,8 +611,6 @@ class CogVideoXFunSampler:
height, width = [int(x / 16) * 16 for x in closest_size]
print(f"Closest size: {width}x{height}")
base_path = pipeline["base_path"]
# Load Sampler
if scheduler == "DPM++":
noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(base_path, subfolder= 'scheduler')
@ -619,7 +653,7 @@ class CogVideoXFunSampler:
comfyui_progressbar = True,
)
#if not pipeline["cpu_offloading"]:
# pipe.transformer.to(offload_device)
# pipe.transformer.to(offload_device)
mm.soft_empty_cache()
print(latents.shape)
@ -748,7 +782,8 @@ NODE_CLASS_MAPPINGS = {
"CogVideoDualTextEncode_311": CogVideoDualTextEncode_311,
"CogVideoImageEncode": CogVideoImageEncode,
"CogVideoXFunSampler": CogVideoXFunSampler,
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler,
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
@ -758,5 +793,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoDualTextEncode_311": "CogVideo DualTextEncode",
"CogVideoImageEncode": "CogVideo ImageEncode",
"CogVideoXFunSampler": "CogVideoXFun Sampler",
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler"
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler",
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine"
}