diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index f2ea3ec..7024ac5 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -2182,8 +2182,8 @@ class LoadImagesFromFolderKJ: RETURN_TYPES = ("IMAGE", "MASK", "INT", "STRING",) RETURN_NAMES = ("image", "mask", "count", "image_path",) FUNCTION = "load_images" - - CATEGORY = "image" + CATEGORY = "KJNodes/image" + DESCRIPTION = """Loads images from a folder into a batch, images are resized to match the largest image in the folder.""" def load_images(self, folder, image_load_cap, start_index): if not os.path.isdir(folder): @@ -2202,6 +2202,17 @@ class LoadImagesFromFolderKJ: # start at start_index dir_files = dir_files[start_index:] + # First pass - find maximum dimensions + max_width = 0 + max_height = 0 + for image_path in dir_files: + if os.path.isdir(image_path): + continue + with Image.open(image_path) as img: + width, height = img.size + max_width = max(max_width, width) + max_height = max(max_height, height) + images = [] masks = [] image_path_list = [] @@ -2211,24 +2222,33 @@ class LoadImagesFromFolderKJ: limit_images = True image_count = 0 - has_non_empty_mask = False - for image_path in dir_files: - if os.path.isdir(image_path) and os.path.ex: + if os.path.isdir(image_path): continue if limit_images and image_count >= image_load_cap: break i = Image.open(image_path) i = ImageOps.exif_transpose(i) + + # Resize image to maximum dimensions + if i.size != (max_width, max_height): + i = i.resize((max_width, max_height), Image.Resampling.LANCZOS) + image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) - has_non_empty_mask = True + if mask.shape != (max_height, max_width): + mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), + size=(max_height, max_width), + mode='bilinear', + align_corners=False).squeeze() else: - mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") + mask = torch.zeros((max_height, max_width), dtype=torch.float32, device="cpu") + images.append(image) masks.append(mask) image_path_list.append(image_path) @@ -2236,30 +2256,16 @@ class LoadImagesFromFolderKJ: if len(images) == 1: return (images[0], masks[0], 1, image_path_list) - + elif len(images) > 1: image1 = images[0] - mask1 = None + mask1 = masks[0].unsqueeze(0) for image2 in images[1:]: - if image1.shape[1:] != image2.shape[1:]: - image2 = common_upscale(image2.movedim(-1, 1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1, -1) image1 = torch.cat((image1, image2), dim=0) for mask2 in masks[1:]: - if has_non_empty_mask: - if image1.shape[1:3] != mask2.shape: - mask2 = torch.nn.functional.interpolate(mask2.unsqueeze(0).unsqueeze(0), size=(image1.shape[2], image1.shape[1]), mode='bilinear', align_corners=False) - mask2 = mask2.squeeze(0) - else: - mask2 = mask2.unsqueeze(0) - else: - mask2 = mask2.unsqueeze(0) - - if mask1 is None: - mask1 = mask2 - else: - mask1 = torch.cat((mask1, mask2), dim=0) + mask1 = torch.cat((mask1, mask2.unsqueeze(0)), dim=0) return (image1, mask1, len(images), image_path_list) diff --git a/nodes/nodes.py b/nodes/nodes.py index 4087aae..0eab9e9 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -166,10 +166,11 @@ Selects and returns the latents at the specified indices as an latent batch. "required": { "latents": ("LATENT",), "indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}), + "latent_format": (["BCHW", "BTCHW", "BCTHW"], {"default": "BCHW"}), }, } - def indexedlatentsfrombatch(self, latents, indexes): + def indexedlatentsfrombatch(self, latents, indexes, latent_format): samples = latents.copy() latent_samples = samples["samples"] @@ -181,7 +182,12 @@ Selects and returns the latents at the specified indices as an latent batch. indices_tensor = torch.tensor(index_list, dtype=torch.long) # Select the latents at the specified indices - chosen_latents = latent_samples[indices_tensor] + if latent_format == "BCHW": + chosen_latents = latent_samples[indices_tensor] + elif latent_format == "BTCHW": + chosen_latents = latent_samples[:, indices_tensor] + elif latent_format == "BCTHW": + chosen_latents = latent_samples[:, :, indices_tensor] samples["samples"] = chosen_latents return (samples,) @@ -2216,4 +2222,5 @@ Concatenates the audio1 to audio2 in the specified direction. concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2) # Concatenate along width elif direction == 'left': concatenated_audio= torch.cat((waveform_2, waveform_1), dim=2) # Concatenate along width - return ({"waveform": concatenated_audio, "sample_rate": sample_rate_1},) \ No newline at end of file + return ({"waveform": concatenated_audio, "sample_rate": sample_rate_1},) + diff --git a/web/js/contextmenu.js b/web/js/contextmenu.js index 6e39c01..3b1670e 100644 --- a/web/js/contextmenu.js +++ b/web/js/contextmenu.js @@ -95,7 +95,6 @@ app.registerExtension({ app.ui.settings.addSetting({ id: "KJNodes.nodeAutoColor", name: "KJNodes: Automatically set node colors", - defaultValue: true, type: "boolean", defaultValue: true, });