diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 9b07600..40bc4ed 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -2156,8 +2156,8 @@ class ReplaceImagesInBatch: FUNCTION = "replace" CATEGORY = "KJNodes/image" DESCRIPTION = """ -Replaces the images in a batch, starting from the specified start index, -with the replacement images. +Replaces the images in a batch, starting from the specified start index with step stride, +using the replacement images. """ @classmethod @@ -2165,6 +2165,7 @@ with the replacement images. return { "required": { "start_index": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}), + "step": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}), }, "optional": { "original_images": ("IMAGE",), @@ -2174,14 +2175,14 @@ with the replacement images. } } - def replace(self, original_images=None, replacement_images=None, start_index=1, original_masks=None, replacement_masks=None): + def replace(self, original_images=None, replacement_images=None, start_index=1, step=1, original_masks=None, replacement_masks=None): images = None masks = None if original_images is not None and replacement_images is not None: if start_index >= len(original_images): raise ValueError("ReplaceImagesInBatch: Start index is out of range") - end_index = start_index + len(replacement_images) + end_index = start_index + len(replacement_images) * step if end_index > len(original_images): raise ValueError("ReplaceImagesInBatch: End index is out of range") @@ -2189,7 +2190,7 @@ with the replacement images. if original_images_copy.shape[2] != replacement_images.shape[2] or original_images_copy.shape[3] != replacement_images.shape[3]: replacement_images = common_upscale(replacement_images.movedim(-1, 1), original_images_copy.shape[1], original_images_copy.shape[2], "lanczos", "center").movedim(1, -1) - original_images_copy[start_index:end_index] = replacement_images + original_images_copy[start_index:end_index:step] = replacement_images images = original_images_copy else: images = torch.zeros((1, 64, 64, 3)) @@ -2197,7 +2198,7 @@ with the replacement images. if original_masks is not None and replacement_masks is not None: if start_index >= len(original_masks): raise ValueError("ReplaceImagesInBatch: Start index is out of range") - end_index = start_index + len(replacement_masks) + end_index = start_index + len(replacement_masks) * step if end_index > len(original_masks): raise ValueError("ReplaceImagesInBatch: End index is out of range") @@ -2205,7 +2206,7 @@ with the replacement images. if original_masks_copy.shape[1] != replacement_masks.shape[1] or original_masks_copy.shape[2] != replacement_masks.shape[2]: replacement_masks = common_upscale(replacement_masks.unsqueeze(1), original_masks_copy.shape[1], original_masks_copy.shape[2], "nearest-exact", "center").squeeze(0) - original_masks_copy[start_index:end_index] = replacement_masks + original_masks_copy[start_index:end_index:step] = replacement_masks masks = original_masks_copy else: masks = torch.zeros((1, 64, 64))