From b6193451b3afb61a80f045c6b5022faa91c4b7f1 Mon Sep 17 00:00:00 2001 From: Kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 31 May 2024 14:09:36 +0300 Subject: [PATCH] Update image_nodes.py --- nodes/image_nodes.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 01ffbe2..f25e796 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -210,6 +210,19 @@ Concatenates the image2 to image1 in the specified direction. """ def concanate(self, image1, image2, direction, match_image_size): + # Check if the batch sizes are different + batch_size1 = image1.size(0) + batch_size2 = image2.size(0) + + if batch_size1 != batch_size2: + # Calculate the number of repetitions needed + max_batch_size = max(batch_size1, batch_size2) + repeats1 = max_batch_size // batch_size1 + repeats2 = max_batch_size // batch_size2 + + # Repeat the images to match the largest batch size + image1 = image1.repeat(repeats1, 1, 1, 1) + image2 = image2.repeat(repeats2, 1, 1, 1) if match_image_size: image2 = torch.nn.functional.interpolate(image2, size=(image1.shape[2], image1.shape[3]), mode="bilinear") if direction == 'right':