diff --git a/nodes.py b/nodes.py index 1eeaa40..2ac8d9d 100644 --- a/nodes.py +++ b/nodes.py @@ -356,6 +356,7 @@ class MochiSampler: }, "optional": { "cfg_schedule": ("FLOAT", {"forceInput": True,}), + "opt_sigmas": ("SIGMAS",), } } @@ -364,11 +365,15 @@ class MochiSampler: FUNCTION = "process" CATEGORY = "MochiWrapper" - def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None): + def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None): mm.soft_empty_cache() - device = mm.get_torch_device() - offload_device = mm.unet_offload_device() + if opt_sigmas is not None: + sigma_schedule = opt_sigmas.tolist() + logging.info(f"Using sigmas: {sigma_schedule}") + else: + sigma_schedule = linear_quadratic_schedule(steps, 0.025) + logging.info(f"Using sigmas: {sigma_schedule}") cfg_schedule = cfg_schedule or [cfg] * steps logging.info(f"Using cfg schedule: {cfg_schedule}") @@ -378,7 +383,7 @@ class MochiSampler: "width": width, "num_frames": num_frames, "mochi_args": { - "sigma_schedule": linear_quadratic_schedule(steps, 0.025), + "sigma_schedule": sigma_schedule, "cfg_schedule": cfg_schedule, "num_inference_steps": steps, "batch_cfg": False,