diff --git a/nodes.py b/nodes.py index 7c4a875..219129a 100644 --- a/nodes.py +++ b/nodes.py @@ -647,7 +647,44 @@ Selects and returns the images at the specified indices as an image batch. chosen_images = images[indices_tensor] return (chosen_images,) + +class InsertImagesToBatchIndexed: + RETURN_TYPES = ("IMAGE",) + FUNCTION = "insertimagesfrombatch" + CATEGORY = "KJNodes/image" + DESCRIPTION = """ +Inserts images at the specified indices into the original image batch. +""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "original_images": ("IMAGE",), + "images_to_insert": ("IMAGE",), + "indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}), + }, + } + + def insertimagesfrombatch(self, original_images, images_to_insert, indexes): + + # Parse the indexes string into a list of integers + index_list = [int(index.strip()) for index in indexes.split(',')] + + # Convert list of indices to a PyTorch tensor + indices_tensor = torch.tensor(index_list, dtype=torch.long) + + # Ensure the images_to_insert is a tensor + if not isinstance(images_to_insert, torch.Tensor): + images_to_insert = torch.tensor(images_to_insert) + + # Insert the images at the specified indices + for index, image in zip(indices_tensor, images_to_insert): + original_images[index] = image + + return (original_images,) + class GetLatentsFromBatchIndexed: RETURN_TYPES = ("LATENT",) @@ -1674,7 +1711,7 @@ class ImageBatchTestPattern: def INPUT_TYPES(s): return {"required": { "batch_size": ("INT", {"default": 1,"min": 1, "max": 255, "step": 1}), - "start_from": ("INT", {"default": 1,"min": 1, "max": 255, "step": 1}), + "start_from": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}), "text_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), "text_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), "width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), @@ -4522,6 +4559,7 @@ NODE_CLASS_MAPPINGS = { "StableZero123_BatchSchedule": StableZero123_BatchSchedule, "SV3D_BatchSchedule": SV3D_BatchSchedule, "GetImagesFromBatchIndexed": GetImagesFromBatchIndexed, + "InsertImagesToBatchIndexed": InsertImagesToBatchIndexed, "ImageBatchRepeatInterleaving": ImageBatchRepeatInterleaving, "NormalizedAmplitudeToMask": NormalizedAmplitudeToMask, "OffsetMaskByNormalizedAmplitude": OffsetMaskByNormalizedAmplitude, @@ -4567,6 +4605,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "EmptyLatentImagePresets": "EmptyLatentImagePresets", "ColorMatch": "ColorMatch", "GetImageRangeFromBatch": "GetImageRangeFromBatch", + "InsertImagesToBatchIndexed": "InsertImagesToBatchIndexed", "SaveImageWithAlpha": "SaveImageWithAlpha", "ReverseImageBatch": "ReverseImageBatch", "ImageGridComposite2x2": "ImageGridComposite2x2",