From 87084633bef7c494705b101dd83956f95eb41430 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 21 Aug 2024 01:58:31 +0300 Subject: [PATCH] Add ImageConcatFromBatch --- __init__.py | 1 + nodes/image_nodes.py | 98 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/__init__.py b/__init__.py index 0609ca7..7881729 100644 --- a/__init__.py +++ b/__init__.py @@ -51,6 +51,7 @@ NODE_CONFIG = { "ImageBatchRepeatInterleaving": {"class": ImageBatchRepeatInterleaving}, "ImageBatchTestPattern": {"class": ImageBatchTestPattern, "name": "Image Batch Test Pattern"}, "ImageConcanate": {"class": ImageConcanate, "name": "Image Concatenate"}, + "ImageConcatFromBatch": {"class": ImageConcatFromBatch, "name": "Image Concatenate From Batch"}, "ImageConcatMulti": {"class": ImageConcatMulti, "name": "Image Concatenate Multi"}, "ImageGrabPIL": {"class": ImageGrabPIL, "name": "Image Grab PIL"}, "ImageGridComposite2x2": {"class": ImageGridComposite2x2, "name": "Image Grid Composite 2x2"}, diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index a3f660f..77344d0 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -211,7 +211,7 @@ class ImageConcanate: { "default": 'right' }), - "match_image_size": ("BOOLEAN", {"default": False}), + "match_image_size": ("BOOLEAN", {"default": True}), }} RETURN_TYPES = ("IMAGE",) @@ -274,6 +274,100 @@ Concatenates the image2 to image1 in the specified direction. elif direction == 'up': concatenated_image = torch.cat((image2_resized, image1), dim=1) # Concatenate along height return concatenated_image, + +import torch # Make sure you have PyTorch installed + +class ImageConcatFromBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "images": ("IMAGE",), + "num_columns": ("INT", {"default": 3, "min": 1, "max": 255, "step": 1}), + "match_image_size": ("BOOLEAN", {"default": False}), + "max_resolution": ("INT", {"default": 4096}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "concat" + CATEGORY = "KJNodes/image" + DESCRIPTION = """ + Concatenates images from a batch into a grid with a specified number of columns. + """ + + def concat(self, images, num_columns, match_image_size, max_resolution): + # Assuming images is a batch of images (B, H, W, C) + batch_size, height, width, channels = images.shape + num_rows = (batch_size + num_columns - 1) // num_columns # Calculate number of rows + + print(f"Initial dimensions: batch_size={batch_size}, height={height}, width={width}, channels={channels}") + print(f"num_rows={num_rows}, num_columns={num_columns}") + + if match_image_size: + target_shape = images[0].shape + + resized_images = [] + for image in images: + original_height = image.shape[0] + original_width = image.shape[1] + original_aspect_ratio = original_width / original_height + + if original_aspect_ratio > 1: + target_height = target_shape[0] + target_width = int(target_height * original_aspect_ratio) + else: + target_width = target_shape[1] + target_height = int(target_width / original_aspect_ratio) + + print(f"Resizing image from ({original_height}, {original_width}) to ({target_height}, {target_width})") + + # Resize the image to match the target size while preserving aspect ratio + resized_image = common_upscale(image.movedim(-1, 0), target_width, target_height, "lanczos", "disabled") + resized_image = resized_image.movedim(0, -1) # Move channels back to the last dimension + resized_images.append(resized_image) + + # Convert the list of resized images back to a tensor + images = torch.stack(resized_images) + + height, width = target_shape[:2] # Update height and width + + # Initialize an empty grid + grid_height = num_rows * height + grid_width = num_columns * width + + print(f"Grid dimensions before scaling: grid_height={grid_height}, grid_width={grid_width}") + + # Original scale factor calculation remains unchanged + scale_factor = min(max_resolution / grid_height, max_resolution / grid_width, 1.0) + + # Apply scale factor to height and width + scaled_height = height * scale_factor + scaled_width = width * scale_factor + + # Round scaled dimensions to the nearest number divisible by 8 + height = max(1, int(round(scaled_height / 8) * 8)) + width = max(1, int(round(scaled_width / 8) * 8)) + + if abs(scaled_height - height) > 4: + height = max(1, int(round((scaled_height + 4) / 8) * 8)) + if abs(scaled_width - width) > 4: + width = max(1, int(round((scaled_width + 4) / 8) * 8)) + + # Recalculate grid dimensions with adjusted height and width + grid_height = num_rows * height + grid_width = num_columns * width + print(f"Grid dimensions after scaling: grid_height={grid_height}, grid_width={grid_width}") + print(f"Final image dimensions: height={height}, width={width}") + + grid = torch.zeros((grid_height, grid_width, channels), dtype=images.dtype) + + for idx, image in enumerate(images): + resized_image = torch.nn.functional.interpolate(image.unsqueeze(0).permute(0, 3, 1, 2), size=(height, width), mode="bilinear").squeeze().permute(1, 2, 0) + row = idx // num_columns + col = idx % num_columns + grid[row*height:(row+1)*height, col*width:(col+1)*width, :] = resized_image + + return grid.unsqueeze(0), class ImageGridComposite2x2: @classmethod @@ -1778,4 +1872,4 @@ class LoadImagesFromFolderKJ: else: mask1 = torch.cat((mask1, mask2), dim=0) - return (image1, mask1, len(images), image_path_list) \ No newline at end of file + return (image1, mask1, len(images), image_path_list)