mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 04:44:30 +08:00
Add LatentInpaintTTM
Can be used to mimic: https://github.com/time-to-move/TTM
This commit is contained in:
parent
246920d8b9
commit
3b9c1b49ab
@ -209,6 +209,7 @@ NODE_CONFIG = {
|
||||
"ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"},
|
||||
"WanVideoNAG": {"class": WanVideoNAG, "name": "WanVideoNAG"},
|
||||
"GGUFLoaderKJ": {"class": GGUFLoaderKJ, "name": "GGUF Loader KJ"},
|
||||
"LatentInpaintTTM": {"class": LatentInpaintTTM, "name": "Latent Inpaint TTM"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -2623,3 +2623,77 @@ class LazySwitchKJ:
|
||||
def switch(self, switch, on_false = None, on_true=None):
|
||||
value = on_true if switch else on_false
|
||||
return (value,)
|
||||
|
||||
|
||||
from comfy.patcher_extension import WrappersMP
|
||||
from comfy.sampler_helpers import prepare_mask
|
||||
class TTM_SampleWrapper:
|
||||
def __init__(self, mask, end_step):
|
||||
self.mask = mask
|
||||
self.end_step = end_step
|
||||
|
||||
def __call__(self, sampler, guider, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar):
|
||||
model_options = extra_args["model_options"]
|
||||
wrappers = model_options["transformer_options"]["wrappers"]
|
||||
w = wrappers.setdefault(WrappersMP.APPLY_MODEL, {})
|
||||
|
||||
if self.mask is not None:
|
||||
motion_mask = self.mask.reshape((-1, 1, self.mask.shape[-2], self.mask.shape[-1]))
|
||||
motion_mask = prepare_mask(motion_mask, noise.shape, noise.device)
|
||||
|
||||
scale_latent_inpaint = guider.model_patcher.model.scale_latent_inpaint
|
||||
w["TTM_ApplyModel_Wrapper"] = [TTM_ApplyModel_Wrapper(latent_image, noise, motion_mask, self.end_step, scale_latent_inpaint)]
|
||||
|
||||
out = sampler(guider, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TTM_ApplyModel_Wrapper:
|
||||
def __init__(self, reference_samples, noise, motion_mask, end_step, scale_latent_inpaint):
|
||||
self.reference_samples = reference_samples
|
||||
self.noise = noise
|
||||
self.motion_mask = motion_mask
|
||||
self.end_step = end_step
|
||||
self.scale_latent_inpaint = scale_latent_inpaint
|
||||
|
||||
def __call__(self, executor, x, t, c_concat, c_crossattn, control, transformer_options, **kwargs):
|
||||
sigmas = transformer_options["sample_sigmas"]
|
||||
|
||||
matched = (sigmas == t).nonzero(as_tuple=True)[0]
|
||||
if matched.numel() > 0:
|
||||
current_step_index = matched.item()
|
||||
else:
|
||||
crossing = ((sigmas[:-1] - t) * (sigmas[1:] - t) <= 0).nonzero(as_tuple=True)[0]
|
||||
current_step_index = crossing.item() if crossing.numel() > 0 else 0
|
||||
|
||||
next_sigma = sigmas[current_step_index + 1] if current_step_index < len(sigmas) - 1 else sigmas[current_step_index]
|
||||
|
||||
if current_step_index != 0 and current_step_index < self.end_step:
|
||||
noisy_latent = self.scale_latent_inpaint(x=x, sigma=torch.tensor([next_sigma]), noise=self.noise.to(x), latent_image=self.reference_samples.to(x))
|
||||
x = x * (1-self.motion_mask).to(x) + noisy_latent * self.motion_mask.to(x)
|
||||
|
||||
return executor(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
|
||||
class LatentInpaintTTM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model": ("MODEL", ),
|
||||
"end_step": ("INT", {"default": 7, "min": 0, "max": 888, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", {"tooltip": "Latent mask where white (1.0) is the area to inpaint and black (0.0) is the area to keep unchanged."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "https://github.com/time-to-move/TTM"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch(self, model, end_step, mask=None):
|
||||
m = model.clone()
|
||||
m.add_wrapper_with_key(WrappersMP.SAMPLER_SAMPLE, "TTM_SampleWrapper", TTM_SampleWrapper(mask, end_step))
|
||||
return (m, )
|
||||
Loading…
x
Reference in New Issue
Block a user