From b5419c853c738a11a86da80a52f41a3efb759daa Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 29 Sep 2024 17:17:18 +0300 Subject: [PATCH] Allow images with alpha to be concatenated, improve ImageGridtoBatch --- nodes/image_nodes.py | 60 ++++++++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index de18ea8..00e8bbd 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -264,6 +264,21 @@ Concatenates the image2 to image1 in the specified direction. else: image2_resized = image2 + # Ensure both images have the same number of channels + channels_image1 = image1.shape[-1] + channels_image2 = image2_resized.shape[-1] + + if channels_image1 != channels_image2: + if channels_image1 < channels_image2: + # Add alpha channel to image1 if image2 has it + alpha_channel = torch.ones((*image1.shape[:-1], channels_image2 - channels_image1), device=image1.device) + image1 = torch.cat((image1, alpha_channel), dim=-1) + else: + # Add alpha channel to image2 if image1 has it + alpha_channel = torch.ones((*image2_resized.shape[:-1], channels_image1 - channels_image2), device=image2_resized.device) + image2_resized = torch.cat((image2_resized, alpha_channel), dim=-1) + + # Concatenate based on the specified direction if direction == 'right': concatenated_image = torch.cat((image1, image2_resized), dim=2) # Concatenate along width @@ -1915,8 +1930,9 @@ class ImageGridtoBatch: def INPUT_TYPES(s): return {"required": { "image": ("IMAGE", ), - "columns": ("INT", {"default": 3, "min": 2, "max": 8, "tooltip": "The number of columns in the grid."}), - } + "columns": ("INT", {"default": 3, "min": 1, "max": 8, "tooltip": "The number of columns in the grid."}), + "rows": ("INT", {"default": 0, "min": 1, "max": 8, "tooltip": "The number of rows in the grid. Set to 0 for automatic calculation."}), + } } RETURN_TYPES = ("IMAGE",) @@ -1924,20 +1940,32 @@ class ImageGridtoBatch: CATEGORY = "KJNodes/image" DESCRIPTION = "Converts a grid of images to a batch of images." - def decompose(self, image, columns): - B, H, W, C = image.shape - orig_h = H // columns - orig_w = W // columns - - # Reshape and permute the image to get the grid - image = image.view(B, columns, orig_h, columns, orig_w, C) - image = image.permute(0, 1, 3, 2, 4, 5).contiguous() - image = image.view(B, columns * columns, orig_h, orig_w, C) - - # Reshape to the final batch tensor - img_tensor = image.view(-1, orig_h, orig_w, C) - - return img_tensor, + def decompose(self, image, columns, rows): + B, H, W, C = image.shape + print("input size: ", image.shape) + + # Calculate cell width, rounding down + cell_width = W // columns + + if rows == 0: + # If rows is 0, calculate number of full rows + rows = H // cell_height + else: + # If rows is specified, adjust cell_height + cell_height = H // rows + + # Crop the image to fit full cells + image = image[:, :rows*cell_height, :columns*cell_width, :] + + # Reshape and permute the image to get the grid + image = image.view(B, rows, cell_height, columns, cell_width, C) + image = image.permute(0, 1, 3, 2, 4, 5).contiguous() + image = image.view(B, rows * columns, cell_height, cell_width, C) + + # Reshape to the final batch tensor + img_tensor = image.view(-1, cell_height, cell_width, C) + + return (img_tensor,) class SaveImageKJ: def __init__(self):