optimize blockify mask node

This commit is contained in:
kijai 2025-10-03 19:01:46 +03:00
parent bb205d809b
commit b9abf5df31

View File

@ -6,6 +6,7 @@ import scipy.ndimage
import numpy as np
from contextlib import nullcontext
import os
from tqdm import tqdm
from comfy import model_management
from comfy.utils import ProgressBar
@ -17,7 +18,7 @@ import folder_paths
from ..utility.utility import tensor2pil, pil2tensor
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
device = model_management.get_torch_device()
main_device = model_management.get_torch_device()
offload_device = model_management.unet_offload_device()
class BatchCLIPSeg:
@ -1013,7 +1014,7 @@ class GrowMaskWithBlur:
previous_output = None
current_expand = expand
for m in growmask:
output = m.unsqueeze(0).unsqueeze(0).to(device) # Add batch and channel dims for kornia
output = m.unsqueeze(0).unsqueeze(0).to(main_device) # Add batch and channel dims for kornia
if abs(round(current_expand)) > 0:
# Create kernel - kornia expects kernel on same device as input
if tapered_corners:
@ -1586,8 +1587,11 @@ class BlockifyMask:
return {"required": {
"masks": ("MASK",),
"block_size": ("INT", {"default": 32, "min": 8, "max": 512, "step": 1, "tooltip": "Size of blocks in pixels (smaller = smaller blocks)"}),
},
"optional": {
"device": (["cpu", "gpu"], {"default": "cpu", "tooltip": "Device to use for processing"}),
}
}
}
RETURN_TYPES = ("MASK", )
RETURN_NAMES = ("mask",)
@ -1595,49 +1599,70 @@ class BlockifyMask:
CATEGORY = "KJNodes/masking"
DESCRIPTION = "Creates a block mask by dividing the bounding box of each mask into blocks of the specified size and filling in blocks that contain any part of the original mask."
def process(self, masks, block_size):
batch_size = masks.shape[0]
result_masks = []
def process(self, masks, block_size, device="cpu"):
processing_device = main_device if device == "gpu" else torch.device("cpu")
print("mask.dtype:", masks.dtype, "device:", processing_device)
for i in range(batch_size):
masks = masks.to(processing_device)
batch_size, height, width = masks.shape
result_masks = torch.zeros_like(masks)
for i in tqdm(range(batch_size), desc="BlockifyMask batch"):
mask = masks[i]
# Find bounding box using tensor operations
nonzero_coords = torch.nonzero(mask, as_tuple=True)
if len(nonzero_coords[0]) == 0: # Empty mask
result_masks.append(mask)
# Find bounding box efficiently
mask_bool = mask > 0
if not mask_bool.any():
continue
y_coords, x_coords = nonzero_coords
y_min, y_max = y_coords.min(), y_coords.max()
x_min, x_max = x_coords.min(), x_coords.max()
y_indices = torch.nonzero(mask_bool.any(dim=1), as_tuple=True)[0]
x_indices = torch.nonzero(mask_bool.any(dim=0), as_tuple=True)[0]
if len(y_indices) == 0 or len(x_indices) == 0:
continue
y_min, y_max = y_indices[0], y_indices[-1]
x_min, x_max = x_indices[0], x_indices[-1]
bbox_width = x_max - x_min + 1
bbox_height = y_max - y_min + 1
# Calculate number of blocks that fit
# Calculate block grid
w_divisions = max(1, bbox_width // block_size)
h_divisions = max(1, bbox_height // block_size)
# Calculate actual block sizes (might be slightly larger than block_size)
w_slice = bbox_width // w_divisions
h_slice = bbox_height // h_divisions
# Create output mask (copy of input)
output_mask = mask.clone()
# Create coordinate grids only for bbox region
y_coords = torch.arange(y_min, y_max + 1, device=processing_device).view(-1, 1)
x_coords = torch.arange(x_min, x_max + 1, device=processing_device).view(1, -1)
# Process grid cells
for w_start in range(x_min, x_max + 1, w_slice):
w_end = min(w_start + w_slice, x_max + 1)
for h_start in range(y_min, y_max + 1, h_slice):
h_end = min(h_start + h_slice, y_max + 1)
# Check if this cell contains any mask content
cell_region = mask[h_start:h_end, w_start:w_end]
if cell_region.sum() > 0:
# Fill the entire cell
output_mask[h_start:h_end, w_start:w_end] = 1.0
# Calculate block indices for bbox region
w_block_indices = (x_coords - x_min) // w_slice
h_block_indices = (y_coords - y_min) // h_slice
result_masks.append(output_mask)
# Clamp to valid range
w_block_indices = w_block_indices.clamp(0, w_divisions - 1)
h_block_indices = h_block_indices.clamp(0, h_divisions - 1)
# Create unique block IDs by combining h and w indices
block_ids = h_block_indices * w_divisions + w_block_indices
# Get mask region within bbox
mask_region = mask[y_min:y_max+1, x_min:x_max+1]
# Find which blocks have content using scatter_add
max_blocks = h_divisions * w_divisions
block_content = torch.zeros(max_blocks, device=processing_device)
block_content.scatter_add_(0, block_ids.flatten(), mask_region.flatten())
# Create result for blocks that have content
has_content = block_content > 0
block_mask = has_content[block_ids]
# Fill the result
result_masks[i, y_min:y_max+1, x_min:x_max+1] = block_mask.float()
return torch.stack(result_masks, dim=0),
return (result_masks.clamp(0, 1),)