Add ImageBatchFilter

Node that removes "empty" frames from a batch, empty being a single color with threshold
This commit is contained in:
kijai 2025-04-20 14:17:44 +03:00
parent d41ad755ef
commit d399bc559d
2 changed files with 46 additions and 0 deletions

View File

@ -52,6 +52,7 @@ NODE_CONFIG = {
"GetLatentRangeFromBatch": {"class": GetLatentRangeFromBatch, "name": "Get Latent Range From Batch"},
"GetImageSizeAndCount": {"class": GetImageSizeAndCount, "name": "Get Image Size & Count"},
"FastPreview": {"class": FastPreview, "name": "Fast Preview"},
"ImageBatchFilter": {"class": ImageBatchFilter, "name": "Image Batch Filter"},
"ImageAndMaskPreview": {"class": ImageAndMaskPreview},
"ImageAddMulti": {"class": ImageAddMulti, "name": "Image Add Multi"},
"ImageBatchMulti": {"class": ImageBatchMulti, "name": "Image Batch Multi"},

View File

@ -1791,6 +1791,51 @@ Inserts a latent at the specified index into the original latent batch.
return ({"samples": joined_latents,},)
class ImageBatchFilter:
RETURN_TYPES = ("IMAGE",)
FUNCTION = "filter"
CATEGORY = "KJNodes/image"
DESCRIPTION = "Removes empty images from a batch"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE",),
"empty_color": ("STRING", {"default": "0, 0, 0"}),
"empty_threshold": ("FLOAT", {"default": 0.01,"min": 0.0, "max": 1.0, "step": 0.01}),
},
"optional": {
"replacement_image": ("IMAGE",),
}
}
def filter(self, images, empty_color, empty_threshold, replacement_image=None):
B, H, W, C = images.shape
input_images = images.clone()
empty_color_list = [int(color.strip()) for color in empty_color.split(',')]
empty_color_tensor = torch.tensor(empty_color_list, dtype=torch.float32).to(input_images.device)
color_diff = torch.abs(input_images - empty_color_tensor)
mean_diff = color_diff.mean(dim=(1, 2, 3))
empty_indices = mean_diff <= empty_threshold
if replacement_image is not None:
B_rep, H_rep, W_rep, C_rep = replacement_image.shape
replacement = replacement_image.clone()
if (H_rep != images.shape[1]) or (W_rep != images.shape[2]) or (C_rep != images.shape[3]):
replacement = common_upscale(replacement.movedim(-1, 1), W, H, "lanczos", "center").movedim(1, -1)
input_images[empty_indices] = replacement[0]
return (input_images,)
else:
non_empty_images = input_images[~empty_indices]
return (non_empty_images,)
class GetImagesFromBatchIndexed:
RETURN_TYPES = ("IMAGE",)