diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 5d0a553..4103faf 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -2007,25 +2007,43 @@ with the replacement images. def INPUT_TYPES(s): return { "required": { - "original_images": ("IMAGE",), - "replacement_images": ("IMAGE",), "start_index": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}), }, "optional": { + "original_images": ("IMAGE",), + "replacement_images": ("IMAGE",), "original_masks": ("MASK",), "replacement_masks": ("MASK",), } } - def replace(self, original_images, replacement_images, start_index, original_masks=None, replacement_masks=None): + def replace(self, original_images=None, replacement_images=None, start_index=1, original_masks=None, replacement_masks=None): images = None - if start_index >= len(original_images): - raise ValueError("GetImageRangeFromBatch: Start index is out of range") - end_index = start_index + len(replacement_images) - if end_index > len(original_images): - raise ValueError("GetImageRangeFromBatch: End index is out of range") + masks = None + + if original_images is not None and replacement_images is not None: + if start_index >= len(original_images): + raise ValueError("ReplaceImagesInBatch: Start index is out of range") + end_index = start_index + len(replacement_images) + if end_index > len(original_images): + raise ValueError("ReplaceImagesInBatch: End index is out of range") + + original_images_copy = original_images.clone() + if original_images_copy.shape[2] != replacement_images.shape[2] or original_images_copy.shape[3] != replacement_images.shape[3]: + replacement_images = common_upscale(replacement_images.movedim(-1, 1), original_images_copy.shape[1], original_images_copy.shape[2], "lanczos", "center").movedim(1, -1) + + original_images_copy[start_index:end_index] = replacement_images + images = original_images_copy + else: + images = torch.zeros((1, 64, 64, 3)) if original_masks is not None and replacement_masks is not None: + if start_index >= len(original_masks): + raise ValueError("ReplaceImagesInBatch: Start index is out of range") + end_index = start_index + len(replacement_masks) + if end_index > len(original_masks): + raise ValueError("ReplaceImagesInBatch: End index is out of range") + original_masks_copy = original_masks.clone() if original_masks_copy.shape[1] != replacement_masks.shape[1] or original_masks_copy.shape[2] != replacement_masks.shape[2]: replacement_masks = common_upscale(replacement_masks.unsqueeze(1), original_masks_copy.shape[1], original_masks_copy.shape[2], "nearest-exact", "center").squeeze(0) @@ -2033,15 +2051,8 @@ with the replacement images. original_masks_copy[start_index:end_index] = replacement_masks masks = original_masks_copy else: - masks = torch.zeros(1,64,64, device=original_images.device, dtype=original_images.dtype) + masks = torch.zeros((1, 64, 64)) - original_images_copy = original_images.clone() - - if original_images_copy.shape[2] != replacement_images.shape[2] or original_images_copy.shape[3] != replacement_images.shape[3]: - replacement_images = common_upscale(replacement_images.movedim(-1, 1), original_images_copy.shape[1], original_images_copy.shape[2], "lanczos", "center").movedim(1, -1) - - original_images_copy[start_index:end_index] = replacement_images - images = original_images_copy return (images, masks)