From 7e2709425ab1f4d76ea7c899f9910de6261bcc45 Mon Sep 17 00:00:00 2001 From: MrForExample <62230687+MrForExample@users.noreply.github.com> Date: Wed, 3 Jan 2024 11:43:45 +0100 Subject: [PATCH] Add InsertImageBatchByIndexes --- nodes.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index d062c8f..e0fe382 100644 --- a/nodes.py +++ b/nodes.py @@ -1821,11 +1821,13 @@ class FilterZeroMasksAndCorrespondingImages: "MASK", "IMAGE", "IMAGE", + "INDEXES" ) RETURN_NAMES = ( "non_zero_masks_out", "non_zero_mask_images_out", "zero_mask_images_out", + "zero_mask_images_out_indexes" ) FUNCTION = "filter" CATEGORY = "KJNodes/masking" @@ -1841,6 +1843,7 @@ class FilterZeroMasksAndCorrespondingImages: non_zero_masks = [] non_zero_mask_images = [] zero_mask_images = [] + zero_mask_images_indexes = [] masks_num = len(masks) also_process_images = False @@ -1859,16 +1862,75 @@ class FilterZeroMasksAndCorrespondingImages: non_zero_mask_images.append(original_images[i]) else: zero_mask_images.append(original_images[i]) + zero_mask_images_indexes.append(i) non_zero_masks_out = torch.stack(non_zero_masks, dim=0) + non_zero_mask_images_out = zero_mask_images_out = zero_mask_images_out_indexes = None if also_process_images: non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0) - zero_mask_images_out = torch.stack(zero_mask_images, dim=0) if len(zero_mask_images) > 0 else None - else: - non_zero_mask_images_out = zero_mask_images_out = None + if len(zero_mask_images) > 0: + zero_mask_images_out = torch.stack(zero_mask_images, dim=0) + zero_mask_images_out_indexes = zero_mask_images_indexes - return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out) + return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out, zero_mask_images_out_indexes) + +class InsertImageBatchByIndexes: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE",), + "images_to_insert": ("IMAGE",), + "insert_indexes": ("INDEXES",), + }, + } + + RETURN_TYPES = ( + "IMAGE", + ) + RETURN_NAMES = ( + "images_after_insert", + ) + FUNCTION = "insert" + CATEGORY = "KJNodes" + + def insert(self, images, images_to_insert, insert_indexes): + """_summary_ + + Args: + images (_type_): _description_ + images_to_insert (_type_): _description_ + insert_indexes (_type_): _description_ + + Returns: + _type_: _description_ + """ + + images_after_insert = images + + if images_to_insert is not None and insert_indexes is not None: + images_to_insert_num = len(images_to_insert) + insert_indexes_num = len(insert_indexes) + if images_to_insert_num == insert_indexes_num: + images_after_insert = [] + + i_images = 0 + for i in range(len(images) + images_to_insert_num): + if i in insert_indexes: + images_after_insert.append(images_to_insert[insert_indexes.index(i)]) + else: + images_after_insert.append(images[i_images]) + i_images += 1 + + images_after_insert = torch.stack(images_after_insert, dim=0) + + else: + print(f"[WARNING] skip this node, due to number of images_to_insert ({images_to_insert_num}) is not equal to number of insert_indexes ({insert_indexes_num})") + + + return (images_after_insert, ) def bbox_to_region(bbox, target_size=None): bbox = bbox_check(bbox, target_size) @@ -3038,6 +3100,7 @@ NODE_CLASS_MAPPINGS = { "BatchCropFromMask": BatchCropFromMask, "BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced, "FilterZeroMasksAndCorrespondingImages": FilterZeroMasksAndCorrespondingImages, + "InsertImageBatchByIndexes": InsertImageBatchByIndexes, "BatchUncrop": BatchUncrop, "BatchUncropAdvanced": BatchUncropAdvanced, "BatchCLIPSeg": BatchCLIPSeg, @@ -3094,6 +3157,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "BatchCropFromMask": "BatchCropFromMask", "BatchCropFromMaskAdvanced": "BatchCropFromMaskAdvanced", "FilterZeroMasksAndCorrespondingImages": "FilterZeroMasksAndCorrespondingImages", + "InsertImageBatchByIndexes": "InsertImageBatchByIndexes", "BatchUncrop": "BatchUncrop", "BatchUncropAdvanced": "BatchUncropAdvanced", "BatchCLIPSeg": "BatchCLIPSeg",