Add InsertImageBatchByIndexes

This commit is contained in:
MrForExample 2024-01-03 11:43:45 +01:00
parent ccd89290ed
commit 7e2709425a

View File

@ -1821,11 +1821,13 @@ class FilterZeroMasksAndCorrespondingImages:
"MASK",
"IMAGE",
"IMAGE",
"INDEXES"
)
RETURN_NAMES = (
"non_zero_masks_out",
"non_zero_mask_images_out",
"zero_mask_images_out",
"zero_mask_images_out_indexes"
)
FUNCTION = "filter"
CATEGORY = "KJNodes/masking"
@ -1841,6 +1843,7 @@ class FilterZeroMasksAndCorrespondingImages:
non_zero_masks = []
non_zero_mask_images = []
zero_mask_images = []
zero_mask_images_indexes = []
masks_num = len(masks)
also_process_images = False
@ -1859,16 +1862,75 @@ class FilterZeroMasksAndCorrespondingImages:
non_zero_mask_images.append(original_images[i])
else:
zero_mask_images.append(original_images[i])
zero_mask_images_indexes.append(i)
non_zero_masks_out = torch.stack(non_zero_masks, dim=0)
non_zero_mask_images_out = zero_mask_images_out = zero_mask_images_out_indexes = None
if also_process_images:
non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0)
zero_mask_images_out = torch.stack(zero_mask_images, dim=0) if len(zero_mask_images) > 0 else None
else:
non_zero_mask_images_out = zero_mask_images_out = None
if len(zero_mask_images) > 0:
zero_mask_images_out = torch.stack(zero_mask_images, dim=0)
zero_mask_images_out_indexes = zero_mask_images_indexes
return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out)
return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out, zero_mask_images_out_indexes)
class InsertImageBatchByIndexes:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"images_to_insert": ("IMAGE",),
"insert_indexes": ("INDEXES",),
},
}
RETURN_TYPES = (
"IMAGE",
)
RETURN_NAMES = (
"images_after_insert",
)
FUNCTION = "insert"
CATEGORY = "KJNodes"
def insert(self, images, images_to_insert, insert_indexes):
"""_summary_
Args:
images (_type_): _description_
images_to_insert (_type_): _description_
insert_indexes (_type_): _description_
Returns:
_type_: _description_
"""
images_after_insert = images
if images_to_insert is not None and insert_indexes is not None:
images_to_insert_num = len(images_to_insert)
insert_indexes_num = len(insert_indexes)
if images_to_insert_num == insert_indexes_num:
images_after_insert = []
i_images = 0
for i in range(len(images) + images_to_insert_num):
if i in insert_indexes:
images_after_insert.append(images_to_insert[insert_indexes.index(i)])
else:
images_after_insert.append(images[i_images])
i_images += 1
images_after_insert = torch.stack(images_after_insert, dim=0)
else:
print(f"[WARNING] skip this node, due to number of images_to_insert ({images_to_insert_num}) is not equal to number of insert_indexes ({insert_indexes_num})")
return (images_after_insert, )
def bbox_to_region(bbox, target_size=None):
bbox = bbox_check(bbox, target_size)
@ -3038,6 +3100,7 @@ NODE_CLASS_MAPPINGS = {
"BatchCropFromMask": BatchCropFromMask,
"BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced,
"FilterZeroMasksAndCorrespondingImages": FilterZeroMasksAndCorrespondingImages,
"InsertImageBatchByIndexes": InsertImageBatchByIndexes,
"BatchUncrop": BatchUncrop,
"BatchUncropAdvanced": BatchUncropAdvanced,
"BatchCLIPSeg": BatchCLIPSeg,
@ -3094,6 +3157,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"BatchCropFromMask": "BatchCropFromMask",
"BatchCropFromMaskAdvanced": "BatchCropFromMaskAdvanced",
"FilterZeroMasksAndCorrespondingImages": "FilterZeroMasksAndCorrespondingImages",
"InsertImageBatchByIndexes": "InsertImageBatchByIndexes",
"BatchUncrop": "BatchUncrop",
"BatchUncropAdvanced": "BatchUncropAdvanced",
"BatchCLIPSeg": "BatchCLIPSeg",