diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 08e12ef..1e45b7e 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -217,13 +217,13 @@ class ImageConcanate: }} RETURN_TYPES = ("IMAGE",) - FUNCTION = "concanate" + FUNCTION = "concatenate" CATEGORY = "KJNodes/image" DESCRIPTION = """ Concatenates the image2 to image1 in the specified direction. """ - def concanate(self, image1, image2, direction, match_image_size, first_image_shape=None): + def concatenate(self, image1, image2, direction, match_image_size, first_image_shape=None): # Check if the batch sizes are different batch_size1 = image1.shape[0] batch_size2 = image2.shape[0] @@ -231,12 +231,16 @@ Concatenates the image2 to image1 in the specified direction. if batch_size1 != batch_size2: # Calculate the number of repetitions needed max_batch_size = max(batch_size1, batch_size2) - repeats1 = max_batch_size // batch_size1 - repeats2 = max_batch_size // batch_size2 + repeats1 = max_batch_size - batch_size1 + repeats2 = max_batch_size - batch_size2 - # Repeat the images to match the largest batch size - image1 = image1.repeat(repeats1, 1, 1, 1) - image2 = image2.repeat(repeats2, 1, 1, 1) + # Repeat the last image to match the largest batch size + if repeats1 > 0: + last_image1 = image1[-1].unsqueeze(0).repeat(repeats1, 1, 1, 1) + image1 = torch.cat([image1, last_image1], dim=0) + if repeats2 > 0: + last_image2 = image2[-1].unsqueeze(0).repeat(repeats2, 1, 1, 1) + image2 = torch.cat([image2, last_image2], dim=0) if match_image_size: # Use first_image_shape if provided; otherwise, default to image1's shape @@ -1845,7 +1849,7 @@ with the **inputcount** and clicking update. first_image_shape = image.shape for c in range(1, inputcount): new_image = kwargs[f"image_{c + 1}"] - image, = ImageConcanate.concanate(self, image, new_image, direction, match_image_size, first_image_shape=first_image_shape) + image, = ImageConcanate.concatenate(self, image, new_image, direction, match_image_size, first_image_shape=first_image_shape) first_image_shape = None return (image,) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 43fb7de..410268e 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -475,6 +475,7 @@ class TorchCompileCosmosModel: "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), + "dynamo_cache_size_limit": ("INT", {"default": 64, "tooltip": "Set the dynamo cache size limit"}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" @@ -482,10 +483,11 @@ class TorchCompileCosmosModel: CATEGORY = "KJNodes/experimental" EXPERIMENTAL = True - def patch(self, model, backend, mode, fullgraph, dynamic): + def patch(self, model, backend, mode, fullgraph, dynamic, dynamo_cache_size_limit): m = model.clone() diffusion_model = m.get_model_object("diffusion_model") + torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit if not self._compiled: try: