diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 7024ac5..06d546d 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -2172,10 +2172,14 @@ class LoadImagesFromFolderKJ: return { "required": { "folder": ("STRING", {"default": ""}), + "width": ("INT", {"default": 1024, "min": 64, "step": 1}), + "height": ("INT", {"default": 1024, "min": 64, "step": 1}), + "keep_aspect_ratio": (["crop", "pad", "stretch",],), }, "optional": { "image_load_cap": ("INT", {"default": 0, "min": 0, "step": 1}), "start_index": ("INT", {"default": 0, "min": 0, "step": 1}), + "include_subfolders": ("BOOLEAN", {"default": False}), } } @@ -2183,36 +2187,32 @@ class LoadImagesFromFolderKJ: RETURN_NAMES = ("image", "mask", "count", "image_path",) FUNCTION = "load_images" CATEGORY = "KJNodes/image" - DESCRIPTION = """Loads images from a folder into a batch, images are resized to match the largest image in the folder.""" + DESCRIPTION = """Loads images from a folder into a batch, images are resized and loaded into a batch.""" - def load_images(self, folder, image_load_cap, start_index): + def load_images(self, folder, width, height, image_load_cap, start_index, keep_aspect_ratio, include_subfolders=False): if not os.path.isdir(folder): raise FileNotFoundError(f"Folder '{folder} cannot be found.'") - dir_files = os.listdir(folder) + + valid_extensions = ['.jpg', '.jpeg', '.png', '.webp'] + image_paths = [] + if include_subfolders: + for root, _, files in os.walk(folder): + for file in files: + if any(file.lower().endswith(ext) for ext in valid_extensions): + image_paths.append(os.path.join(root, file)) + else: + for file in os.listdir(folder): + if any(file.lower().endswith(ext) for ext in valid_extensions): + image_paths.append(os.path.join(folder, file)) + + dir_files = sorted(image_paths) + if len(dir_files) == 0: raise FileNotFoundError(f"No files in directory '{folder}'.") - # Filter files by extension - valid_extensions = ['.jpg', '.jpeg', '.png', '.webp'] - dir_files = [f for f in dir_files if any(f.lower().endswith(ext) for ext in valid_extensions)] - - dir_files = sorted(dir_files) - dir_files = [os.path.join(folder, x) for x in dir_files] - # start at start_index dir_files = dir_files[start_index:] - # First pass - find maximum dimensions - max_width = 0 - max_height = 0 - for image_path in dir_files: - if os.path.isdir(image_path): - continue - with Image.open(image_path) as img: - width, height = img.size - max_width = max(max_width, width) - max_height = max(max_height, height) - images = [] masks = [] image_path_list = [] @@ -2231,8 +2231,9 @@ class LoadImagesFromFolderKJ: i = ImageOps.exif_transpose(i) # Resize image to maximum dimensions - if i.size != (max_width, max_height): - i = i.resize((max_width, max_height), Image.Resampling.LANCZOS) + if i.size != (width, height): + i = self.resize_with_aspect_ratio(i, width, height, keep_aspect_ratio) + image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -2241,13 +2242,13 @@ class LoadImagesFromFolderKJ: if 'A' in i.getbands(): mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) - if mask.shape != (max_height, max_width): + if mask.shape != (height, width): mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), - size=(max_height, max_width), + size=(height, width), mode='bilinear', align_corners=False).squeeze() else: - mask = torch.zeros((max_height, max_width), dtype=torch.float32, device="cpu") + mask = torch.zeros((height, width), dtype=torch.float32, device="cpu") images.append(image) masks.append(mask) @@ -2268,6 +2269,72 @@ class LoadImagesFromFolderKJ: mask1 = torch.cat((mask1, mask2.unsqueeze(0)), dim=0) return (image1, mask1, len(images), image_path_list) + def resize_with_aspect_ratio(self, img, width, height, mode): + if mode == "stretch": + return img.resize((width, height), Image.Resampling.LANCZOS) + + img_width, img_height = img.size + aspect_ratio = img_width / img_height + target_ratio = width / height + + if mode == "crop": + # Calculate dimensions for center crop + if aspect_ratio > target_ratio: + # Image is wider - crop width + new_width = int(height * aspect_ratio) + img = img.resize((new_width, height), Image.Resampling.LANCZOS) + left = (new_width - width) // 2 + return img.crop((left, 0, left + width, height)) + else: + # Image is taller - crop height + new_height = int(width / aspect_ratio) + img = img.resize((width, new_height), Image.Resampling.LANCZOS) + top = (new_height - height) // 2 + return img.crop((0, top, width, top + height)) + + elif mode == "pad": + pad_color = self.get_edge_color(img) + # Calculate dimensions for padding + if aspect_ratio > target_ratio: + # Image is wider - pad height + new_height = int(width / aspect_ratio) + img = img.resize((width, new_height), Image.Resampling.LANCZOS) + padding = (height - new_height) // 2 + padded = Image.new('RGBA', (width, height), pad_color) + padded.paste(img, (0, padding)) + return padded + else: + # Image is taller - pad width + new_width = int(height * aspect_ratio) + img = img.resize((new_width, height), Image.Resampling.LANCZOS) + padding = (width - new_width) // 2 + padded = Image.new('RGBA', (width, height), pad_color) + padded.paste(img, (padding, 0)) + return padded + def get_edge_color(self, img): + from PIL import ImageStat + """Sample edges and return dominant color""" + width, height = img.size + img = img.convert('RGBA') + + # Create 1-pixel high/wide images from edges + top = img.crop((0, 0, width, 1)) + bottom = img.crop((0, height-1, width, height)) + left = img.crop((0, 0, 1, height)) + right = img.crop((width-1, 0, width, height)) + + # Combine edges into single image + edges = Image.new('RGBA', (width*2 + height*2, 1)) + edges.paste(top, (0, 0)) + edges.paste(bottom, (width, 0)) + edges.paste(left.resize((height, 1)), (width*2, 0)) + edges.paste(right.resize((height, 1)), (width*2 + height, 0)) + + # Get median color + stat = ImageStat.Stat(edges) + median = tuple(map(int, stat.median)) + return median + class ImageGridtoBatch: @classmethod