diff --git a/__init__.py b/__init__.py index 02157b7..2f2ebb5 100644 --- a/__init__.py +++ b/__init__.py @@ -209,6 +209,7 @@ NODE_CONFIG = { "ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"}, "WanVideoNAG": {"class": WanVideoNAG, "name": "WanVideoNAG"}, "GGUFLoaderKJ": {"class": GGUFLoaderKJ, "name": "GGUF Loader KJ"}, + "LatentInpaintTTM": {"class": LatentInpaintTTM, "name": "Latent Inpaint TTM"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/nodes.py b/nodes/nodes.py index 9caec1c..d04053c 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2623,3 +2623,77 @@ class LazySwitchKJ: def switch(self, switch, on_false = None, on_true=None): value = on_true if switch else on_false return (value,) + + +from comfy.patcher_extension import WrappersMP +from comfy.sampler_helpers import prepare_mask +class TTM_SampleWrapper: + def __init__(self, mask, end_step): + self.mask = mask + self.end_step = end_step + + def __call__(self, sampler, guider, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar): + model_options = extra_args["model_options"] + wrappers = model_options["transformer_options"]["wrappers"] + w = wrappers.setdefault(WrappersMP.APPLY_MODEL, {}) + + if self.mask is not None: + motion_mask = self.mask.reshape((-1, 1, self.mask.shape[-2], self.mask.shape[-1])) + motion_mask = prepare_mask(motion_mask, noise.shape, noise.device) + + scale_latent_inpaint = guider.model_patcher.model.scale_latent_inpaint + w["TTM_ApplyModel_Wrapper"] = [TTM_ApplyModel_Wrapper(latent_image, noise, motion_mask, self.end_step, scale_latent_inpaint)] + + out = sampler(guider, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) + + return out + + +class TTM_ApplyModel_Wrapper: + def __init__(self, reference_samples, noise, motion_mask, end_step, scale_latent_inpaint): + self.reference_samples = reference_samples + self.noise = noise + self.motion_mask = motion_mask + self.end_step = end_step + self.scale_latent_inpaint = scale_latent_inpaint + + def __call__(self, executor, x, t, c_concat, c_crossattn, control, transformer_options, **kwargs): + sigmas = transformer_options["sample_sigmas"] + + matched = (sigmas == t).nonzero(as_tuple=True)[0] + if matched.numel() > 0: + current_step_index = matched.item() + else: + crossing = ((sigmas[:-1] - t) * (sigmas[1:] - t) <= 0).nonzero(as_tuple=True)[0] + current_step_index = crossing.item() if crossing.numel() > 0 else 0 + + next_sigma = sigmas[current_step_index + 1] if current_step_index < len(sigmas) - 1 else sigmas[current_step_index] + + if current_step_index != 0 and current_step_index < self.end_step: + noisy_latent = self.scale_latent_inpaint(x=x, sigma=torch.tensor([next_sigma]), noise=self.noise.to(x), latent_image=self.reference_samples.to(x)) + x = x * (1-self.motion_mask).to(x) + noisy_latent * self.motion_mask.to(x) + + return executor(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) + + +class LatentInpaintTTM: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL", ), + "end_step": ("INT", {"default": 7, "min": 0, "max": 888, "step": 1}), + }, + "optional": { + "mask": ("MASK", {"tooltip": "Latent mask where white (1.0) is the area to inpaint and black (0.0) is the area to keep unchanged."}), + } + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + EXPERIMENTAL = True + DESCRIPTION = "https://github.com/time-to-move/TTM" + CATEGORY = "KJNodes/experimental" + + def patch(self, model, end_step, mask=None): + m = model.clone() + m.add_wrapper_with_key(WrappersMP.SAMPLER_SAMPLE, "TTM_SampleWrapper", TTM_SampleWrapper(mask, end_step)) + return (m, ) \ No newline at end of file