Add LatentInpaintTTM

Can be used to mimic:
https://github.com/time-to-move/TTM
This commit is contained in:
kijai 2025-11-23 01:47:56 +02:00
parent 246920d8b9
commit 3b9c1b49ab
2 changed files with 75 additions and 0 deletions

View File

@ -209,6 +209,7 @@ NODE_CONFIG = {
"ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"},
"WanVideoNAG": {"class": WanVideoNAG, "name": "WanVideoNAG"},
"GGUFLoaderKJ": {"class": GGUFLoaderKJ, "name": "GGUF Loader KJ"},
"LatentInpaintTTM": {"class": LatentInpaintTTM, "name": "Latent Inpaint TTM"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -2623,3 +2623,77 @@ class LazySwitchKJ:
def switch(self, switch, on_false = None, on_true=None):
value = on_true if switch else on_false
return (value,)
from comfy.patcher_extension import WrappersMP
from comfy.sampler_helpers import prepare_mask
class TTM_SampleWrapper:
def __init__(self, mask, end_step):
self.mask = mask
self.end_step = end_step
def __call__(self, sampler, guider, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar):
model_options = extra_args["model_options"]
wrappers = model_options["transformer_options"]["wrappers"]
w = wrappers.setdefault(WrappersMP.APPLY_MODEL, {})
if self.mask is not None:
motion_mask = self.mask.reshape((-1, 1, self.mask.shape[-2], self.mask.shape[-1]))
motion_mask = prepare_mask(motion_mask, noise.shape, noise.device)
scale_latent_inpaint = guider.model_patcher.model.scale_latent_inpaint
w["TTM_ApplyModel_Wrapper"] = [TTM_ApplyModel_Wrapper(latent_image, noise, motion_mask, self.end_step, scale_latent_inpaint)]
out = sampler(guider, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return out
class TTM_ApplyModel_Wrapper:
def __init__(self, reference_samples, noise, motion_mask, end_step, scale_latent_inpaint):
self.reference_samples = reference_samples
self.noise = noise
self.motion_mask = motion_mask
self.end_step = end_step
self.scale_latent_inpaint = scale_latent_inpaint
def __call__(self, executor, x, t, c_concat, c_crossattn, control, transformer_options, **kwargs):
sigmas = transformer_options["sample_sigmas"]
matched = (sigmas == t).nonzero(as_tuple=True)[0]
if matched.numel() > 0:
current_step_index = matched.item()
else:
crossing = ((sigmas[:-1] - t) * (sigmas[1:] - t) <= 0).nonzero(as_tuple=True)[0]
current_step_index = crossing.item() if crossing.numel() > 0 else 0
next_sigma = sigmas[current_step_index + 1] if current_step_index < len(sigmas) - 1 else sigmas[current_step_index]
if current_step_index != 0 and current_step_index < self.end_step:
noisy_latent = self.scale_latent_inpaint(x=x, sigma=torch.tensor([next_sigma]), noise=self.noise.to(x), latent_image=self.reference_samples.to(x))
x = x * (1-self.motion_mask).to(x) + noisy_latent * self.motion_mask.to(x)
return executor(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
class LatentInpaintTTM:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL", ),
"end_step": ("INT", {"default": 7, "min": 0, "max": 888, "step": 1}),
},
"optional": {
"mask": ("MASK", {"tooltip": "Latent mask where white (1.0) is the area to inpaint and black (0.0) is the area to keep unchanged."}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
EXPERIMENTAL = True
DESCRIPTION = "https://github.com/time-to-move/TTM"
CATEGORY = "KJNodes/experimental"
def patch(self, model, end_step, mask=None):
m = model.clone()
m.add_wrapper_with_key(WrappersMP.SAMPLER_SAMPLE, "TTM_SampleWrapper", TTM_SampleWrapper(mask, end_step))
return (m, )