diff --git a/nodes.py b/nodes.py index 80dd68f..9b95279 100644 --- a/nodes.py +++ b/nodes.py @@ -9,7 +9,7 @@ import scipy.ndimage from scipy.spatial import Voronoi import matplotlib.pyplot as plt import numpy as np -from PIL import ImageFilter, Image, ImageDraw, ImageFont +from PIL import ImageFilter, Image, ImageDraw, ImageFont, ImageOps from PIL.PngImagePlugin import PngInfo import json import re @@ -773,6 +773,9 @@ class GrowMaskWithBlur: "lerp_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "decay_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, + "optional": { + "fill_holes": ("BOOLEAN", {"default": False}), + }, } CATEGORY = "KJNodes/masking" @@ -781,7 +784,7 @@ class GrowMaskWithBlur: RETURN_NAMES = ("mask", "mask_inverted",) FUNCTION = "expand_mask" - def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor): + def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor, fill_holes=False): alpha = lerp_alpha decay = decay_factor if flip_input: @@ -805,6 +808,10 @@ class GrowMaskWithBlur: current_expand -= abs(incremental_expandrate) else: current_expand += abs(incremental_expandrate) + if fill_holes: + binary_mask = output > 0 + output = scipy.ndimage.binary_fill_holes(binary_mask) + output = output.astype(np.uint8) * 255 output = torch.from_numpy(output) if alpha < 1.0 and previous_output is not None: # Interpolate between the previous and current frame