Add BlockifyMask

This commit is contained in:
kijai 2025-09-18 21:55:59 +03:00
parent edd7994b74
commit 00da191063
2 changed files with 66 additions and 2 deletions

View File

@ -25,6 +25,7 @@ NODE_CONFIG = {
"DrawMaskOnImage": {"class": DrawMaskOnImage, "name": "Draw Mask On Image"},
"DownloadAndLoadCLIPSeg": {"class": DownloadAndLoadCLIPSeg, "name": "(Down)load CLIPSeg"},
"BatchCLIPSeg": {"class": BatchCLIPSeg, "name": "Batch CLIPSeg"},
"BlockifyMask": {"class": BlockifyMask, "name": "Create Block Mask"},
"ColorToMask": {"class": ColorToMask, "name": "Color To Mask"},
"CreateGradientMask": {"class": CreateGradientMask, "name": "Create Gradient Mask"},
"CreateTextMask": {"class": CreateTextMask, "name": "Create Text Mask"},

View File

@ -1516,7 +1516,7 @@ class DrawMaskOnImage:
RETURN_TYPES = ("IMAGE", )
RETURN_NAMES = ("images",)
FUNCTION = "apply"
CATEGORY = "KJNodes/image"
CATEGORY = "KJNodes/masking"
DESCRIPTION = "Applies the provided masks to the input images."
def apply(self, image, mask, color):
@ -1559,4 +1559,67 @@ class DrawMaskOnImage:
out_rgb = torch.stack(output_images, dim=0)
return (out_rgb, )
return (out_rgb, )
class BlockifyMask:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"masks": ("MASK",),
"block_size": ("INT", {"default": 32, "min": 8, "max": 512, "step": 1, "tooltip": "Size of blocks in pixels (smaller = smaller blocks)"}),
}
}
RETURN_TYPES = ("MASK", )
RETURN_NAMES = ("mask",)
FUNCTION = "process"
CATEGORY = "KJNodes/masking"
DESCRIPTION = "Creates a block mask by dividing the bounding box of each mask into blocks of the specified size and filling in blocks that contain any part of the original mask."
def process(self, masks, block_size):
batch_size = masks.shape[0]
result_masks = []
for i in range(batch_size):
mask = masks[i]
# Find bounding box using tensor operations
nonzero_coords = torch.nonzero(mask, as_tuple=True)
if len(nonzero_coords[0]) == 0: # Empty mask
result_masks.append(mask)
continue
y_coords, x_coords = nonzero_coords
y_min, y_max = y_coords.min(), y_coords.max()
x_min, x_max = x_coords.min(), x_coords.max()
bbox_width = x_max - x_min + 1
bbox_height = y_max - y_min + 1
# Calculate number of blocks that fit
w_divisions = max(1, bbox_width // block_size)
h_divisions = max(1, bbox_height // block_size)
# Calculate actual block sizes (might be slightly larger than block_size)
w_slice = bbox_width // w_divisions
h_slice = bbox_height // h_divisions
# Create output mask (copy of input)
output_mask = mask.clone()
# Process grid cells
for w_start in range(x_min, x_max + 1, w_slice):
w_end = min(w_start + w_slice, x_max + 1)
for h_start in range(y_min, y_max + 1, h_slice):
h_end = min(h_start + h_slice, y_max + 1)
# Check if this cell contains any mask content
cell_region = mask[h_start:h_end, w_start:w_end]
if cell_region.sum() > 0:
# Fill the entire cell
output_mask[h_start:h_end, w_start:w_end] = 1.0
result_masks.append(output_mask)
return torch.stack(result_masks, dim=0),