mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-04-05 05:46:58 +08:00
Add ImageConcatFromBatch
This commit is contained in:
parent
11c2155138
commit
87084633be
@ -51,6 +51,7 @@ NODE_CONFIG = {
|
||||
"ImageBatchRepeatInterleaving": {"class": ImageBatchRepeatInterleaving},
|
||||
"ImageBatchTestPattern": {"class": ImageBatchTestPattern, "name": "Image Batch Test Pattern"},
|
||||
"ImageConcanate": {"class": ImageConcanate, "name": "Image Concatenate"},
|
||||
"ImageConcatFromBatch": {"class": ImageConcatFromBatch, "name": "Image Concatenate From Batch"},
|
||||
"ImageConcatMulti": {"class": ImageConcatMulti, "name": "Image Concatenate Multi"},
|
||||
"ImageGrabPIL": {"class": ImageGrabPIL, "name": "Image Grab PIL"},
|
||||
"ImageGridComposite2x2": {"class": ImageGridComposite2x2, "name": "Image Grid Composite 2x2"},
|
||||
|
||||
@ -211,7 +211,7 @@ class ImageConcanate:
|
||||
{
|
||||
"default": 'right'
|
||||
}),
|
||||
"match_image_size": ("BOOLEAN", {"default": False}),
|
||||
"match_image_size": ("BOOLEAN", {"default": True}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
@ -274,6 +274,100 @@ Concatenates the image2 to image1 in the specified direction.
|
||||
elif direction == 'up':
|
||||
concatenated_image = torch.cat((image2_resized, image1), dim=1) # Concatenate along height
|
||||
return concatenated_image,
|
||||
|
||||
import torch # Make sure you have PyTorch installed
|
||||
|
||||
class ImageConcatFromBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"images": ("IMAGE",),
|
||||
"num_columns": ("INT", {"default": 3, "min": 1, "max": 255, "step": 1}),
|
||||
"match_image_size": ("BOOLEAN", {"default": False}),
|
||||
"max_resolution": ("INT", {"default": 4096}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "concat"
|
||||
CATEGORY = "KJNodes/image"
|
||||
DESCRIPTION = """
|
||||
Concatenates images from a batch into a grid with a specified number of columns.
|
||||
"""
|
||||
|
||||
def concat(self, images, num_columns, match_image_size, max_resolution):
|
||||
# Assuming images is a batch of images (B, H, W, C)
|
||||
batch_size, height, width, channels = images.shape
|
||||
num_rows = (batch_size + num_columns - 1) // num_columns # Calculate number of rows
|
||||
|
||||
print(f"Initial dimensions: batch_size={batch_size}, height={height}, width={width}, channels={channels}")
|
||||
print(f"num_rows={num_rows}, num_columns={num_columns}")
|
||||
|
||||
if match_image_size:
|
||||
target_shape = images[0].shape
|
||||
|
||||
resized_images = []
|
||||
for image in images:
|
||||
original_height = image.shape[0]
|
||||
original_width = image.shape[1]
|
||||
original_aspect_ratio = original_width / original_height
|
||||
|
||||
if original_aspect_ratio > 1:
|
||||
target_height = target_shape[0]
|
||||
target_width = int(target_height * original_aspect_ratio)
|
||||
else:
|
||||
target_width = target_shape[1]
|
||||
target_height = int(target_width / original_aspect_ratio)
|
||||
|
||||
print(f"Resizing image from ({original_height}, {original_width}) to ({target_height}, {target_width})")
|
||||
|
||||
# Resize the image to match the target size while preserving aspect ratio
|
||||
resized_image = common_upscale(image.movedim(-1, 0), target_width, target_height, "lanczos", "disabled")
|
||||
resized_image = resized_image.movedim(0, -1) # Move channels back to the last dimension
|
||||
resized_images.append(resized_image)
|
||||
|
||||
# Convert the list of resized images back to a tensor
|
||||
images = torch.stack(resized_images)
|
||||
|
||||
height, width = target_shape[:2] # Update height and width
|
||||
|
||||
# Initialize an empty grid
|
||||
grid_height = num_rows * height
|
||||
grid_width = num_columns * width
|
||||
|
||||
print(f"Grid dimensions before scaling: grid_height={grid_height}, grid_width={grid_width}")
|
||||
|
||||
# Original scale factor calculation remains unchanged
|
||||
scale_factor = min(max_resolution / grid_height, max_resolution / grid_width, 1.0)
|
||||
|
||||
# Apply scale factor to height and width
|
||||
scaled_height = height * scale_factor
|
||||
scaled_width = width * scale_factor
|
||||
|
||||
# Round scaled dimensions to the nearest number divisible by 8
|
||||
height = max(1, int(round(scaled_height / 8) * 8))
|
||||
width = max(1, int(round(scaled_width / 8) * 8))
|
||||
|
||||
if abs(scaled_height - height) > 4:
|
||||
height = max(1, int(round((scaled_height + 4) / 8) * 8))
|
||||
if abs(scaled_width - width) > 4:
|
||||
width = max(1, int(round((scaled_width + 4) / 8) * 8))
|
||||
|
||||
# Recalculate grid dimensions with adjusted height and width
|
||||
grid_height = num_rows * height
|
||||
grid_width = num_columns * width
|
||||
print(f"Grid dimensions after scaling: grid_height={grid_height}, grid_width={grid_width}")
|
||||
print(f"Final image dimensions: height={height}, width={width}")
|
||||
|
||||
grid = torch.zeros((grid_height, grid_width, channels), dtype=images.dtype)
|
||||
|
||||
for idx, image in enumerate(images):
|
||||
resized_image = torch.nn.functional.interpolate(image.unsqueeze(0).permute(0, 3, 1, 2), size=(height, width), mode="bilinear").squeeze().permute(1, 2, 0)
|
||||
row = idx // num_columns
|
||||
col = idx % num_columns
|
||||
grid[row*height:(row+1)*height, col*width:(col+1)*width, :] = resized_image
|
||||
|
||||
return grid.unsqueeze(0),
|
||||
|
||||
class ImageGridComposite2x2:
|
||||
@classmethod
|
||||
@ -1778,4 +1872,4 @@ class LoadImagesFromFolderKJ:
|
||||
else:
|
||||
mask1 = torch.cat((mask1, mask2), dim=0)
|
||||
|
||||
return (image1, mask1, len(images), image_path_list)
|
||||
return (image1, mask1, len(images), image_path_list)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user