From 2830476fd4a302afe1f8bd9b2ff839021fec0c3b Mon Sep 17 00:00:00 2001 From: MrForExample <62230687+MrForExample@users.noreply.github.com> Date: Wed, 27 Dec 2023 19:24:28 +0100 Subject: [PATCH 1/4] Add new node: FilterZeroMasksAndCorrespondingImages Filter out all the empty (i.e. all zero) mask in masks Also filter out all the corresponding images in original_images by indexes if provide --- nodes.py | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/nodes.py b/nodes.py index 5bb7995..d698120 100644 --- a/nodes.py +++ b/nodes.py @@ -1804,6 +1804,72 @@ class BatchCropFromMaskAdvanced: return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size) +class FilterZeroMasksAndCorrespondingImages: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "masks": ("MASK",), + }, + "optional": { + "original_images": ("IMAGE",), + }, + } + + RETURN_TYPES = ( + "MASK", + "IMAGE", + "IMAGE", + ) + RETURN_NAMES = ( + "non_zero_masks_out", + "non_zero_mask_images_out", + "zero_mask_images_out", + ) + FUNCTION = "filter" + CATEGORY = "KJNodes/masking" + + def filter(self, masks, original_images=None): + """ + Filter out all the empty (i.e. all zero) mask in masks + Also filter out all the corresponding images in original_images by indexes if provide + + Args: + original_images (optional): If provide, it need have same length as masks. + """ + non_zero_masks = [] + non_zero_mask_images = [] + zero_mask_images = [] + + masks_num = len(masks) + also_process_images = False + if original_images is not None: + imgs_num = len(original_images) + if len(original_images) == masks_num: + also_process_images = True + else: + print(f"[WARNING] ignore input: original_images, due to number of original_images ({imgs_num}) is not equal to number of masks ({masks_num})") + + for i in range(masks_num): + non_zero_num = np.count_nonzero(np.array(masks[i])) + if non_zero_num > 0: + non_zero_masks.append(masks[i]) + if also_process_images: + non_zero_mask_images.append(original_images[i]) + else: + zero_mask_images.append(original_images[i]) + + non_zero_masks_out = torch.stack(non_zero_masks, dim=0) + + if also_process_images: + non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0) + zero_mask_images_out = torch.stack(zero_mask_images, dim=0) + else: + non_zero_mask_images_out = zero_mask_images_out = None + + return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out) + def bbox_to_region(bbox, target_size=None): bbox = bbox_check(bbox, target_size) return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]) @@ -2971,6 +3037,7 @@ NODE_CLASS_MAPPINGS = { "ReplaceImagesInBatch": ReplaceImagesInBatch, "BatchCropFromMask": BatchCropFromMask, "BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced, + "FilterZeroMasksAndCorrespondingImages": FilterZeroMasksAndCorrespondingImages, "BatchUncrop": BatchUncrop, "BatchUncropAdvanced": BatchUncropAdvanced, "BatchCLIPSeg": BatchCLIPSeg, @@ -3026,6 +3093,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ReplaceImagesInBatch": "ReplaceImagesInBatch", "BatchCropFromMask": "BatchCropFromMask", "BatchCropFromMaskAdvanced": "BatchCropFromMaskAdvanced", + "FilterZeroMasksAndCorrespondingImages": "FilterZeroMasksAndCorrespondingImages", "BatchUncrop": "BatchUncrop", "BatchUncropAdvanced": "BatchUncropAdvanced", "BatchCLIPSeg": "BatchCLIPSeg", From ccd89290ed387c3b639ef7f6a9d4295be0a97165 Mon Sep 17 00:00:00 2001 From: MrForExample <62230687+MrForExample@users.noreply.github.com> Date: Tue, 2 Jan 2024 20:27:46 +0100 Subject: [PATCH 2/4] Node FilterZeroMasksAndCorrespondingImages Edge case handle --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index d698120..d062c8f 100644 --- a/nodes.py +++ b/nodes.py @@ -1864,7 +1864,7 @@ class FilterZeroMasksAndCorrespondingImages: if also_process_images: non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0) - zero_mask_images_out = torch.stack(zero_mask_images, dim=0) + zero_mask_images_out = torch.stack(zero_mask_images, dim=0) if len(zero_mask_images) > 0 else None else: non_zero_mask_images_out = zero_mask_images_out = None From 7e2709425ab1f4d76ea7c899f9910de6261bcc45 Mon Sep 17 00:00:00 2001 From: MrForExample <62230687+MrForExample@users.noreply.github.com> Date: Wed, 3 Jan 2024 11:43:45 +0100 Subject: [PATCH 3/4] Add InsertImageBatchByIndexes --- nodes.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index d062c8f..e0fe382 100644 --- a/nodes.py +++ b/nodes.py @@ -1821,11 +1821,13 @@ class FilterZeroMasksAndCorrespondingImages: "MASK", "IMAGE", "IMAGE", + "INDEXES" ) RETURN_NAMES = ( "non_zero_masks_out", "non_zero_mask_images_out", "zero_mask_images_out", + "zero_mask_images_out_indexes" ) FUNCTION = "filter" CATEGORY = "KJNodes/masking" @@ -1841,6 +1843,7 @@ class FilterZeroMasksAndCorrespondingImages: non_zero_masks = [] non_zero_mask_images = [] zero_mask_images = [] + zero_mask_images_indexes = [] masks_num = len(masks) also_process_images = False @@ -1859,16 +1862,75 @@ class FilterZeroMasksAndCorrespondingImages: non_zero_mask_images.append(original_images[i]) else: zero_mask_images.append(original_images[i]) + zero_mask_images_indexes.append(i) non_zero_masks_out = torch.stack(non_zero_masks, dim=0) + non_zero_mask_images_out = zero_mask_images_out = zero_mask_images_out_indexes = None if also_process_images: non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0) - zero_mask_images_out = torch.stack(zero_mask_images, dim=0) if len(zero_mask_images) > 0 else None - else: - non_zero_mask_images_out = zero_mask_images_out = None + if len(zero_mask_images) > 0: + zero_mask_images_out = torch.stack(zero_mask_images, dim=0) + zero_mask_images_out_indexes = zero_mask_images_indexes - return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out) + return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out, zero_mask_images_out_indexes) + +class InsertImageBatchByIndexes: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE",), + "images_to_insert": ("IMAGE",), + "insert_indexes": ("INDEXES",), + }, + } + + RETURN_TYPES = ( + "IMAGE", + ) + RETURN_NAMES = ( + "images_after_insert", + ) + FUNCTION = "insert" + CATEGORY = "KJNodes" + + def insert(self, images, images_to_insert, insert_indexes): + """_summary_ + + Args: + images (_type_): _description_ + images_to_insert (_type_): _description_ + insert_indexes (_type_): _description_ + + Returns: + _type_: _description_ + """ + + images_after_insert = images + + if images_to_insert is not None and insert_indexes is not None: + images_to_insert_num = len(images_to_insert) + insert_indexes_num = len(insert_indexes) + if images_to_insert_num == insert_indexes_num: + images_after_insert = [] + + i_images = 0 + for i in range(len(images) + images_to_insert_num): + if i in insert_indexes: + images_after_insert.append(images_to_insert[insert_indexes.index(i)]) + else: + images_after_insert.append(images[i_images]) + i_images += 1 + + images_after_insert = torch.stack(images_after_insert, dim=0) + + else: + print(f"[WARNING] skip this node, due to number of images_to_insert ({images_to_insert_num}) is not equal to number of insert_indexes ({insert_indexes_num})") + + + return (images_after_insert, ) def bbox_to_region(bbox, target_size=None): bbox = bbox_check(bbox, target_size) @@ -3038,6 +3100,7 @@ NODE_CLASS_MAPPINGS = { "BatchCropFromMask": BatchCropFromMask, "BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced, "FilterZeroMasksAndCorrespondingImages": FilterZeroMasksAndCorrespondingImages, + "InsertImageBatchByIndexes": InsertImageBatchByIndexes, "BatchUncrop": BatchUncrop, "BatchUncropAdvanced": BatchUncropAdvanced, "BatchCLIPSeg": BatchCLIPSeg, @@ -3094,6 +3157,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "BatchCropFromMask": "BatchCropFromMask", "BatchCropFromMaskAdvanced": "BatchCropFromMaskAdvanced", "FilterZeroMasksAndCorrespondingImages": "FilterZeroMasksAndCorrespondingImages", + "InsertImageBatchByIndexes": "InsertImageBatchByIndexes", "BatchUncrop": "BatchUncrop", "BatchUncropAdvanced": "BatchUncropAdvanced", "BatchCLIPSeg": "BatchCLIPSeg", From 983b7e683d8537ab66361311407616abbca25b0c Mon Sep 17 00:00:00 2001 From: MrForExample <62230687+MrForExample@users.noreply.github.com> Date: Wed, 3 Jan 2024 11:49:39 +0100 Subject: [PATCH 4/4] Update comment for node InsertImageBatchByIndexes --- nodes.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/nodes.py b/nodes.py index e0fe382..b7897c1 100644 --- a/nodes.py +++ b/nodes.py @@ -1897,15 +1897,12 @@ class InsertImageBatchByIndexes: CATEGORY = "KJNodes" def insert(self, images, images_to_insert, insert_indexes): - """_summary_ - - Args: - images (_type_): _description_ - images_to_insert (_type_): _description_ - insert_indexes (_type_): _description_ + """ + This node is designed to be use with node FilterZeroMasksAndCorrespondingImages + It inserts the images_to_insert into images according to insert_indexes Returns: - _type_: _description_ + images_after_insert: updated original images with origonal sequence order """ images_after_insert = images