diff --git a/nodes.py b/nodes.py index 89e6532..e824bd5 100644 --- a/nodes.py +++ b/nodes.py @@ -1797,6 +1797,131 @@ class BatchCropFromMaskAdvanced: return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size) +class FilterZeroMasksAndCorrespondingImages: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "masks": ("MASK",), + }, + "optional": { + "original_images": ("IMAGE",), + }, + } + + RETURN_TYPES = ( + "MASK", + "IMAGE", + "IMAGE", + "INDEXES" + ) + RETURN_NAMES = ( + "non_zero_masks_out", + "non_zero_mask_images_out", + "zero_mask_images_out", + "zero_mask_images_out_indexes" + ) + FUNCTION = "filter" + CATEGORY = "KJNodes/masking" + + def filter(self, masks, original_images=None): + """ + Filter out all the empty (i.e. all zero) mask in masks + Also filter out all the corresponding images in original_images by indexes if provide + + Args: + original_images (optional): If provide, it need have same length as masks. + """ + non_zero_masks = [] + non_zero_mask_images = [] + zero_mask_images = [] + zero_mask_images_indexes = [] + + masks_num = len(masks) + also_process_images = False + if original_images is not None: + imgs_num = len(original_images) + if len(original_images) == masks_num: + also_process_images = True + else: + print(f"[WARNING] ignore input: original_images, due to number of original_images ({imgs_num}) is not equal to number of masks ({masks_num})") + + for i in range(masks_num): + non_zero_num = np.count_nonzero(np.array(masks[i])) + if non_zero_num > 0: + non_zero_masks.append(masks[i]) + if also_process_images: + non_zero_mask_images.append(original_images[i]) + else: + zero_mask_images.append(original_images[i]) + zero_mask_images_indexes.append(i) + + non_zero_masks_out = torch.stack(non_zero_masks, dim=0) + non_zero_mask_images_out = zero_mask_images_out = zero_mask_images_out_indexes = None + + if also_process_images: + non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0) + if len(zero_mask_images) > 0: + zero_mask_images_out = torch.stack(zero_mask_images, dim=0) + zero_mask_images_out_indexes = zero_mask_images_indexes + + return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out, zero_mask_images_out_indexes) + +class InsertImageBatchByIndexes: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE",), + "images_to_insert": ("IMAGE",), + "insert_indexes": ("INDEXES",), + }, + } + + RETURN_TYPES = ( + "IMAGE", + ) + RETURN_NAMES = ( + "images_after_insert", + ) + FUNCTION = "insert" + CATEGORY = "KJNodes" + + def insert(self, images, images_to_insert, insert_indexes): + """ + This node is designed to be use with node FilterZeroMasksAndCorrespondingImages + It inserts the images_to_insert into images according to insert_indexes + + Returns: + images_after_insert: updated original images with origonal sequence order + """ + + images_after_insert = images + + if images_to_insert is not None and insert_indexes is not None: + images_to_insert_num = len(images_to_insert) + insert_indexes_num = len(insert_indexes) + if images_to_insert_num == insert_indexes_num: + images_after_insert = [] + + i_images = 0 + for i in range(len(images) + images_to_insert_num): + if i in insert_indexes: + images_after_insert.append(images_to_insert[insert_indexes.index(i)]) + else: + images_after_insert.append(images[i_images]) + i_images += 1 + + images_after_insert = torch.stack(images_after_insert, dim=0) + + else: + print(f"[WARNING] skip this node, due to number of images_to_insert ({images_to_insert_num}) is not equal to number of insert_indexes ({insert_indexes_num})") + + + return (images_after_insert, ) + def bbox_to_region(bbox, target_size=None): bbox = bbox_check(bbox, target_size) return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]) @@ -2973,6 +3098,8 @@ NODE_CLASS_MAPPINGS = { "ReplaceImagesInBatch": ReplaceImagesInBatch, "BatchCropFromMask": BatchCropFromMask, "BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced, + "FilterZeroMasksAndCorrespondingImages": FilterZeroMasksAndCorrespondingImages, + "InsertImageBatchByIndexes": InsertImageBatchByIndexes, "BatchUncrop": BatchUncrop, "BatchUncropAdvanced": BatchUncropAdvanced, "BatchCLIPSeg": BatchCLIPSeg, @@ -3028,6 +3155,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ReplaceImagesInBatch": "ReplaceImagesInBatch", "BatchCropFromMask": "BatchCropFromMask", "BatchCropFromMaskAdvanced": "BatchCropFromMaskAdvanced", + "FilterZeroMasksAndCorrespondingImages": "FilterZeroMasksAndCorrespondingImages", + "InsertImageBatchByIndexes": "InsertImageBatchByIndexes", "BatchUncrop": "BatchUncrop", "BatchUncropAdvanced": "BatchUncropAdvanced", "BatchCLIPSeg": "BatchCLIPSeg",