diff --git a/nodes.py b/nodes.py index 3d4b415..0822865 100644 --- a/nodes.py +++ b/nodes.py @@ -2524,16 +2524,20 @@ class FlipSigmasAdjusted: return (adjusted_sigmas,) + class InjectNoiseToLatent: @classmethod def INPUT_TYPES(s): return {"required": { "latents":("LATENT",), - "strength": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 200.0, "step": 0.001}), + "strength": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 200.0, "step": 0.0001}), "noise": ("LATENT",), "normalize": ("BOOLEAN", {"default": False}), "average": ("BOOLEAN", {"default": False}), }, + "optional":{ + "mask": ("MASK", ), + } } RETURN_TYPES = ("LATENT",) @@ -2541,7 +2545,7 @@ class InjectNoiseToLatent: CATEGORY = "KJNodes/noise" - def injectnoise(self, latents, strength, noise, normalize, average): + def injectnoise(self, latents, strength, noise, normalize, average, mask=None): samples = latents.copy() if latents["samples"].shape != noise["samples"].shape: raise ValueError("InjectNoiseToLatent: Latent and noise must have the same shape") @@ -2551,7 +2555,12 @@ class InjectNoiseToLatent: noised = samples["samples"].clone() + noise["samples"].clone() * strength if normalize: noised = noised / noised.std() - + if mask is not None: + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(noised.shape[2], noised.shape[3]), mode="bilinear") + mask = mask.expand((-1,noised.shape[1],-1,-1)) + if mask.shape[0] < noised.shape[0]: + mask = mask.repeat((noised.shape[0] -1) // mask.shape[0] + 1, 1, 1, 1)[:noised.shape[0]] + noised = mask * noised + (1-mask) * latents["samples"] samples["samples"] = noised return (samples,) @@ -2606,7 +2615,55 @@ class AddLabel: combined_images = torch.cat((label_batch, image), dim=1) return (combined_images,) - + +class ReferenceOnlySimple3: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "reference": ("LATENT",), + "reference2": ("LATENT",), + "input": ("LATENT",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}) + }} + RETURN_TYPES = ("MODEL", "LATENT") + FUNCTION = "reference_only" + + CATEGORY = "KJNodes/experiments" + + def reference_only(self, model, reference, reference2, input, batch_size): + model_reference = model.clone() + size_latent = list(reference["samples"].shape) + size_latent[0] = batch_size + latent = input + + batch = latent["samples"].shape[0] + reference["samples"].shape[0] + reference2["samples"].shape[0] + + def reference_apply(q, k, v, extra_options): + k = k.clone().repeat(1, 2, 1) + offset = 0 + if q.shape[0] > batch: + offset = batch + + re = extra_options["transformer_index"] % 2 + + for o in range(0, q.shape[0], batch): + for x in range(1, batch): + k[x + o, q.shape[1]:] = q[o + re,:] + return q, k, k + + model_reference.set_model_attn1_patch(reference_apply) + + out_latent = torch.cat((reference["samples"], reference2["samples"], latent["samples"])) + if "noise_mask" in latent: + mask = latent["noise_mask"] + else: + mask = torch.ones((64,64), dtype=torch.float32, device="cpu") + mask = mask.repeat(latent["samples"].shape[0], 1, 1) + + out_mask = torch.zeros((1,mask.shape[1],mask.shape[2]), dtype=torch.float32, device="cpu") + return (model_reference, {"samples": out_latent, "noise_mask": torch.cat((out_mask,out_mask, mask))}) + + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -2655,7 +2712,8 @@ NODE_CLASS_MAPPINGS = { "NormalizeLatent": NormalizeLatent, "FlipSigmasAdjusted": FlipSigmasAdjusted, "InjectNoiseToLatent": InjectNoiseToLatent, - "AddLabel": AddLabel + "AddLabel": AddLabel, + "ReferenceOnlySimple3": ReferenceOnlySimple3 } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -2704,6 +2762,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "NormalizeLatent": "NormalizeLatent", "FlipSigmasAdjusted": "FlipSigmasAdjusted", "InjectNoiseToLatent": "InjectNoiseToLatent", - "AddLabel": "AddLabel" + "AddLabel": "AddLabel", + "ReferenceOnlySimple3": "ReferenceOnlySimple3" } \ No newline at end of file