This commit is contained in:
kijai 2025-09-16 18:23:09 +03:00
parent ed0ab5231f
commit 1fff7fed6b
2 changed files with 8 additions and 8 deletions

View File

@ -41,7 +41,7 @@ NODE_CONFIG = {
"ResizeMask": {"class": ResizeMask, "name": "Resize Mask"}, "ResizeMask": {"class": ResizeMask, "name": "Resize Mask"},
"RoundMask": {"class": RoundMask, "name": "Round Mask"}, "RoundMask": {"class": RoundMask, "name": "Round Mask"},
"SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"}, "SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"},
"KJConsolidateMasks": {"class": KJConsolidateMasks, "name": "Consolidate Masks"}, "ConsolidateMasksKJ": {"class": ConsolidateMasksKJ, "name": "Consolidate Masks"},
#images #images
"AddLabel": {"class": AddLabel, "name": "Add Label"}, "AddLabel": {"class": AddLabel, "name": "Add Label"},
"ColorMatch": {"class": ColorMatch, "name": "Color Match"}, "ColorMatch": {"class": ColorMatch, "name": "Color Match"},

View File

@ -1426,12 +1426,12 @@ class SeparateMasks:
return torch.empty((1, 64, 64), device=mask.device), return torch.empty((1, 64, 64), device=mask.device),
class KJConsolidateMasks: class ConsolidateMasksKJ:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"mask": ("MASK",), "masks": ("MASK",),
"width": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}), "width": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}),
"height": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}), "height": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}),
"padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), "padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
@ -1444,8 +1444,8 @@ class KJConsolidateMasks:
CATEGORY = "KJNodes/masking" CATEGORY = "KJNodes/masking"
DESCRIPTION = "Consolidates a batch of separate masks by finding the largest group of masks that fit inside a tile of the given width and height (including the padding), and repeating until no more masks can be combined." DESCRIPTION = "Consolidates a batch of separate masks by finding the largest group of masks that fit inside a tile of the given width and height (including the padding), and repeating until no more masks can be combined."
def consolidate(self, mask, width=512, height=512, padding=0): def consolidate(self, masks, width=512, height=512, padding=0):
B, H, W = mask.shape B, H, W = masks.shape
def mask_fits(coords, candidate_coords): def mask_fits(coords, candidate_coords):
x_min, y_min, x_max, y_max = coords x_min, y_min, x_max, y_max = coords
@ -1459,7 +1459,7 @@ class KJConsolidateMasks:
separated = [] separated = []
final_masks = [] final_masks = []
for b in range(B): for b in range(B):
m = mask[b] m = masks[b]
rows, cols = m.any(dim=1), m.any(dim=0) rows, cols = m.any(dim=1), m.any(dim=0)
y_min, y_max = torch.where(rows)[0][[0, -1]] y_min, y_max = torch.where(rows)[0][[0, -1]]
x_min, x_max = torch.where(cols)[0][[0, -1]] x_min, x_max = torch.where(cols)[0][[0, -1]]
@ -1469,8 +1469,8 @@ class KJConsolidateMasks:
separated.sort(key=lambda x: x[0]) separated.sort(key=lambda x: x[0])
fits = [] fits = []
for i, mask in enumerate(separated): for i, masks in enumerate(separated):
coord = mask[0] coord = masks[0]
fits_in_box = [] fits_in_box = []
for j, cand_mask in enumerate(separated): for j, cand_mask in enumerate(separated):
if i == j: if i == j: