Add ImageConcatFromBatch

This commit is contained in:
kijai 2024-08-21 01:58:31 +03:00
parent 11c2155138
commit 87084633be
2 changed files with 97 additions and 2 deletions

View File

@ -51,6 +51,7 @@ NODE_CONFIG = {
"ImageBatchRepeatInterleaving": {"class": ImageBatchRepeatInterleaving},
"ImageBatchTestPattern": {"class": ImageBatchTestPattern, "name": "Image Batch Test Pattern"},
"ImageConcanate": {"class": ImageConcanate, "name": "Image Concatenate"},
"ImageConcatFromBatch": {"class": ImageConcatFromBatch, "name": "Image Concatenate From Batch"},
"ImageConcatMulti": {"class": ImageConcatMulti, "name": "Image Concatenate Multi"},
"ImageGrabPIL": {"class": ImageGrabPIL, "name": "Image Grab PIL"},
"ImageGridComposite2x2": {"class": ImageGridComposite2x2, "name": "Image Grid Composite 2x2"},

View File

@ -211,7 +211,7 @@ class ImageConcanate:
{
"default": 'right'
}),
"match_image_size": ("BOOLEAN", {"default": False}),
"match_image_size": ("BOOLEAN", {"default": True}),
}}
RETURN_TYPES = ("IMAGE",)
@ -274,6 +274,100 @@ Concatenates the image2 to image1 in the specified direction.
elif direction == 'up':
concatenated_image = torch.cat((image2_resized, image1), dim=1) # Concatenate along height
return concatenated_image,
import torch # Make sure you have PyTorch installed
class ImageConcatFromBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"images": ("IMAGE",),
"num_columns": ("INT", {"default": 3, "min": 1, "max": 255, "step": 1}),
"match_image_size": ("BOOLEAN", {"default": False}),
"max_resolution": ("INT", {"default": 4096}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "concat"
CATEGORY = "KJNodes/image"
DESCRIPTION = """
Concatenates images from a batch into a grid with a specified number of columns.
"""
def concat(self, images, num_columns, match_image_size, max_resolution):
# Assuming images is a batch of images (B, H, W, C)
batch_size, height, width, channels = images.shape
num_rows = (batch_size + num_columns - 1) // num_columns # Calculate number of rows
print(f"Initial dimensions: batch_size={batch_size}, height={height}, width={width}, channels={channels}")
print(f"num_rows={num_rows}, num_columns={num_columns}")
if match_image_size:
target_shape = images[0].shape
resized_images = []
for image in images:
original_height = image.shape[0]
original_width = image.shape[1]
original_aspect_ratio = original_width / original_height
if original_aspect_ratio > 1:
target_height = target_shape[0]
target_width = int(target_height * original_aspect_ratio)
else:
target_width = target_shape[1]
target_height = int(target_width / original_aspect_ratio)
print(f"Resizing image from ({original_height}, {original_width}) to ({target_height}, {target_width})")
# Resize the image to match the target size while preserving aspect ratio
resized_image = common_upscale(image.movedim(-1, 0), target_width, target_height, "lanczos", "disabled")
resized_image = resized_image.movedim(0, -1) # Move channels back to the last dimension
resized_images.append(resized_image)
# Convert the list of resized images back to a tensor
images = torch.stack(resized_images)
height, width = target_shape[:2] # Update height and width
# Initialize an empty grid
grid_height = num_rows * height
grid_width = num_columns * width
print(f"Grid dimensions before scaling: grid_height={grid_height}, grid_width={grid_width}")
# Original scale factor calculation remains unchanged
scale_factor = min(max_resolution / grid_height, max_resolution / grid_width, 1.0)
# Apply scale factor to height and width
scaled_height = height * scale_factor
scaled_width = width * scale_factor
# Round scaled dimensions to the nearest number divisible by 8
height = max(1, int(round(scaled_height / 8) * 8))
width = max(1, int(round(scaled_width / 8) * 8))
if abs(scaled_height - height) > 4:
height = max(1, int(round((scaled_height + 4) / 8) * 8))
if abs(scaled_width - width) > 4:
width = max(1, int(round((scaled_width + 4) / 8) * 8))
# Recalculate grid dimensions with adjusted height and width
grid_height = num_rows * height
grid_width = num_columns * width
print(f"Grid dimensions after scaling: grid_height={grid_height}, grid_width={grid_width}")
print(f"Final image dimensions: height={height}, width={width}")
grid = torch.zeros((grid_height, grid_width, channels), dtype=images.dtype)
for idx, image in enumerate(images):
resized_image = torch.nn.functional.interpolate(image.unsqueeze(0).permute(0, 3, 1, 2), size=(height, width), mode="bilinear").squeeze().permute(1, 2, 0)
row = idx // num_columns
col = idx % num_columns
grid[row*height:(row+1)*height, col*width:(col+1)*width, :] = resized_image
return grid.unsqueeze(0),
class ImageGridComposite2x2:
@classmethod
@ -1778,4 +1872,4 @@ class LoadImagesFromFolderKJ:
else:
mask1 = torch.cat((mask1, mask2), dim=0)
return (image1, mask1, len(images), image_path_list)
return (image1, mask1, len(images), image_path_list)