mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-16 08:24:25 +08:00
Update nodes.py
This commit is contained in:
parent
c3eae2edc1
commit
62ebaf986a
48
nodes.py
48
nodes.py
@ -274,6 +274,37 @@ class CogVideoTextEncode:
|
||||
|
||||
return (embeds, )
|
||||
|
||||
class CogVideoTextEncodeCombine:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"conditioning_1": ("CONDITIONING",),
|
||||
"conditioning_2": ("CONDITIONING",),
|
||||
"combination_mode": (["average", "weighted_average", "concatenate"], {"default": "weighted_average"}),
|
||||
"weighted_average_ratio": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
RETURN_NAMES = ("conditioning",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def process(self, conditioning_1, conditioning_2, combination_mode, weighted_average_ratio):
|
||||
if conditioning_1.shape != conditioning_2.shape:
|
||||
raise ValueError("conditioning_1 and conditioning_2 must have the same shape")
|
||||
|
||||
if combination_mode == "average":
|
||||
embeds = (conditioning_1 + conditioning_2) / 2
|
||||
elif combination_mode == "weighted_average":
|
||||
embeds = conditioning_1 * (1 - weighted_average_ratio) + conditioning_2 * weighted_average_ratio
|
||||
elif combination_mode == "concatenate":
|
||||
embeds = torch.cat((conditioning_1, conditioning_2), dim=-1)
|
||||
else:
|
||||
raise ValueError("Invalid combination mode")
|
||||
|
||||
return (embeds, )
|
||||
|
||||
class CogVideoImageEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -390,6 +421,9 @@ class CogVideoSampler:
|
||||
denoise_strength=1.0, image_cond_latents=None):
|
||||
mm.soft_empty_cache()
|
||||
|
||||
base_path = pipeline["base_path"]
|
||||
|
||||
assert "Fun" not in base_path, "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'"
|
||||
assert t_tile_length > t_tile_overlap, "t_tile_length must be greater than t_tile_overlap"
|
||||
assert t_tile_length <= num_frames, "t_tile_length must be equal or less than num_frames"
|
||||
t_tile_length = t_tile_length // 4
|
||||
@ -399,7 +433,7 @@ class CogVideoSampler:
|
||||
offload_device = mm.unet_offload_device()
|
||||
pipe = pipeline["pipe"]
|
||||
dtype = pipeline["dtype"]
|
||||
base_path = pipeline["base_path"]
|
||||
|
||||
|
||||
if not pipeline["cpu_offloading"]:
|
||||
pipe.transformer.to(device)
|
||||
@ -556,6 +590,8 @@ class CogVideoXFunSampler:
|
||||
offload_device = mm.unet_offload_device()
|
||||
pipe = pipeline["pipe"]
|
||||
dtype = pipeline["dtype"]
|
||||
base_path = pipeline["base_path"]
|
||||
assert "Fun" in base_path, "'Unfun' models not supported in 'CogVideoXFunSampler', use the 'CogVideoSampler'"
|
||||
|
||||
pipe.enable_model_cpu_offload(device=device)
|
||||
|
||||
@ -575,8 +611,6 @@ class CogVideoXFunSampler:
|
||||
height, width = [int(x / 16) * 16 for x in closest_size]
|
||||
print(f"Closest size: {width}x{height}")
|
||||
|
||||
base_path = pipeline["base_path"]
|
||||
|
||||
# Load Sampler
|
||||
if scheduler == "DPM++":
|
||||
noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(base_path, subfolder= 'scheduler')
|
||||
@ -619,7 +653,7 @@ class CogVideoXFunSampler:
|
||||
comfyui_progressbar = True,
|
||||
)
|
||||
#if not pipeline["cpu_offloading"]:
|
||||
# pipe.transformer.to(offload_device)
|
||||
# pipe.transformer.to(offload_device)
|
||||
mm.soft_empty_cache()
|
||||
print(latents.shape)
|
||||
|
||||
@ -748,7 +782,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoDualTextEncode_311": CogVideoDualTextEncode_311,
|
||||
"CogVideoImageEncode": CogVideoImageEncode,
|
||||
"CogVideoXFunSampler": CogVideoXFunSampler,
|
||||
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler
|
||||
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler,
|
||||
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||
@ -758,5 +793,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoDualTextEncode_311": "CogVideo DualTextEncode",
|
||||
"CogVideoImageEncode": "CogVideo ImageEncode",
|
||||
"CogVideoXFunSampler": "CogVideoXFun Sampler",
|
||||
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler"
|
||||
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler",
|
||||
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine"
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user