mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 04:44:30 +08:00
Add LatentInpaintTTM
Can be used to mimic: https://github.com/time-to-move/TTM
This commit is contained in:
parent
246920d8b9
commit
3b9c1b49ab
@ -209,6 +209,7 @@ NODE_CONFIG = {
|
|||||||
"ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"},
|
"ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"},
|
||||||
"WanVideoNAG": {"class": WanVideoNAG, "name": "WanVideoNAG"},
|
"WanVideoNAG": {"class": WanVideoNAG, "name": "WanVideoNAG"},
|
||||||
"GGUFLoaderKJ": {"class": GGUFLoaderKJ, "name": "GGUF Loader KJ"},
|
"GGUFLoaderKJ": {"class": GGUFLoaderKJ, "name": "GGUF Loader KJ"},
|
||||||
|
"LatentInpaintTTM": {"class": LatentInpaintTTM, "name": "Latent Inpaint TTM"},
|
||||||
|
|
||||||
#instance diffusion
|
#instance diffusion
|
||||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||||
|
|||||||
@ -2623,3 +2623,77 @@ class LazySwitchKJ:
|
|||||||
def switch(self, switch, on_false = None, on_true=None):
|
def switch(self, switch, on_false = None, on_true=None):
|
||||||
value = on_true if switch else on_false
|
value = on_true if switch else on_false
|
||||||
return (value,)
|
return (value,)
|
||||||
|
|
||||||
|
|
||||||
|
from comfy.patcher_extension import WrappersMP
|
||||||
|
from comfy.sampler_helpers import prepare_mask
|
||||||
|
class TTM_SampleWrapper:
|
||||||
|
def __init__(self, mask, end_step):
|
||||||
|
self.mask = mask
|
||||||
|
self.end_step = end_step
|
||||||
|
|
||||||
|
def __call__(self, sampler, guider, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar):
|
||||||
|
model_options = extra_args["model_options"]
|
||||||
|
wrappers = model_options["transformer_options"]["wrappers"]
|
||||||
|
w = wrappers.setdefault(WrappersMP.APPLY_MODEL, {})
|
||||||
|
|
||||||
|
if self.mask is not None:
|
||||||
|
motion_mask = self.mask.reshape((-1, 1, self.mask.shape[-2], self.mask.shape[-1]))
|
||||||
|
motion_mask = prepare_mask(motion_mask, noise.shape, noise.device)
|
||||||
|
|
||||||
|
scale_latent_inpaint = guider.model_patcher.model.scale_latent_inpaint
|
||||||
|
w["TTM_ApplyModel_Wrapper"] = [TTM_ApplyModel_Wrapper(latent_image, noise, motion_mask, self.end_step, scale_latent_inpaint)]
|
||||||
|
|
||||||
|
out = sampler(guider, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TTM_ApplyModel_Wrapper:
|
||||||
|
def __init__(self, reference_samples, noise, motion_mask, end_step, scale_latent_inpaint):
|
||||||
|
self.reference_samples = reference_samples
|
||||||
|
self.noise = noise
|
||||||
|
self.motion_mask = motion_mask
|
||||||
|
self.end_step = end_step
|
||||||
|
self.scale_latent_inpaint = scale_latent_inpaint
|
||||||
|
|
||||||
|
def __call__(self, executor, x, t, c_concat, c_crossattn, control, transformer_options, **kwargs):
|
||||||
|
sigmas = transformer_options["sample_sigmas"]
|
||||||
|
|
||||||
|
matched = (sigmas == t).nonzero(as_tuple=True)[0]
|
||||||
|
if matched.numel() > 0:
|
||||||
|
current_step_index = matched.item()
|
||||||
|
else:
|
||||||
|
crossing = ((sigmas[:-1] - t) * (sigmas[1:] - t) <= 0).nonzero(as_tuple=True)[0]
|
||||||
|
current_step_index = crossing.item() if crossing.numel() > 0 else 0
|
||||||
|
|
||||||
|
next_sigma = sigmas[current_step_index + 1] if current_step_index < len(sigmas) - 1 else sigmas[current_step_index]
|
||||||
|
|
||||||
|
if current_step_index != 0 and current_step_index < self.end_step:
|
||||||
|
noisy_latent = self.scale_latent_inpaint(x=x, sigma=torch.tensor([next_sigma]), noise=self.noise.to(x), latent_image=self.reference_samples.to(x))
|
||||||
|
x = x * (1-self.motion_mask).to(x) + noisy_latent * self.motion_mask.to(x)
|
||||||
|
|
||||||
|
return executor(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LatentInpaintTTM:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"model": ("MODEL", ),
|
||||||
|
"end_step": ("INT", {"default": 7, "min": 0, "max": 888, "step": 1}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"mask": ("MASK", {"tooltip": "Latent mask where white (1.0) is the area to inpaint and black (0.0) is the area to keep unchanged."}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "https://github.com/time-to-move/TTM"
|
||||||
|
CATEGORY = "KJNodes/experimental"
|
||||||
|
|
||||||
|
def patch(self, model, end_step, mask=None):
|
||||||
|
m = model.clone()
|
||||||
|
m.add_wrapper_with_key(WrappersMP.SAMPLER_SAMPLE, "TTM_SampleWrapper", TTM_SampleWrapper(mask, end_step))
|
||||||
|
return (m, )
|
||||||
Loading…
x
Reference in New Issue
Block a user