custom sigmas

This commit is contained in:
kijai 2024-10-27 12:41:33 +02:00
parent 3613700752
commit 185f4e0bee

View File

@ -356,6 +356,7 @@ class MochiSampler:
}, },
"optional": { "optional": {
"cfg_schedule": ("FLOAT", {"forceInput": True,}), "cfg_schedule": ("FLOAT", {"forceInput": True,}),
"opt_sigmas": ("SIGMAS",),
} }
} }
@ -364,11 +365,15 @@ class MochiSampler:
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None): def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None):
mm.soft_empty_cache() mm.soft_empty_cache()
device = mm.get_torch_device() if opt_sigmas is not None:
offload_device = mm.unet_offload_device() sigma_schedule = opt_sigmas.tolist()
logging.info(f"Using sigmas: {sigma_schedule}")
else:
sigma_schedule = linear_quadratic_schedule(steps, 0.025)
logging.info(f"Using sigmas: {sigma_schedule}")
cfg_schedule = cfg_schedule or [cfg] * steps cfg_schedule = cfg_schedule or [cfg] * steps
logging.info(f"Using cfg schedule: {cfg_schedule}") logging.info(f"Using cfg schedule: {cfg_schedule}")
@ -378,7 +383,7 @@ class MochiSampler:
"width": width, "width": width,
"num_frames": num_frames, "num_frames": num_frames,
"mochi_args": { "mochi_args": {
"sigma_schedule": linear_quadratic_schedule(steps, 0.025), "sigma_schedule": sigma_schedule,
"cfg_schedule": cfg_schedule, "cfg_schedule": cfg_schedule,
"num_inference_steps": steps, "num_inference_steps": steps,
"batch_cfg": False, "batch_cfg": False,