Use kornia for GPU accelerated mask dilation

This commit is contained in:
kijai 2025-09-23 20:15:13 +03:00
parent eeb797736b
commit 9d7af919b9
2 changed files with 41 additions and 18 deletions

View File

@ -710,14 +710,19 @@ Visualizes the specified bbox on the image.
def visualizebbox(self, bboxes, images, line_width, bbox_format): def visualizebbox(self, bboxes, images, line_width, bbox_format):
image_list = [] image_list = []
for image, bbox in zip(images, bboxes): for image, bbox in zip(images, bboxes):
if bbox_format == "xywh": # Ensure bbox is a sequence of 4 values
x_min, y_min, width, height = bbox if isinstance(bbox, (list, tuple, np.ndarray)) and len(bbox) == 4:
elif bbox_format == "xyxy": if bbox_format == "xywh":
x_min, y_min, x_max, y_max = bbox x_min, y_min, width, height = bbox
width = x_max - x_min elif bbox_format == "xyxy":
height = y_max - y_min x_min, y_min, x_max, y_max = bbox
width = x_max - x_min
height = y_max - y_min
else:
raise ValueError(f"Unknown bbox_format: {bbox_format}")
else: else:
raise ValueError(f"Unknown bbox_format: {bbox_format}") print("Invalid bbox:", bbox)
continue
# Ensure bbox coordinates are integers # Ensure bbox coordinates are integers
x_min = int(x_min) x_min = int(x_min)

View File

@ -17,6 +17,8 @@ import folder_paths
from ..utility.utility import tensor2pil, pil2tensor from ..utility.utility import tensor2pil, pil2tensor
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
device = model_management.get_torch_device()
offload_device = model_management.unet_offload_device()
class BatchCLIPSeg: class BatchCLIPSeg:
@ -997,6 +999,7 @@ class GrowMaskWithBlur:
- fill_holes: fill holes in the mask (slow)""" - fill_holes: fill holes in the mask (slow)"""
def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor, fill_holes=False): def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor, fill_holes=False):
import kornia.morphology as morph
alpha = lerp_alpha alpha = lerp_alpha
decay = decay_factor decay = decay_factor
if flip_input: if flip_input:
@ -1010,30 +1013,45 @@ class GrowMaskWithBlur:
previous_output = None previous_output = None
current_expand = expand current_expand = expand
for m in growmask: for m in growmask:
output = m.numpy().astype(np.float32) output = m.unsqueeze(0).unsqueeze(0).to(device) # Add batch and channel dims for kornia
for _ in range(abs(round(current_expand))): if abs(round(current_expand)) > 0:
if current_expand < 0: # Create kernel - kornia expects kernel on same device as input
output = scipy.ndimage.grey_erosion(output, footprint=kernel) if tapered_corners:
kernel = torch.tensor([[0, 1, 0],
[1, 1, 1],
[0, 1, 0]], dtype=torch.float32, device=output.device)
else: else:
output = scipy.ndimage.grey_dilation(output, footprint=kernel) kernel = torch.tensor([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]], dtype=torch.float32, device=output.device)
for _ in range(abs(round(current_expand))):
if current_expand < 0:
output = morph.erosion(output, kernel)
else:
output = morph.dilation(output, kernel)
output = output.squeeze(0).squeeze(0) # Remove batch and channel dims
if current_expand < 0: if current_expand < 0:
current_expand -= abs(incremental_expandrate) current_expand -= abs(incremental_expandrate)
else: else:
current_expand += abs(incremental_expandrate) current_expand += abs(incremental_expandrate)
if fill_holes: if fill_holes:
# For fill_holes, you might need to keep using scipy or implement GPU version
binary_mask = output > 0 binary_mask = output > 0
output = scipy.ndimage.binary_fill_holes(binary_mask) output_np = binary_mask.cpu().numpy()
output = output.astype(np.float32) * 255 filled = scipy.ndimage.binary_fill_holes(output_np)
output = torch.from_numpy(output) output = torch.from_numpy(filled.astype(np.float32)).to(output.device)
if alpha < 1.0 and previous_output is not None: if alpha < 1.0 and previous_output is not None:
# Interpolate between the previous and current frame
output = alpha * output + (1 - alpha) * previous_output output = alpha * output + (1 - alpha) * previous_output
if decay < 1.0 and previous_output is not None: if decay < 1.0 and previous_output is not None:
# Add the decayed previous output to the current frame
output += decay * previous_output output += decay * previous_output
output = output / output.max() output = output / output.max()
previous_output = output previous_output = output
out.append(output) out.append(output.cpu())
if blur_radius != 0: if blur_radius != 0:
# Convert the tensor list to PIL images, apply blur, and convert back # Convert the tensor list to PIL images, apply blur, and convert back