Better image concat size matching

This commit is contained in:
kijai 2024-07-07 15:10:02 +03:00
parent 3d7577f316
commit 296a1beb66

View File

@ -211,8 +211,8 @@ Concatenates the image2 to image1 in the specified direction.
def concanate(self, image1, image2, direction, match_image_size, first_image_shape=None):
# Check if the batch sizes are different
batch_size1 = image1.size(0)
batch_size2 = image2.size(0)
batch_size1 = image1.shape[0]
batch_size2 = image2.shape[0]
if batch_size1 != batch_size2:
# Calculate the number of repetitions needed
@ -223,21 +223,45 @@ Concatenates the image2 to image1 in the specified direction.
# Repeat the images to match the largest batch size
image1 = image1.repeat(repeats1, 1, 1, 1)
image2 = image2.repeat(repeats2, 1, 1, 1)
if match_image_size:
first_image_shape = first_image_shape if first_image_shape is not None else image1.shape
image2_resized = image2.movedim(-1,1)
image2_resized = common_upscale(image2_resized, first_image_shape[2], first_image_shape[1], "lanczos", "disabled").movedim(1,-1)
# Use first_image_shape if provided; otherwise, default to image1's shape
target_shape = first_image_shape if first_image_shape is not None else image1.shape
original_height = image2.shape[1]
original_width = image2.shape[2]
original_aspect_ratio = original_width / original_height
if direction in ['left', 'right']:
# Match the height and adjust the width to preserve aspect ratio
target_height = target_shape[1] # B, H, W, C format
target_width = int(target_height * original_aspect_ratio)
elif direction in ['up', 'down']:
# Match the width and adjust the height to preserve aspect ratio
target_width = target_shape[2] # B, H, W, C format
target_height = int(target_width / original_aspect_ratio)
# Adjust image2 to the expected format for common_upscale
image2_for_upscale = image2.movedim(-1, 1) # Move C to the second position (B, C, H, W)
# Resize image2 to match the target size while preserving aspect ratio
image2_resized = common_upscale(image2_for_upscale, target_width, target_height, "lanczos", "disabled")
# Adjust image2 back to the original format (B, H, W, C) after resizing
image2_resized = image2_resized.movedim(1, -1)
else:
image2_resized = image2
# Concatenate based on the specified direction
if direction == 'right':
row = torch.cat((image1, image2_resized), dim=2)
concatenated_image = torch.cat((image1, image2_resized), dim=2) # Concatenate along width
elif direction == 'down':
row = torch.cat((image1, image2_resized), dim=1)
concatenated_image = torch.cat((image1, image2_resized), dim=1) # Concatenate along height
elif direction == 'left':
row = torch.cat((image2_resized, image1), dim=2)
concatenated_image = torch.cat((image2_resized, image1), dim=2) # Concatenate along width
elif direction == 'up':
row = torch.cat((image2_resized, image1), dim=1)
return (row,)
concatenated_image = torch.cat((image2_resized, image1), dim=1) # Concatenate along height
return concatenated_image,
class ImageGridComposite2x2:
@classmethod