Merge pull request #12 from MrForExample/main

Adding a new node to bypass the error caused by all-zero mask in BatchCropFromMaskAdvanced
This commit is contained in:
Jukka Seppänen 2024-01-06 17:10:48 +02:00 committed by GitHub
commit a0a7e8556c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

129
nodes.py
View File

@ -1797,6 +1797,131 @@ class BatchCropFromMaskAdvanced:
return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size)
class FilterZeroMasksAndCorrespondingImages:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"masks": ("MASK",),
},
"optional": {
"original_images": ("IMAGE",),
},
}
RETURN_TYPES = (
"MASK",
"IMAGE",
"IMAGE",
"INDEXES"
)
RETURN_NAMES = (
"non_zero_masks_out",
"non_zero_mask_images_out",
"zero_mask_images_out",
"zero_mask_images_out_indexes"
)
FUNCTION = "filter"
CATEGORY = "KJNodes/masking"
def filter(self, masks, original_images=None):
"""
Filter out all the empty (i.e. all zero) mask in masks
Also filter out all the corresponding images in original_images by indexes if provide
Args:
original_images (optional): If provide, it need have same length as masks.
"""
non_zero_masks = []
non_zero_mask_images = []
zero_mask_images = []
zero_mask_images_indexes = []
masks_num = len(masks)
also_process_images = False
if original_images is not None:
imgs_num = len(original_images)
if len(original_images) == masks_num:
also_process_images = True
else:
print(f"[WARNING] ignore input: original_images, due to number of original_images ({imgs_num}) is not equal to number of masks ({masks_num})")
for i in range(masks_num):
non_zero_num = np.count_nonzero(np.array(masks[i]))
if non_zero_num > 0:
non_zero_masks.append(masks[i])
if also_process_images:
non_zero_mask_images.append(original_images[i])
else:
zero_mask_images.append(original_images[i])
zero_mask_images_indexes.append(i)
non_zero_masks_out = torch.stack(non_zero_masks, dim=0)
non_zero_mask_images_out = zero_mask_images_out = zero_mask_images_out_indexes = None
if also_process_images:
non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0)
if len(zero_mask_images) > 0:
zero_mask_images_out = torch.stack(zero_mask_images, dim=0)
zero_mask_images_out_indexes = zero_mask_images_indexes
return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out, zero_mask_images_out_indexes)
class InsertImageBatchByIndexes:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"images_to_insert": ("IMAGE",),
"insert_indexes": ("INDEXES",),
},
}
RETURN_TYPES = (
"IMAGE",
)
RETURN_NAMES = (
"images_after_insert",
)
FUNCTION = "insert"
CATEGORY = "KJNodes"
def insert(self, images, images_to_insert, insert_indexes):
"""
This node is designed to be use with node FilterZeroMasksAndCorrespondingImages
It inserts the images_to_insert into images according to insert_indexes
Returns:
images_after_insert: updated original images with origonal sequence order
"""
images_after_insert = images
if images_to_insert is not None and insert_indexes is not None:
images_to_insert_num = len(images_to_insert)
insert_indexes_num = len(insert_indexes)
if images_to_insert_num == insert_indexes_num:
images_after_insert = []
i_images = 0
for i in range(len(images) + images_to_insert_num):
if i in insert_indexes:
images_after_insert.append(images_to_insert[insert_indexes.index(i)])
else:
images_after_insert.append(images[i_images])
i_images += 1
images_after_insert = torch.stack(images_after_insert, dim=0)
else:
print(f"[WARNING] skip this node, due to number of images_to_insert ({images_to_insert_num}) is not equal to number of insert_indexes ({insert_indexes_num})")
return (images_after_insert, )
def bbox_to_region(bbox, target_size=None):
bbox = bbox_check(bbox, target_size)
return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
@ -2973,6 +3098,8 @@ NODE_CLASS_MAPPINGS = {
"ReplaceImagesInBatch": ReplaceImagesInBatch,
"BatchCropFromMask": BatchCropFromMask,
"BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced,
"FilterZeroMasksAndCorrespondingImages": FilterZeroMasksAndCorrespondingImages,
"InsertImageBatchByIndexes": InsertImageBatchByIndexes,
"BatchUncrop": BatchUncrop,
"BatchUncropAdvanced": BatchUncropAdvanced,
"BatchCLIPSeg": BatchCLIPSeg,
@ -3028,6 +3155,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ReplaceImagesInBatch": "ReplaceImagesInBatch",
"BatchCropFromMask": "BatchCropFromMask",
"BatchCropFromMaskAdvanced": "BatchCropFromMaskAdvanced",
"FilterZeroMasksAndCorrespondingImages": "FilterZeroMasksAndCorrespondingImages",
"InsertImageBatchByIndexes": "InsertImageBatchByIndexes",
"BatchUncrop": "BatchUncrop",
"BatchUncropAdvanced": "BatchUncropAdvanced",
"BatchCLIPSeg": "BatchCLIPSeg",