diff --git a/nodes.py b/nodes.py index 5bb7995..d698120 100644 --- a/nodes.py +++ b/nodes.py @@ -1804,6 +1804,72 @@ class BatchCropFromMaskAdvanced: return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size) +class FilterZeroMasksAndCorrespondingImages: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "masks": ("MASK",), + }, + "optional": { + "original_images": ("IMAGE",), + }, + } + + RETURN_TYPES = ( + "MASK", + "IMAGE", + "IMAGE", + ) + RETURN_NAMES = ( + "non_zero_masks_out", + "non_zero_mask_images_out", + "zero_mask_images_out", + ) + FUNCTION = "filter" + CATEGORY = "KJNodes/masking" + + def filter(self, masks, original_images=None): + """ + Filter out all the empty (i.e. all zero) mask in masks + Also filter out all the corresponding images in original_images by indexes if provide + + Args: + original_images (optional): If provide, it need have same length as masks. + """ + non_zero_masks = [] + non_zero_mask_images = [] + zero_mask_images = [] + + masks_num = len(masks) + also_process_images = False + if original_images is not None: + imgs_num = len(original_images) + if len(original_images) == masks_num: + also_process_images = True + else: + print(f"[WARNING] ignore input: original_images, due to number of original_images ({imgs_num}) is not equal to number of masks ({masks_num})") + + for i in range(masks_num): + non_zero_num = np.count_nonzero(np.array(masks[i])) + if non_zero_num > 0: + non_zero_masks.append(masks[i]) + if also_process_images: + non_zero_mask_images.append(original_images[i]) + else: + zero_mask_images.append(original_images[i]) + + non_zero_masks_out = torch.stack(non_zero_masks, dim=0) + + if also_process_images: + non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0) + zero_mask_images_out = torch.stack(zero_mask_images, dim=0) + else: + non_zero_mask_images_out = zero_mask_images_out = None + + return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out) + def bbox_to_region(bbox, target_size=None): bbox = bbox_check(bbox, target_size) return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]) @@ -2971,6 +3037,7 @@ NODE_CLASS_MAPPINGS = { "ReplaceImagesInBatch": ReplaceImagesInBatch, "BatchCropFromMask": BatchCropFromMask, "BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced, + "FilterZeroMasksAndCorrespondingImages": FilterZeroMasksAndCorrespondingImages, "BatchUncrop": BatchUncrop, "BatchUncropAdvanced": BatchUncropAdvanced, "BatchCLIPSeg": BatchCLIPSeg, @@ -3026,6 +3093,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ReplaceImagesInBatch": "ReplaceImagesInBatch", "BatchCropFromMask": "BatchCropFromMask", "BatchCropFromMaskAdvanced": "BatchCropFromMaskAdvanced", + "FilterZeroMasksAndCorrespondingImages": "FilterZeroMasksAndCorrespondingImages", "BatchUncrop": "BatchUncrop", "BatchUncropAdvanced": "BatchUncropAdvanced", "BatchCLIPSeg": "BatchCLIPSeg",