Add new node: FilterZeroMasksAndCorrespondingImages

Filter out all the empty (i.e. all zero) mask in masks
Also filter out all the corresponding images in original_images by indexes if provide
This commit is contained in:
MrForExample 2023-12-27 19:24:28 +01:00
parent 6010472e5f
commit 2830476fd4

View File

@ -1804,6 +1804,72 @@ class BatchCropFromMaskAdvanced:
return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size)
class FilterZeroMasksAndCorrespondingImages:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"masks": ("MASK",),
},
"optional": {
"original_images": ("IMAGE",),
},
}
RETURN_TYPES = (
"MASK",
"IMAGE",
"IMAGE",
)
RETURN_NAMES = (
"non_zero_masks_out",
"non_zero_mask_images_out",
"zero_mask_images_out",
)
FUNCTION = "filter"
CATEGORY = "KJNodes/masking"
def filter(self, masks, original_images=None):
"""
Filter out all the empty (i.e. all zero) mask in masks
Also filter out all the corresponding images in original_images by indexes if provide
Args:
original_images (optional): If provide, it need have same length as masks.
"""
non_zero_masks = []
non_zero_mask_images = []
zero_mask_images = []
masks_num = len(masks)
also_process_images = False
if original_images is not None:
imgs_num = len(original_images)
if len(original_images) == masks_num:
also_process_images = True
else:
print(f"[WARNING] ignore input: original_images, due to number of original_images ({imgs_num}) is not equal to number of masks ({masks_num})")
for i in range(masks_num):
non_zero_num = np.count_nonzero(np.array(masks[i]))
if non_zero_num > 0:
non_zero_masks.append(masks[i])
if also_process_images:
non_zero_mask_images.append(original_images[i])
else:
zero_mask_images.append(original_images[i])
non_zero_masks_out = torch.stack(non_zero_masks, dim=0)
if also_process_images:
non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0)
zero_mask_images_out = torch.stack(zero_mask_images, dim=0)
else:
non_zero_mask_images_out = zero_mask_images_out = None
return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out)
def bbox_to_region(bbox, target_size=None):
bbox = bbox_check(bbox, target_size)
return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
@ -2971,6 +3037,7 @@ NODE_CLASS_MAPPINGS = {
"ReplaceImagesInBatch": ReplaceImagesInBatch,
"BatchCropFromMask": BatchCropFromMask,
"BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced,
"FilterZeroMasksAndCorrespondingImages": FilterZeroMasksAndCorrespondingImages,
"BatchUncrop": BatchUncrop,
"BatchUncropAdvanced": BatchUncropAdvanced,
"BatchCLIPSeg": BatchCLIPSeg,
@ -3026,6 +3093,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ReplaceImagesInBatch": "ReplaceImagesInBatch",
"BatchCropFromMask": "BatchCropFromMask",
"BatchCropFromMaskAdvanced": "BatchCropFromMaskAdvanced",
"FilterZeroMasksAndCorrespondingImages": "FilterZeroMasksAndCorrespondingImages",
"BatchUncrop": "BatchUncrop",
"BatchUncropAdvanced": "BatchUncropAdvanced",
"BatchCLIPSeg": "BatchCLIPSeg",