mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-04-05 15:26:59 +08:00
optimize blockify mask node
This commit is contained in:
parent
bb205d809b
commit
b9abf5df31
@ -6,6 +6,7 @@ import scipy.ndimage
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.utils import ProgressBar
|
||||
@ -17,7 +18,7 @@ import folder_paths
|
||||
from ..utility.utility import tensor2pil, pil2tensor
|
||||
|
||||
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
device = model_management.get_torch_device()
|
||||
main_device = model_management.get_torch_device()
|
||||
offload_device = model_management.unet_offload_device()
|
||||
|
||||
class BatchCLIPSeg:
|
||||
@ -1013,7 +1014,7 @@ class GrowMaskWithBlur:
|
||||
previous_output = None
|
||||
current_expand = expand
|
||||
for m in growmask:
|
||||
output = m.unsqueeze(0).unsqueeze(0).to(device) # Add batch and channel dims for kornia
|
||||
output = m.unsqueeze(0).unsqueeze(0).to(main_device) # Add batch and channel dims for kornia
|
||||
if abs(round(current_expand)) > 0:
|
||||
# Create kernel - kornia expects kernel on same device as input
|
||||
if tapered_corners:
|
||||
@ -1586,8 +1587,11 @@ class BlockifyMask:
|
||||
return {"required": {
|
||||
"masks": ("MASK",),
|
||||
"block_size": ("INT", {"default": 32, "min": 8, "max": 512, "step": 1, "tooltip": "Size of blocks in pixels (smaller = smaller blocks)"}),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["cpu", "gpu"], {"default": "cpu", "tooltip": "Device to use for processing"}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK", )
|
||||
RETURN_NAMES = ("mask",)
|
||||
@ -1595,49 +1599,70 @@ class BlockifyMask:
|
||||
CATEGORY = "KJNodes/masking"
|
||||
DESCRIPTION = "Creates a block mask by dividing the bounding box of each mask into blocks of the specified size and filling in blocks that contain any part of the original mask."
|
||||
|
||||
def process(self, masks, block_size):
|
||||
batch_size = masks.shape[0]
|
||||
result_masks = []
|
||||
def process(self, masks, block_size, device="cpu"):
|
||||
processing_device = main_device if device == "gpu" else torch.device("cpu")
|
||||
print("mask.dtype:", masks.dtype, "device:", processing_device)
|
||||
|
||||
for i in range(batch_size):
|
||||
masks = masks.to(processing_device)
|
||||
batch_size, height, width = masks.shape
|
||||
|
||||
result_masks = torch.zeros_like(masks)
|
||||
|
||||
for i in tqdm(range(batch_size), desc="BlockifyMask batch"):
|
||||
mask = masks[i]
|
||||
|
||||
# Find bounding box using tensor operations
|
||||
nonzero_coords = torch.nonzero(mask, as_tuple=True)
|
||||
if len(nonzero_coords[0]) == 0: # Empty mask
|
||||
result_masks.append(mask)
|
||||
# Find bounding box efficiently
|
||||
mask_bool = mask > 0
|
||||
if not mask_bool.any():
|
||||
continue
|
||||
|
||||
y_coords, x_coords = nonzero_coords
|
||||
y_min, y_max = y_coords.min(), y_coords.max()
|
||||
x_min, x_max = x_coords.min(), x_coords.max()
|
||||
y_indices = torch.nonzero(mask_bool.any(dim=1), as_tuple=True)[0]
|
||||
x_indices = torch.nonzero(mask_bool.any(dim=0), as_tuple=True)[0]
|
||||
|
||||
if len(y_indices) == 0 or len(x_indices) == 0:
|
||||
continue
|
||||
|
||||
y_min, y_max = y_indices[0], y_indices[-1]
|
||||
x_min, x_max = x_indices[0], x_indices[-1]
|
||||
|
||||
bbox_width = x_max - x_min + 1
|
||||
bbox_height = y_max - y_min + 1
|
||||
|
||||
# Calculate number of blocks that fit
|
||||
# Calculate block grid
|
||||
w_divisions = max(1, bbox_width // block_size)
|
||||
h_divisions = max(1, bbox_height // block_size)
|
||||
|
||||
# Calculate actual block sizes (might be slightly larger than block_size)
|
||||
w_slice = bbox_width // w_divisions
|
||||
h_slice = bbox_height // h_divisions
|
||||
|
||||
# Create output mask (copy of input)
|
||||
output_mask = mask.clone()
|
||||
# Create coordinate grids only for bbox region
|
||||
y_coords = torch.arange(y_min, y_max + 1, device=processing_device).view(-1, 1)
|
||||
x_coords = torch.arange(x_min, x_max + 1, device=processing_device).view(1, -1)
|
||||
|
||||
# Process grid cells
|
||||
for w_start in range(x_min, x_max + 1, w_slice):
|
||||
w_end = min(w_start + w_slice, x_max + 1)
|
||||
for h_start in range(y_min, y_max + 1, h_slice):
|
||||
h_end = min(h_start + h_slice, y_max + 1)
|
||||
|
||||
# Check if this cell contains any mask content
|
||||
cell_region = mask[h_start:h_end, w_start:w_end]
|
||||
if cell_region.sum() > 0:
|
||||
# Fill the entire cell
|
||||
output_mask[h_start:h_end, w_start:w_end] = 1.0
|
||||
# Calculate block indices for bbox region
|
||||
w_block_indices = (x_coords - x_min) // w_slice
|
||||
h_block_indices = (y_coords - y_min) // h_slice
|
||||
|
||||
result_masks.append(output_mask)
|
||||
# Clamp to valid range
|
||||
w_block_indices = w_block_indices.clamp(0, w_divisions - 1)
|
||||
h_block_indices = h_block_indices.clamp(0, h_divisions - 1)
|
||||
|
||||
# Create unique block IDs by combining h and w indices
|
||||
block_ids = h_block_indices * w_divisions + w_block_indices
|
||||
|
||||
# Get mask region within bbox
|
||||
mask_region = mask[y_min:y_max+1, x_min:x_max+1]
|
||||
|
||||
# Find which blocks have content using scatter_add
|
||||
max_blocks = h_divisions * w_divisions
|
||||
block_content = torch.zeros(max_blocks, device=processing_device)
|
||||
block_content.scatter_add_(0, block_ids.flatten(), mask_region.flatten())
|
||||
|
||||
# Create result for blocks that have content
|
||||
has_content = block_content > 0
|
||||
block_mask = has_content[block_ids]
|
||||
|
||||
# Fill the result
|
||||
result_masks[i, y_min:y_max+1, x_min:x_max+1] = block_mask.float()
|
||||
|
||||
return torch.stack(result_masks, dim=0),
|
||||
return (result_masks.clamp(0, 1),)
|
||||
Loading…
x
Reference in New Issue
Block a user