From be8e44db3937e0b1729f30c7a1e75efa62629992 Mon Sep 17 00:00:00 2001 From: Kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 29 Jan 2024 19:07:53 +0200 Subject: [PATCH] Update nodes.py --- nodes.py | 67 +++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 57 insertions(+), 10 deletions(-) diff --git a/nodes.py b/nodes.py index 01e4cec..34750ba 100644 --- a/nodes.py +++ b/nodes.py @@ -34,7 +34,7 @@ class INTConstant: RETURN_NAMES = ("value",) FUNCTION = "get_value" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/constants" def get_value(self, value): return (value,) @@ -51,16 +51,27 @@ class FloatConstant: RETURN_NAMES = ("value",) FUNCTION = "get_value" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/constants" def get_value(self, value): return (value,) -def gaussian_kernel(kernel_size: int, sigma: float, device=None): - x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij") - d = torch.sqrt(x * x + y * y) - g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) - return g / g.sum() +class StringConstant: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "string": ("STRING", {"default": '', "multiline": False}), + } + } + RETURN_TYPES = ("STRING",) + FUNCTION = "passtring" + + CATEGORY = "KJNodes/constants" + + def passtring(self, string): + return (string, ) class CreateFluidMask: @@ -526,6 +537,38 @@ class GetImagesFromBatchIndexed: return (chosen_images,) +class GetLatentsFromBatchIndexed: + + RETURN_TYPES = ("LATENT",) + FUNCTION = "indexedlatentsfrombatch" + CATEGORY = "KJNodes" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "latents": ("LATENT",), + "indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}), + }, + } + + def indexedlatentsfrombatch(self, latents, indexes): + + samples = latents.copy() + latent_samples = samples["samples"] + + # Parse the indexes string into a list of integers + index_list = [int(index.strip()) for index in indexes.split(',')] + + # Convert list of indices to a PyTorch tensor + indices_tensor = torch.tensor(index_list, dtype=torch.long) + + # Select the latents at the specified indices + chosen_latents = latent_samples[indices_tensor] + + samples["samples"] = chosen_latents + return (samples,) + class ReplaceImagesInBatch: RETURN_TYPES = ("IMAGE",) @@ -3271,7 +3314,7 @@ class ImageTransformByNormalizedAmplitude: return (transformed_batch,) - + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -3331,7 +3374,9 @@ NODE_CLASS_MAPPINGS = { "ImageBatchRepeatInterleaving": ImageBatchRepeatInterleaving, "NormalizedAmplitudeToMask": NormalizedAmplitudeToMask, "OffsetMaskByNormalizedAmplitude": OffsetMaskByNormalizedAmplitude, - "ImageTransformByNormalizedAmplitude": ImageTransformByNormalizedAmplitude + "ImageTransformByNormalizedAmplitude": ImageTransformByNormalizedAmplitude, + "GetLatentsFromBatchIndexed": GetLatentsFromBatchIndexed, + "StringConstant": StringConstant } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -3391,5 +3436,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageBatchRepeatInterleaving": "ImageBatchRepeatInterleaving", "NormalizedAmplitudeToMask": "NormalizedAmplitudeToMask", "OffsetMaskByNormalizedAmplitude": "OffsetMaskByNormalizedAmplitude", - "ImageTransformByNormalizedAmplitude": "ImageTransformByNormalizedAmplitude" + "ImageTransformByNormalizedAmplitude": "ImageTransformByNormalizedAmplitude", + "GetLatentsFromBatchIndexed": "GetLatentsFromBatchIndexed", + "StringConstant": "StringConstant" } \ No newline at end of file