diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 12d6990..e3a1809 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -787,7 +787,7 @@ and passes it through unchanged. class ImageBatchRepeatInterleaving: - RETURN_TYPES = ("IMAGE",) + RETURN_TYPES = ("IMAGE", "MASK",) FUNCTION = "repeat" CATEGORY = "KJNodes/image" DESCRIPTION = """ @@ -802,13 +802,20 @@ with repeats 2 becomes batch of 10 images: 0, 0, 1, 1, 2, 2, 3, 3, 4, 4 "required": { "images": ("IMAGE",), "repeats": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - } + }, + "optional": { + "mask": ("MASK",), + } + } - def repeat(self, images, repeats): + def repeat(self, images, repeats, mask=None): repeated_images = torch.repeat_interleave(images, repeats=repeats, dim=0) - return (repeated_images, ) + if mask is not None: + mask = torch.repeat_interleave(mask, repeats=repeats, dim=0) + else: + mask = torch.zeros_like(repeated_images[:, 0:1, :, :]) + return (repeated_images, mask) class ImageUpscaleWithModelBatched: @classmethod