From 9d7af919b91838fb22e31ad0107a6ddcf8bd7f3f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 23 Sep 2025 20:15:13 +0300 Subject: [PATCH] Use kornia for GPU accelerated mask dilation --- nodes/batchcrop_nodes.py | 19 ++++++++++++------- nodes/mask_nodes.py | 40 +++++++++++++++++++++++++++++----------- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/nodes/batchcrop_nodes.py b/nodes/batchcrop_nodes.py index 3b8cd3a..304ff12 100644 --- a/nodes/batchcrop_nodes.py +++ b/nodes/batchcrop_nodes.py @@ -710,14 +710,19 @@ Visualizes the specified bbox on the image. def visualizebbox(self, bboxes, images, line_width, bbox_format): image_list = [] for image, bbox in zip(images, bboxes): - if bbox_format == "xywh": - x_min, y_min, width, height = bbox - elif bbox_format == "xyxy": - x_min, y_min, x_max, y_max = bbox - width = x_max - x_min - height = y_max - y_min + # Ensure bbox is a sequence of 4 values + if isinstance(bbox, (list, tuple, np.ndarray)) and len(bbox) == 4: + if bbox_format == "xywh": + x_min, y_min, width, height = bbox + elif bbox_format == "xyxy": + x_min, y_min, x_max, y_max = bbox + width = x_max - x_min + height = y_max - y_min + else: + raise ValueError(f"Unknown bbox_format: {bbox_format}") else: - raise ValueError(f"Unknown bbox_format: {bbox_format}") + print("Invalid bbox:", bbox) + continue # Ensure bbox coordinates are integers x_min = int(x_min) diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 69bbfa4..5397e36 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -17,6 +17,8 @@ import folder_paths from ..utility.utility import tensor2pil, pil2tensor script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +device = model_management.get_torch_device() +offload_device = model_management.unet_offload_device() class BatchCLIPSeg: @@ -997,6 +999,7 @@ class GrowMaskWithBlur: - fill_holes: fill holes in the mask (slow)""" def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor, fill_holes=False): + import kornia.morphology as morph alpha = lerp_alpha decay = decay_factor if flip_input: @@ -1010,30 +1013,45 @@ class GrowMaskWithBlur: previous_output = None current_expand = expand for m in growmask: - output = m.numpy().astype(np.float32) - for _ in range(abs(round(current_expand))): - if current_expand < 0: - output = scipy.ndimage.grey_erosion(output, footprint=kernel) + output = m.unsqueeze(0).unsqueeze(0).to(device) # Add batch and channel dims for kornia + if abs(round(current_expand)) > 0: + # Create kernel - kornia expects kernel on same device as input + if tapered_corners: + kernel = torch.tensor([[0, 1, 0], + [1, 1, 1], + [0, 1, 0]], dtype=torch.float32, device=output.device) else: - output = scipy.ndimage.grey_dilation(output, footprint=kernel) + kernel = torch.tensor([[1, 1, 1], + [1, 1, 1], + [1, 1, 1]], dtype=torch.float32, device=output.device) + + for _ in range(abs(round(current_expand))): + if current_expand < 0: + output = morph.erosion(output, kernel) + else: + output = morph.dilation(output, kernel) + + output = output.squeeze(0).squeeze(0) # Remove batch and channel dims + if current_expand < 0: current_expand -= abs(incremental_expandrate) else: current_expand += abs(incremental_expandrate) + if fill_holes: + # For fill_holes, you might need to keep using scipy or implement GPU version binary_mask = output > 0 - output = scipy.ndimage.binary_fill_holes(binary_mask) - output = output.astype(np.float32) * 255 - output = torch.from_numpy(output) + output_np = binary_mask.cpu().numpy() + filled = scipy.ndimage.binary_fill_holes(output_np) + output = torch.from_numpy(filled.astype(np.float32)).to(output.device) + if alpha < 1.0 and previous_output is not None: - # Interpolate between the previous and current frame output = alpha * output + (1 - alpha) * previous_output if decay < 1.0 and previous_output is not None: - # Add the decayed previous output to the current frame output += decay * previous_output output = output / output.max() previous_output = output - out.append(output) + out.append(output.cpu()) if blur_radius != 0: # Convert the tensor list to PIL images, apply blur, and convert back