diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 1ccdbf5..a16a741 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -211,8 +211,8 @@ Concatenates the image2 to image1 in the specified direction. def concanate(self, image1, image2, direction, match_image_size, first_image_shape=None): # Check if the batch sizes are different - batch_size1 = image1.size(0) - batch_size2 = image2.size(0) + batch_size1 = image1.shape[0] + batch_size2 = image2.shape[0] if batch_size1 != batch_size2: # Calculate the number of repetitions needed @@ -223,21 +223,45 @@ Concatenates the image2 to image1 in the specified direction. # Repeat the images to match the largest batch size image1 = image1.repeat(repeats1, 1, 1, 1) image2 = image2.repeat(repeats2, 1, 1, 1) + if match_image_size: - first_image_shape = first_image_shape if first_image_shape is not None else image1.shape - image2_resized = image2.movedim(-1,1) - image2_resized = common_upscale(image2_resized, first_image_shape[2], first_image_shape[1], "lanczos", "disabled").movedim(1,-1) + # Use first_image_shape if provided; otherwise, default to image1's shape + target_shape = first_image_shape if first_image_shape is not None else image1.shape + + original_height = image2.shape[1] + original_width = image2.shape[2] + original_aspect_ratio = original_width / original_height + + if direction in ['left', 'right']: + # Match the height and adjust the width to preserve aspect ratio + target_height = target_shape[1] # B, H, W, C format + target_width = int(target_height * original_aspect_ratio) + elif direction in ['up', 'down']: + # Match the width and adjust the height to preserve aspect ratio + target_width = target_shape[2] # B, H, W, C format + target_height = int(target_width / original_aspect_ratio) + + # Adjust image2 to the expected format for common_upscale + image2_for_upscale = image2.movedim(-1, 1) # Move C to the second position (B, C, H, W) + + # Resize image2 to match the target size while preserving aspect ratio + image2_resized = common_upscale(image2_for_upscale, target_width, target_height, "lanczos", "disabled") + + # Adjust image2 back to the original format (B, H, W, C) after resizing + image2_resized = image2_resized.movedim(1, -1) else: image2_resized = image2 + + # Concatenate based on the specified direction if direction == 'right': - row = torch.cat((image1, image2_resized), dim=2) + concatenated_image = torch.cat((image1, image2_resized), dim=2) # Concatenate along width elif direction == 'down': - row = torch.cat((image1, image2_resized), dim=1) + concatenated_image = torch.cat((image1, image2_resized), dim=1) # Concatenate along height elif direction == 'left': - row = torch.cat((image2_resized, image1), dim=2) + concatenated_image = torch.cat((image2_resized, image1), dim=2) # Concatenate along width elif direction == 'up': - row = torch.cat((image2_resized, image1), dim=1) - return (row,) + concatenated_image = torch.cat((image2_resized, image1), dim=1) # Concatenate along height + return concatenated_image, class ImageGridComposite2x2: @classmethod