Allow images with alpha to be concatenated, improve ImageGridtoBatch

This commit is contained in:
kijai 2024-09-29 17:17:18 +03:00
parent c31fa9f438
commit b5419c853c

View File

@ -264,6 +264,21 @@ Concatenates the image2 to image1 in the specified direction.
else:
image2_resized = image2
# Ensure both images have the same number of channels
channels_image1 = image1.shape[-1]
channels_image2 = image2_resized.shape[-1]
if channels_image1 != channels_image2:
if channels_image1 < channels_image2:
# Add alpha channel to image1 if image2 has it
alpha_channel = torch.ones((*image1.shape[:-1], channels_image2 - channels_image1), device=image1.device)
image1 = torch.cat((image1, alpha_channel), dim=-1)
else:
# Add alpha channel to image2 if image1 has it
alpha_channel = torch.ones((*image2_resized.shape[:-1], channels_image1 - channels_image2), device=image2_resized.device)
image2_resized = torch.cat((image2_resized, alpha_channel), dim=-1)
# Concatenate based on the specified direction
if direction == 'right':
concatenated_image = torch.cat((image1, image2_resized), dim=2) # Concatenate along width
@ -1915,8 +1930,9 @@ class ImageGridtoBatch:
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE", ),
"columns": ("INT", {"default": 3, "min": 2, "max": 8, "tooltip": "The number of columns in the grid."}),
}
"columns": ("INT", {"default": 3, "min": 1, "max": 8, "tooltip": "The number of columns in the grid."}),
"rows": ("INT", {"default": 0, "min": 1, "max": 8, "tooltip": "The number of rows in the grid. Set to 0 for automatic calculation."}),
}
}
RETURN_TYPES = ("IMAGE",)
@ -1924,20 +1940,32 @@ class ImageGridtoBatch:
CATEGORY = "KJNodes/image"
DESCRIPTION = "Converts a grid of images to a batch of images."
def decompose(self, image, columns):
B, H, W, C = image.shape
orig_h = H // columns
orig_w = W // columns
# Reshape and permute the image to get the grid
image = image.view(B, columns, orig_h, columns, orig_w, C)
image = image.permute(0, 1, 3, 2, 4, 5).contiguous()
image = image.view(B, columns * columns, orig_h, orig_w, C)
# Reshape to the final batch tensor
img_tensor = image.view(-1, orig_h, orig_w, C)
return img_tensor,
def decompose(self, image, columns, rows):
B, H, W, C = image.shape
print("input size: ", image.shape)
# Calculate cell width, rounding down
cell_width = W // columns
if rows == 0:
# If rows is 0, calculate number of full rows
rows = H // cell_height
else:
# If rows is specified, adjust cell_height
cell_height = H // rows
# Crop the image to fit full cells
image = image[:, :rows*cell_height, :columns*cell_width, :]
# Reshape and permute the image to get the grid
image = image.view(B, rows, cell_height, columns, cell_width, C)
image = image.permute(0, 1, 3, 2, 4, 5).contiguous()
image = image.view(B, rows * columns, cell_height, cell_width, C)
# Reshape to the final batch tensor
img_tensor = image.view(-1, cell_height, cell_width, C)
return (img_tensor,)
class SaveImageKJ:
def __init__(self):