diff --git a/__init__.py b/__init__.py index 7881729..7aa99e5 100644 --- a/__init__.py +++ b/__init__.py @@ -56,6 +56,7 @@ NODE_CONFIG = { "ImageGrabPIL": {"class": ImageGrabPIL, "name": "Image Grab PIL"}, "ImageGridComposite2x2": {"class": ImageGridComposite2x2, "name": "Image Grid Composite 2x2"}, "ImageGridComposite3x3": {"class": ImageGridComposite3x3, "name": "Image Grid Composite 3x3"}, + "ImageGridtoBatch": {"class": ImageGridtoBatch, "name": "Image Grid To Batch"}, "ImageNormalize_Neg1_To_1": {"class": ImageNormalize_Neg1_To_1, "name": "Image Normalize -1 to 1"}, "ImagePass": {"class": ImagePass}, "ImagePadForOutpaintMasked": {"class": ImagePadForOutpaintMasked, "name": "Image Pad For Outpaint Masked"}, diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 06dec4b..1701752 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -1910,3 +1910,32 @@ class LoadImagesFromFolderKJ: mask1 = torch.cat((mask1, mask2), dim=0) return (image1, mask1, len(images), image_path_list) + +class ImageGridtoBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE", ), + "columns": ("INT", {"default": 3, "min": 2, "max": 8, "tooltip": "The number of columns in the grid."}), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "decompose" + CATEGORY = "KJNodes/image" + DESCRIPTION = "Converts a grid of images to a batch of images." + + def decompose(self, image, columns): + B, H, W, C = image.shape + orig_h = H // columns + orig_w = W // columns + + # Reshape and permute the image to get the grid + image = image.view(B, columns, orig_h, columns, orig_w, C) + image = image.permute(0, 1, 3, 2, 4, 5).contiguous() + image = image.view(B, columns * columns, orig_h, orig_w, C) + + # Reshape to the final batch tensor + img_tensor = image.view(-1, orig_h, orig_w, C) + + return img_tensor, \ No newline at end of file