diff --git a/__init__.py b/__init__.py index 6d891b6..9dd5ba8 100644 --- a/__init__.py +++ b/__init__.py @@ -82,6 +82,7 @@ NODE_CONFIG = { "LoadAndResizeImage": {"class": LoadAndResizeImage, "name": "Load & Resize Image"}, "LoadImagesFromFolderKJ": {"class": LoadImagesFromFolderKJ, "name": "Load Images From Folder (KJ)"}, "MergeImageChannels": {"class": MergeImageChannels, "name": "Merge Image Channels"}, + "PadImageBatchInterleaved": {"class": PadImageBatchInterleaved, "name": "Pad Image Batch Interleaved"}, "PreviewAnimation": {"class": PreviewAnimation, "name": "Preview Animation"}, "RemapImageRange": {"class": RemapImageRange, "name": "Remap Image Range"}, "ReverseImageBatch": {"class": ReverseImageBatch, "name": "Reverse Image Batch"}, diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index a87a698..9df8d02 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -1122,6 +1122,9 @@ class ImagePrepForICLora: print("Warning: The incoming mask is fully black. Handling it as None.") reference_mask = None image = reference_image + if latent_image is not None: + if image.shape[0] != latent_image.shape[0]: + image = image.repeat(latent_image.shape[0], 1, 1, 1) B, H, W, C = image.size() # Handle mask @@ -1849,6 +1852,52 @@ Inserts images at the specified indices into the original image batch. return (original_images,) +class PadImageBatchInterleaved: + + RETURN_TYPES = ("IMAGE", "MASK",) + RETURN_NAMES = ("images", "masks",) + FUNCTION = "pad" + CATEGORY = "KJNodes/image" + DESCRIPTION = """ +Inserts empty frames between the images in a batch. +""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ("IMAGE",), + "empty_frames_per_image": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}), + }, + } + + def pad(self, images, empty_frames_per_image): + B, H, W, C = images.shape + + if B == 1 or empty_frames_per_image == 0: + return (images,) + + # Original B images + (B-1) sets of empty frames between them + total_frames = B + (B-1) * empty_frames_per_image + + # Create new tensor with zeros (empty frames) + padded_batch = torch.zeros((total_frames, H, W, C), + dtype=images.dtype, + device=images.device) + # Create mask tensor (1 for original frames, 0 for empty frames) + mask = torch.zeros((total_frames, H, W), + dtype=images.dtype, + device=images.device) + + # Fill in original images at their new positions + for i in range(B): + # Each image is separated by empty_frames_per_image blank frames + new_pos = i * (empty_frames_per_image + 1) + padded_batch[new_pos] = images[i] + mask[new_pos] = 1.0 # Mark this as an original frame + + return (padded_batch, mask) + class ReplaceImagesInBatch: RETURN_TYPES = ("IMAGE",) @@ -3147,11 +3196,14 @@ class ImagePadKJ: out_image[b, :, :, :] = bg_color.unsqueeze(0).unsqueeze(0) # Expand for H and W dimensions out_image[b, pad_top:pad_top+H, pad_left:pad_left+W, :] = image[b] + if mask is not None: out_masks = torch.zeros((BM, padded_height, padded_width), dtype=mask.dtype, device=mask.device) for m in range(BM): out_masks[m, pad_top:pad_top+H, pad_left:pad_left+W] = mask[m] else: - out_masks = torch.zeros((1, padded_height, padded_width), dtype=image.dtype, device=image.device) + out_masks = torch.ones((B, padded_height, padded_width), dtype=image.dtype, device=image.device) + for m in range(B): + out_masks[m, pad_top:pad_top+H, pad_left:pad_left+W] = 0.0 return (out_image, out_masks)