diff --git a/nodes.py b/nodes.py index db3c10e..cfce1d8 100644 --- a/nodes.py +++ b/nodes.py @@ -726,7 +726,7 @@ class GrowMaskWithBlur: "required": { "mask": ("MASK",), "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), - "incremental_expandrate": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1}), + "incremental_expandrate": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}), "tapered_corners": ("BOOLEAN", {"default": True}), "flip_input": ("BOOLEAN", {"default": False}), "blur_radius": ("FLOAT", { @@ -739,7 +739,7 @@ class GrowMaskWithBlur: "decay_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, } - + CATEGORY = "KJNodes/masking" RETURN_TYPES = ("MASK", "MASK",) @@ -749,7 +749,7 @@ class GrowMaskWithBlur: def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor): alpha = lerp_alpha decay = decay_factor - if( flip_input ): + if flip_input: mask = 1.0 - mask c = 0 if tapered_corners else 1 kernel = np.array([[c, 1, c], @@ -758,44 +758,37 @@ class GrowMaskWithBlur: growmask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) out = [] previous_output = None + current_expand = expand for m in growmask: output = m.numpy() - for _ in range(abs(expand)): - if expand < 0: + for _ in range(abs(round(current_expand))): + if current_expand < 0: output = scipy.ndimage.grey_erosion(output, footprint=kernel) else: output = scipy.ndimage.grey_dilation(output, footprint=kernel) - if expand < 0: - expand -= abs(incremental_expandrate) # Use abs(growrate) to ensure positive change + if current_expand < 0: + current_expand -= abs(incremental_expandrate) else: - expand += abs(incremental_expandrate) # Use abs(growrate) to ensure positive change + current_expand += abs(incremental_expandrate) output = torch.from_numpy(output) if alpha < 1.0 and previous_output is not None: - # Interpolate between the previous and current frame output = alpha * output + (1 - alpha) * previous_output if decay < 1.0 and previous_output is not None: - # Add the decayed previous output to the current frame output += decay * previous_output output = output / output.max() previous_output = output out.append(output) if blur_radius != 0: - # Convert the tensor list to PIL images, apply blur, and convert back for idx, tensor in enumerate(out): - # Convert tensor to PIL image pil_image = tensor2pil(tensor.cpu().detach())[0] - # Apply Gaussian blur pil_image = pil_image.filter(ImageFilter.GaussianBlur(blur_radius)) - # Convert back to tensor out[idx] = pil2tensor(pil_image) blurred = torch.cat(out, dim=0) return (blurred, 1.0 - blurred) else: return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),) - - - + class PlotNode: @classmethod def INPUT_TYPES(s):