Add more schedulers for "fun" model

This commit is contained in:
kijai 2024-09-22 18:12:50 +03:00
parent d3d7f043cd
commit 3c8e939f8e

View File

@ -3,7 +3,34 @@ import torch
import folder_paths import folder_paths
import comfy.model_management as mm import comfy.model_management as mm
from comfy.utils import ProgressBar, load_torch_file from comfy.utils import ProgressBar, load_torch_file
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler from diffusers.schedulers import (
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
DDIMScheduler,
PNDMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
UniPCMultistepScheduler,
HeunDiscreteScheduler,
SASolverScheduler,
DEISMultistepScheduler,
DDIMInverseScheduler
)
scheduler_mapping = {
"DPM++": DPMSolverMultistepScheduler,
"Euler": EulerDiscreteScheduler,
"Euler A": EulerAncestralDiscreteScheduler,
"PNDM": PNDMScheduler,
"DDIM": DDIMScheduler,
"CogVideoXDDIM": CogVideoXDDIMScheduler,
"CogVideoXDPMScheduler": CogVideoXDPMScheduler,
"SASolverScheduler": SASolverScheduler,
"UniPCMultistepScheduler": UniPCMultistepScheduler,
"HeunDiscreteScheduler": HeunDiscreteScheduler,
"DEISMultistepScheduler": DEISMultistepScheduler
}
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from .pipeline_cogvideox import CogVideoXPipeline from .pipeline_cogvideox import CogVideoXPipeline
@ -737,6 +764,10 @@ class CogVideoXFunSampler:
"DPM++", "DPM++",
"PNDM", "PNDM",
"DDIM", "DDIM",
"SASolverScheduler",
"UniPCMultistepScheduler",
"HeunDiscreteScheduler",
"DEISMultistepScheduler",
"CogVideoXDDIM", "CogVideoXDDIM",
"CogVideoXDPMScheduler", "CogVideoXDPMScheduler",
], ],
@ -787,21 +818,11 @@ class CogVideoXFunSampler:
# Load Sampler # Load Sampler
scheduler_config = pipeline["scheduler_config"] scheduler_config = pipeline["scheduler_config"]
if scheduler == "DPM++": if scheduler in scheduler_mapping:
noise_scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config) noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
elif scheduler == "Euler":
noise_scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
elif scheduler == "Euler A":
noise_scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
elif scheduler == "PNDM":
noise_scheduler = PNDMScheduler.from_config(scheduler_config)
elif scheduler == "DDIM":
noise_scheduler = DDIMScheduler.from_config(scheduler_config)
elif scheduler == "CogVideoXDDIM":
noise_scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
elif scheduler == "CogVideoXDPMScheduler":
noise_scheduler = CogVideoXDPMScheduler.from_config(scheduler_config)
pipe.scheduler = noise_scheduler pipe.scheduler = noise_scheduler
else:
raise ValueError(f"Unknown scheduler: {scheduler}")
#if not pipeline["cpu_offloading"]: #if not pipeline["cpu_offloading"]:
# pipe.transformer.to(device) # pipe.transformer.to(device)
@ -865,6 +886,10 @@ class CogVideoXFunVid2VidSampler:
"DPM++", "DPM++",
"PNDM", "PNDM",
"DDIM", "DDIM",
"SASolverScheduler",
"UniPCMultistepScheduler",
"HeunDiscreteScheduler",
"DEISMultistepScheduler",
"CogVideoXDDIM", "CogVideoXDDIM",
"CogVideoXDPMScheduler", "CogVideoXDPMScheduler",
], ],
@ -887,7 +912,11 @@ class CogVideoXFunVid2VidSampler:
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
pipe = pipeline["pipe"] pipe = pipeline["pipe"]
dtype = pipeline["dtype"] dtype = pipeline["dtype"]
base_path = pipeline["base_path"]
assert "Fun" in base_path, "'Unfun' models not supported in 'CogVideoXFunSampler', use the 'CogVideoSampler'"
if not pipeline["cpu_offloading"]:
pipe.enable_model_cpu_offload(device=device) pipe.enable_model_cpu_offload(device=device)
mm.soft_empty_cache() mm.soft_empty_cache()
@ -902,21 +931,12 @@ class CogVideoXFunVid2VidSampler:
base_path = pipeline["base_path"] base_path = pipeline["base_path"]
# Load Sampler # Load Sampler
if scheduler == "DPM++": scheduler_config = pipeline["scheduler_config"]
noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(base_path, subfolder= 'scheduler') if scheduler in scheduler_mapping:
elif scheduler == "Euler": noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
noise_scheduler = EulerDiscreteScheduler.from_pretrained(base_path, subfolder= 'scheduler')
elif scheduler == "Euler A":
noise_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(base_path, subfolder= 'scheduler')
elif scheduler == "PNDM":
noise_scheduler = PNDMScheduler.from_pretrained(base_path, subfolder= 'scheduler')
elif scheduler == "DDIM":
noise_scheduler = DDIMScheduler.from_pretrained(base_path, subfolder= 'scheduler')
elif scheduler == "CogVideoXDDIM":
noise_scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder= 'scheduler')
elif scheduler == "CogVideoXDPMScheduler":
noise_scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder= 'scheduler')
pipe.scheduler = noise_scheduler pipe.scheduler = noise_scheduler
else:
raise ValueError(f"Unknown scheduler: {scheduler}")
generator= torch.Generator(device).manual_seed(seed) generator= torch.Generator(device).manual_seed(seed)