From 3c8e939f8edf1f51a6394b54faec9fc4341f1d95 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 22 Sep 2024 18:12:50 +0300 Subject: [PATCH] Add more schedulers for "fun" model --- nodes.py | 84 +++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 52 insertions(+), 32 deletions(-) diff --git a/nodes.py b/nodes.py index 882c7b6..59d5621 100644 --- a/nodes.py +++ b/nodes.py @@ -3,7 +3,34 @@ import torch import folder_paths import comfy.model_management as mm from comfy.utils import ProgressBar, load_torch_file -from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler +from diffusers.schedulers import ( + CogVideoXDDIMScheduler, + CogVideoXDPMScheduler, + DDIMScheduler, + PNDMScheduler, + DPMSolverMultistepScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + UniPCMultistepScheduler, + HeunDiscreteScheduler, + SASolverScheduler, + DEISMultistepScheduler, + DDIMInverseScheduler + ) + +scheduler_mapping = { + "DPM++": DPMSolverMultistepScheduler, + "Euler": EulerDiscreteScheduler, + "Euler A": EulerAncestralDiscreteScheduler, + "PNDM": PNDMScheduler, + "DDIM": DDIMScheduler, + "CogVideoXDDIM": CogVideoXDDIMScheduler, + "CogVideoXDPMScheduler": CogVideoXDPMScheduler, + "SASolverScheduler": SASolverScheduler, + "UniPCMultistepScheduler": UniPCMultistepScheduler, + "HeunDiscreteScheduler": HeunDiscreteScheduler, + "DEISMultistepScheduler": DEISMultistepScheduler +} from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from .pipeline_cogvideox import CogVideoXPipeline @@ -737,6 +764,10 @@ class CogVideoXFunSampler: "DPM++", "PNDM", "DDIM", + "SASolverScheduler", + "UniPCMultistepScheduler", + "HeunDiscreteScheduler", + "DEISMultistepScheduler", "CogVideoXDDIM", "CogVideoXDPMScheduler", ], @@ -787,21 +818,11 @@ class CogVideoXFunSampler: # Load Sampler scheduler_config = pipeline["scheduler_config"] - if scheduler == "DPM++": - noise_scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config) - elif scheduler == "Euler": - noise_scheduler = EulerDiscreteScheduler.from_config(scheduler_config) - elif scheduler == "Euler A": - noise_scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) - elif scheduler == "PNDM": - noise_scheduler = PNDMScheduler.from_config(scheduler_config) - elif scheduler == "DDIM": - noise_scheduler = DDIMScheduler.from_config(scheduler_config) - elif scheduler == "CogVideoXDDIM": - noise_scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) - elif scheduler == "CogVideoXDPMScheduler": - noise_scheduler = CogVideoXDPMScheduler.from_config(scheduler_config) - pipe.scheduler = noise_scheduler + if scheduler in scheduler_mapping: + noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config) + pipe.scheduler = noise_scheduler + else: + raise ValueError(f"Unknown scheduler: {scheduler}") #if not pipeline["cpu_offloading"]: # pipe.transformer.to(device) @@ -865,6 +886,10 @@ class CogVideoXFunVid2VidSampler: "DPM++", "PNDM", "DDIM", + "SASolverScheduler", + "UniPCMultistepScheduler", + "HeunDiscreteScheduler", + "DEISMultistepScheduler", "CogVideoXDDIM", "CogVideoXDPMScheduler", ], @@ -887,8 +912,12 @@ class CogVideoXFunVid2VidSampler: offload_device = mm.unet_offload_device() pipe = pipeline["pipe"] dtype = pipeline["dtype"] + base_path = pipeline["base_path"] - pipe.enable_model_cpu_offload(device=device) + assert "Fun" in base_path, "'Unfun' models not supported in 'CogVideoXFunSampler', use the 'CogVideoSampler'" + + if not pipeline["cpu_offloading"]: + pipe.enable_model_cpu_offload(device=device) mm.soft_empty_cache() @@ -902,21 +931,12 @@ class CogVideoXFunVid2VidSampler: base_path = pipeline["base_path"] # Load Sampler - if scheduler == "DPM++": - noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(base_path, subfolder= 'scheduler') - elif scheduler == "Euler": - noise_scheduler = EulerDiscreteScheduler.from_pretrained(base_path, subfolder= 'scheduler') - elif scheduler == "Euler A": - noise_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(base_path, subfolder= 'scheduler') - elif scheduler == "PNDM": - noise_scheduler = PNDMScheduler.from_pretrained(base_path, subfolder= 'scheduler') - elif scheduler == "DDIM": - noise_scheduler = DDIMScheduler.from_pretrained(base_path, subfolder= 'scheduler') - elif scheduler == "CogVideoXDDIM": - noise_scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder= 'scheduler') - elif scheduler == "CogVideoXDPMScheduler": - noise_scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder= 'scheduler') - pipe.scheduler = noise_scheduler + scheduler_config = pipeline["scheduler_config"] + if scheduler in scheduler_mapping: + noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config) + pipe.scheduler = noise_scheduler + else: + raise ValueError(f"Unknown scheduler: {scheduler}") generator= torch.Generator(device).manual_seed(seed)