From d399bc559d78b19fcde88de91b646931dc53d97e Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 20 Apr 2025 14:17:44 +0300 Subject: [PATCH] Add ImageBatchFilter Node that removes "empty" frames from a batch, empty being a single color with threshold --- __init__.py | 1 + nodes/image_nodes.py | 45 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/__init__.py b/__init__.py index 9dd5ba8..92fd497 100644 --- a/__init__.py +++ b/__init__.py @@ -52,6 +52,7 @@ NODE_CONFIG = { "GetLatentRangeFromBatch": {"class": GetLatentRangeFromBatch, "name": "Get Latent Range From Batch"}, "GetImageSizeAndCount": {"class": GetImageSizeAndCount, "name": "Get Image Size & Count"}, "FastPreview": {"class": FastPreview, "name": "Fast Preview"}, + "ImageBatchFilter": {"class": ImageBatchFilter, "name": "Image Batch Filter"}, "ImageAndMaskPreview": {"class": ImageAndMaskPreview}, "ImageAddMulti": {"class": ImageAddMulti, "name": "Image Add Multi"}, "ImageBatchMulti": {"class": ImageBatchMulti, "name": "Image Batch Multi"}, diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index 7b987c7..1350fc4 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -1790,6 +1790,51 @@ Inserts a latent at the specified index into the original latent batch. ], dim=2) return ({"samples": joined_latents,},) + +class ImageBatchFilter: + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "filter" + CATEGORY = "KJNodes/image" + DESCRIPTION = "Removes empty images from a batch" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ("IMAGE",), + "empty_color": ("STRING", {"default": "0, 0, 0"}), + "empty_threshold": ("FLOAT", {"default": 0.01,"min": 0.0, "max": 1.0, "step": 0.01}), + }, + "optional": { + "replacement_image": ("IMAGE",), + } + } + + def filter(self, images, empty_color, empty_threshold, replacement_image=None): + B, H, W, C = images.shape + + input_images = images.clone() + + empty_color_list = [int(color.strip()) for color in empty_color.split(',')] + empty_color_tensor = torch.tensor(empty_color_list, dtype=torch.float32).to(input_images.device) + + color_diff = torch.abs(input_images - empty_color_tensor) + mean_diff = color_diff.mean(dim=(1, 2, 3)) + + empty_indices = mean_diff <= empty_threshold + + if replacement_image is not None: + B_rep, H_rep, W_rep, C_rep = replacement_image.shape + replacement = replacement_image.clone() + if (H_rep != images.shape[1]) or (W_rep != images.shape[2]) or (C_rep != images.shape[3]): + replacement = common_upscale(replacement.movedim(-1, 1), W, H, "lanczos", "center").movedim(1, -1) + input_images[empty_indices] = replacement[0] + + return (input_images,) + else: + non_empty_images = input_images[~empty_indices] + return (non_empty_images,) class GetImagesFromBatchIndexed: