diff --git a/nodes/nodes.py b/nodes/nodes.py index 34ac57a..63efbea 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1761,19 +1761,24 @@ class CheckpointPerturbWeights: class DifferentialDiffusionAdvanced(): @classmethod def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.001}), + return {"required": { + "model": ("MODEL", ), + "samples": ("LATENT",), + "mask": ("MASK",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.001}), }} - RETURN_TYPES = ("MODEL",) + RETURN_TYPES = ("MODEL", "LATENT") FUNCTION = "apply" CATEGORY = "_for_testing" INIT = False - def apply(self, model, multiplier): + def apply(self, model, samples, mask, multiplier): self.multiplier = multiplier model = model.clone() model.set_model_denoise_mask_function(self.forward) - return (model,) + s = samples.copy() + s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) + return (model, s) def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): model = extra_options["model"]