mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-06-04 10:11:20 +08:00
Use kornia for GPU accelerated mask dilation
This commit is contained in:
parent
eeb797736b
commit
9d7af919b9
@ -710,14 +710,19 @@ Visualizes the specified bbox on the image.
|
|||||||
def visualizebbox(self, bboxes, images, line_width, bbox_format):
|
def visualizebbox(self, bboxes, images, line_width, bbox_format):
|
||||||
image_list = []
|
image_list = []
|
||||||
for image, bbox in zip(images, bboxes):
|
for image, bbox in zip(images, bboxes):
|
||||||
if bbox_format == "xywh":
|
# Ensure bbox is a sequence of 4 values
|
||||||
x_min, y_min, width, height = bbox
|
if isinstance(bbox, (list, tuple, np.ndarray)) and len(bbox) == 4:
|
||||||
elif bbox_format == "xyxy":
|
if bbox_format == "xywh":
|
||||||
x_min, y_min, x_max, y_max = bbox
|
x_min, y_min, width, height = bbox
|
||||||
width = x_max - x_min
|
elif bbox_format == "xyxy":
|
||||||
height = y_max - y_min
|
x_min, y_min, x_max, y_max = bbox
|
||||||
|
width = x_max - x_min
|
||||||
|
height = y_max - y_min
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown bbox_format: {bbox_format}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown bbox_format: {bbox_format}")
|
print("Invalid bbox:", bbox)
|
||||||
|
continue
|
||||||
|
|
||||||
# Ensure bbox coordinates are integers
|
# Ensure bbox coordinates are integers
|
||||||
x_min = int(x_min)
|
x_min = int(x_min)
|
||||||
|
|||||||
@ -17,6 +17,8 @@ import folder_paths
|
|||||||
from ..utility.utility import tensor2pil, pil2tensor
|
from ..utility.utility import tensor2pil, pil2tensor
|
||||||
|
|
||||||
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
device = model_management.get_torch_device()
|
||||||
|
offload_device = model_management.unet_offload_device()
|
||||||
|
|
||||||
class BatchCLIPSeg:
|
class BatchCLIPSeg:
|
||||||
|
|
||||||
@ -997,6 +999,7 @@ class GrowMaskWithBlur:
|
|||||||
- fill_holes: fill holes in the mask (slow)"""
|
- fill_holes: fill holes in the mask (slow)"""
|
||||||
|
|
||||||
def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor, fill_holes=False):
|
def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor, fill_holes=False):
|
||||||
|
import kornia.morphology as morph
|
||||||
alpha = lerp_alpha
|
alpha = lerp_alpha
|
||||||
decay = decay_factor
|
decay = decay_factor
|
||||||
if flip_input:
|
if flip_input:
|
||||||
@ -1010,30 +1013,45 @@ class GrowMaskWithBlur:
|
|||||||
previous_output = None
|
previous_output = None
|
||||||
current_expand = expand
|
current_expand = expand
|
||||||
for m in growmask:
|
for m in growmask:
|
||||||
output = m.numpy().astype(np.float32)
|
output = m.unsqueeze(0).unsqueeze(0).to(device) # Add batch and channel dims for kornia
|
||||||
for _ in range(abs(round(current_expand))):
|
if abs(round(current_expand)) > 0:
|
||||||
if current_expand < 0:
|
# Create kernel - kornia expects kernel on same device as input
|
||||||
output = scipy.ndimage.grey_erosion(output, footprint=kernel)
|
if tapered_corners:
|
||||||
|
kernel = torch.tensor([[0, 1, 0],
|
||||||
|
[1, 1, 1],
|
||||||
|
[0, 1, 0]], dtype=torch.float32, device=output.device)
|
||||||
else:
|
else:
|
||||||
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
|
kernel = torch.tensor([[1, 1, 1],
|
||||||
|
[1, 1, 1],
|
||||||
|
[1, 1, 1]], dtype=torch.float32, device=output.device)
|
||||||
|
|
||||||
|
for _ in range(abs(round(current_expand))):
|
||||||
|
if current_expand < 0:
|
||||||
|
output = morph.erosion(output, kernel)
|
||||||
|
else:
|
||||||
|
output = morph.dilation(output, kernel)
|
||||||
|
|
||||||
|
output = output.squeeze(0).squeeze(0) # Remove batch and channel dims
|
||||||
|
|
||||||
if current_expand < 0:
|
if current_expand < 0:
|
||||||
current_expand -= abs(incremental_expandrate)
|
current_expand -= abs(incremental_expandrate)
|
||||||
else:
|
else:
|
||||||
current_expand += abs(incremental_expandrate)
|
current_expand += abs(incremental_expandrate)
|
||||||
|
|
||||||
if fill_holes:
|
if fill_holes:
|
||||||
|
# For fill_holes, you might need to keep using scipy or implement GPU version
|
||||||
binary_mask = output > 0
|
binary_mask = output > 0
|
||||||
output = scipy.ndimage.binary_fill_holes(binary_mask)
|
output_np = binary_mask.cpu().numpy()
|
||||||
output = output.astype(np.float32) * 255
|
filled = scipy.ndimage.binary_fill_holes(output_np)
|
||||||
output = torch.from_numpy(output)
|
output = torch.from_numpy(filled.astype(np.float32)).to(output.device)
|
||||||
|
|
||||||
if alpha < 1.0 and previous_output is not None:
|
if alpha < 1.0 and previous_output is not None:
|
||||||
# Interpolate between the previous and current frame
|
|
||||||
output = alpha * output + (1 - alpha) * previous_output
|
output = alpha * output + (1 - alpha) * previous_output
|
||||||
if decay < 1.0 and previous_output is not None:
|
if decay < 1.0 and previous_output is not None:
|
||||||
# Add the decayed previous output to the current frame
|
|
||||||
output += decay * previous_output
|
output += decay * previous_output
|
||||||
output = output / output.max()
|
output = output / output.max()
|
||||||
previous_output = output
|
previous_output = output
|
||||||
out.append(output)
|
out.append(output.cpu())
|
||||||
|
|
||||||
if blur_radius != 0:
|
if blur_radius != 0:
|
||||||
# Convert the tensor list to PIL images, apply blur, and convert back
|
# Convert the tensor list to PIL images, apply blur, and convert back
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user