diff --git a/__init__.py b/__init__.py index 7aa99e5..34366e1 100644 --- a/__init__.py +++ b/__init__.py @@ -134,6 +134,7 @@ NODE_CONFIG = { "CheckpointPerturbWeights": {"class": CheckpointPerturbWeights, "name": "CheckpointPerturbWeights"}, "Screencap_mss": {"class": Screencap_mss, "name": "Screencap mss"}, "WebcamCaptureCV2": {"class": WebcamCaptureCV2, "name": "Webcam Capture CV2"}, + "DifferentialDiffusionAdvanced": {"class": DifferentialDiffusionAdvanced, "name": "Differential Diffusion Advanced"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/nodes.py b/nodes/nodes.py index 41dbb37..34ac57a 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1756,4 +1756,37 @@ class CheckpointPerturbWeights: dict[k] += torch.normal(torch.zeros_like(v) * v.mean(), torch.ones_like(v) * v.std() * multiplier).to(device) pbar.update(1) model_copy.model.diffusion_model.load_state_dict(dict) - return model_copy, \ No newline at end of file + return model_copy, + +class DifferentialDiffusionAdvanced(): + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.001}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply" + CATEGORY = "_for_testing" + INIT = False + + def apply(self, model, multiplier): + self.multiplier = multiplier + model = model.clone() + model.set_model_denoise_mask_function(self.forward) + return (model,) + + def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): + model = extra_options["model"] + step_sigmas = extra_options["sigmas"] + sigma_to = model.inner_model.model_sampling.sigma_min + if step_sigmas[-1] > sigma_to: + sigma_to = step_sigmas[-1] + sigma_from = step_sigmas[0] + + ts_from = model.inner_model.model_sampling.timestep(sigma_from) + ts_to = model.inner_model.model_sampling.timestep(sigma_to) + current_ts = model.inner_model.model_sampling.timestep(sigma[0]) + + threshold = (current_ts - ts_to) / (ts_from - ts_to) / self.multiplier + + return (denoise_mask >= threshold).to(denoise_mask.dtype) \ No newline at end of file