add TCD scheduler to CogVideoXFun sampler

This commit is contained in:
Phr00t 2024-09-21 13:49:00 -04:00
parent df3f210287
commit b36d1b7a4a

View File

@ -3,7 +3,7 @@ import torch
import folder_paths
import comfy.model_management as mm
from comfy.utils import ProgressBar, load_torch_file
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LCMScheduler, UniPCMultistepScheduler
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LCMScheduler, UniPCMultistepScheduler, TCDScheduler
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from .pipeline_cogvideox import CogVideoXPipeline
@ -700,6 +700,7 @@ class CogVideoXFunSampler:
"DDIM",
"LCM",
"UniPC",
"TCD",
"CogVideoXDDIM",
"CogVideoXDPMScheduler",
],
@ -759,6 +760,8 @@ class CogVideoXFunSampler:
noise_scheduler = LCMScheduler.from_config(scheduler_config)
elif scheduler == "UniPC":
noise_scheduler = UniPCMultistepScheduler.from_config(scheduler_config)
elif scheduler == "TCD":
noise_scheduler = TCDScheduler.from_config(scheduler_config)
elif scheduler == "CogVideoXDDIM":
noise_scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
elif scheduler == "CogVideoXDPMScheduler":
@ -819,6 +822,7 @@ class CogVideoXFunVid2VidSampler:
"DDIM",
"LCM",
"UniPC",
"TCD",
"CogVideoXDDIM",
"CogVideoXDPMScheduler",
],
@ -867,6 +871,8 @@ class CogVideoXFunVid2VidSampler:
noise_scheduler = LCMScheduler.from_pretrained(base_path, subfolder= 'scheduler')
elif scheduler == "UniPC":
noise_scheduler = UniPCMultistepScheduler.from_pretrained(base_path, subfolder= 'scheduler')
elif scheduler == "TCD":
noise_scheduler = TCDScheduler.from_pretrained(base_path, subfolder= 'scheduler')
elif scheduler == "CogVideoXDDIM":
noise_scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder= 'scheduler')
elif scheduler == "CogVideoXDPMScheduler":