Update image_nodes.py

This commit is contained in:
Kijai 2024-05-31 14:09:36 +03:00
parent 2c7e8613e0
commit b6193451b3

View File

@ -210,6 +210,19 @@ Concatenates the image2 to image1 in the specified direction.
"""
def concanate(self, image1, image2, direction, match_image_size):
# Check if the batch sizes are different
batch_size1 = image1.size(0)
batch_size2 = image2.size(0)
if batch_size1 != batch_size2:
# Calculate the number of repetitions needed
max_batch_size = max(batch_size1, batch_size2)
repeats1 = max_batch_size // batch_size1
repeats2 = max_batch_size // batch_size2
# Repeat the images to match the largest batch size
image1 = image1.repeat(repeats1, 1, 1, 1)
image2 = image2.repeat(repeats2, 1, 1, 1)
if match_image_size:
image2 = torch.nn.functional.interpolate(image2, size=(image1.shape[2], image1.shape[3]), mode="bilinear")
if direction == 'right':