mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-20 18:24:33 +08:00
Add BlockifyMask
This commit is contained in:
parent
edd7994b74
commit
00da191063
@ -25,6 +25,7 @@ NODE_CONFIG = {
|
|||||||
"DrawMaskOnImage": {"class": DrawMaskOnImage, "name": "Draw Mask On Image"},
|
"DrawMaskOnImage": {"class": DrawMaskOnImage, "name": "Draw Mask On Image"},
|
||||||
"DownloadAndLoadCLIPSeg": {"class": DownloadAndLoadCLIPSeg, "name": "(Down)load CLIPSeg"},
|
"DownloadAndLoadCLIPSeg": {"class": DownloadAndLoadCLIPSeg, "name": "(Down)load CLIPSeg"},
|
||||||
"BatchCLIPSeg": {"class": BatchCLIPSeg, "name": "Batch CLIPSeg"},
|
"BatchCLIPSeg": {"class": BatchCLIPSeg, "name": "Batch CLIPSeg"},
|
||||||
|
"BlockifyMask": {"class": BlockifyMask, "name": "Create Block Mask"},
|
||||||
"ColorToMask": {"class": ColorToMask, "name": "Color To Mask"},
|
"ColorToMask": {"class": ColorToMask, "name": "Color To Mask"},
|
||||||
"CreateGradientMask": {"class": CreateGradientMask, "name": "Create Gradient Mask"},
|
"CreateGradientMask": {"class": CreateGradientMask, "name": "Create Gradient Mask"},
|
||||||
"CreateTextMask": {"class": CreateTextMask, "name": "Create Text Mask"},
|
"CreateTextMask": {"class": CreateTextMask, "name": "Create Text Mask"},
|
||||||
|
|||||||
@ -1516,7 +1516,7 @@ class DrawMaskOnImage:
|
|||||||
RETURN_TYPES = ("IMAGE", )
|
RETURN_TYPES = ("IMAGE", )
|
||||||
RETURN_NAMES = ("images",)
|
RETURN_NAMES = ("images",)
|
||||||
FUNCTION = "apply"
|
FUNCTION = "apply"
|
||||||
CATEGORY = "KJNodes/image"
|
CATEGORY = "KJNodes/masking"
|
||||||
DESCRIPTION = "Applies the provided masks to the input images."
|
DESCRIPTION = "Applies the provided masks to the input images."
|
||||||
|
|
||||||
def apply(self, image, mask, color):
|
def apply(self, image, mask, color):
|
||||||
@ -1560,3 +1560,66 @@ class DrawMaskOnImage:
|
|||||||
out_rgb = torch.stack(output_images, dim=0)
|
out_rgb = torch.stack(output_images, dim=0)
|
||||||
|
|
||||||
return (out_rgb, )
|
return (out_rgb, )
|
||||||
|
|
||||||
|
|
||||||
|
class BlockifyMask:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"masks": ("MASK",),
|
||||||
|
"block_size": ("INT", {"default": 32, "min": 8, "max": 512, "step": 1, "tooltip": "Size of blocks in pixels (smaller = smaller blocks)"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MASK", )
|
||||||
|
RETURN_NAMES = ("mask",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "KJNodes/masking"
|
||||||
|
DESCRIPTION = "Creates a block mask by dividing the bounding box of each mask into blocks of the specified size and filling in blocks that contain any part of the original mask."
|
||||||
|
|
||||||
|
def process(self, masks, block_size):
|
||||||
|
batch_size = masks.shape[0]
|
||||||
|
result_masks = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
mask = masks[i]
|
||||||
|
|
||||||
|
# Find bounding box using tensor operations
|
||||||
|
nonzero_coords = torch.nonzero(mask, as_tuple=True)
|
||||||
|
if len(nonzero_coords[0]) == 0: # Empty mask
|
||||||
|
result_masks.append(mask)
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_coords, x_coords = nonzero_coords
|
||||||
|
y_min, y_max = y_coords.min(), y_coords.max()
|
||||||
|
x_min, x_max = x_coords.min(), x_coords.max()
|
||||||
|
|
||||||
|
bbox_width = x_max - x_min + 1
|
||||||
|
bbox_height = y_max - y_min + 1
|
||||||
|
|
||||||
|
# Calculate number of blocks that fit
|
||||||
|
w_divisions = max(1, bbox_width // block_size)
|
||||||
|
h_divisions = max(1, bbox_height // block_size)
|
||||||
|
|
||||||
|
# Calculate actual block sizes (might be slightly larger than block_size)
|
||||||
|
w_slice = bbox_width // w_divisions
|
||||||
|
h_slice = bbox_height // h_divisions
|
||||||
|
|
||||||
|
# Create output mask (copy of input)
|
||||||
|
output_mask = mask.clone()
|
||||||
|
|
||||||
|
# Process grid cells
|
||||||
|
for w_start in range(x_min, x_max + 1, w_slice):
|
||||||
|
w_end = min(w_start + w_slice, x_max + 1)
|
||||||
|
for h_start in range(y_min, y_max + 1, h_slice):
|
||||||
|
h_end = min(h_start + h_slice, y_max + 1)
|
||||||
|
|
||||||
|
# Check if this cell contains any mask content
|
||||||
|
cell_region = mask[h_start:h_end, w_start:w_end]
|
||||||
|
if cell_region.sum() > 0:
|
||||||
|
# Fill the entire cell
|
||||||
|
output_mask[h_start:h_end, w_start:w_end] = 1.0
|
||||||
|
|
||||||
|
result_masks.append(output_mask)
|
||||||
|
|
||||||
|
return torch.stack(result_masks, dim=0),
|
||||||
Loading…
x
Reference in New Issue
Block a user