From ed0ab5231fc322c5c4ef637f5810041b4f3854c7 Mon Sep 17 00:00:00 2001 From: asagi4 <130366179+asagi4@users.noreply.github.com> Date: Tue, 8 Jul 2025 00:19:03 +0300 Subject: [PATCH 1/2] Mask consolidation node --- __init__.py | 3 +- nodes/mask_nodes.py | 79 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/__init__.py b/__init__.py index fce7a50..2c23dd4 100644 --- a/__init__.py +++ b/__init__.py @@ -41,6 +41,7 @@ NODE_CONFIG = { "ResizeMask": {"class": ResizeMask, "name": "Resize Mask"}, "RoundMask": {"class": RoundMask, "name": "Round Mask"}, "SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"}, + "KJConsolidateMasks": {"class": KJConsolidateMasks, "name": "Consolidate Masks"}, #images "AddLabel": {"class": AddLabel, "name": "Add Label"}, "ColorMatch": {"class": ColorMatch, "name": "Color Match"}, @@ -232,4 +233,4 @@ if hasattr(PromptServer, "instance"): [web.static("/kjweb_async", (Path(__file__).parent.absolute() / "kjweb_async").as_posix())] ) except: - pass \ No newline at end of file + pass diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 9a84a12..8268f18 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -1424,4 +1424,81 @@ class SeparateMasks: return out_masks, else: return torch.empty((1, 64, 64), device=mask.device), - \ No newline at end of file + + +class KJConsolidateMasks: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + "width": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}), + "height": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}), + "padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), + }, + } + + RETURN_TYPES = ("MASK",) + FUNCTION = "consolidate" + + CATEGORY = "KJNodes/masking" + DESCRIPTION = "Consolidates a batch of separate masks by finding the largest group of masks that fit inside a tile of the given width and height (including the padding), and repeating until no more masks can be combined." + + def consolidate(self, mask, width=512, height=512, padding=0): + B, H, W = mask.shape + + def mask_fits(coords, candidate_coords): + x_min, y_min, x_max, y_max = coords + cx_min, cy_min, cx_max, cy_max = candidate_coords + nx_min, ny_min = min(x_min, cx_min), min(y_min, cy_min) + nx_max, ny_max = max(x_max, cx_max), max(y_max, cy_max) + if nx_min + width < nx_max + padding or ny_min + height < ny_max + padding: + return False, coords + return True, (nx_min, ny_min, nx_max, ny_max) + + separated = [] + final_masks = [] + for b in range(B): + m = mask[b] + rows, cols = m.any(dim=1), m.any(dim=0) + y_min, y_max = torch.where(rows)[0][[0, -1]] + x_min, x_max = torch.where(cols)[0][[0, -1]] + w = x_max - x_min + 1 + h = y_max - y_min + 1 + separated.append(((x_min.item(), y_min.item(), x_max.item(), y_max.item()), m)) + + separated.sort(key=lambda x: x[0]) + fits = [] + for i, mask in enumerate(separated): + coord = mask[0] + fits_in_box = [] + for j, cand_mask in enumerate(separated): + if i == j: + continue + r, coord = mask_fits(coord, cand_mask[0]) + if r: + fits_in_box.append(j) + fits.append((i, fits_in_box)) + fits.sort(key=lambda x: -len(x[1])) + seen = [] + unique_fits = [] + for idx, fs in fits: + uniq = [i for i in fs if i not in seen] + unique_fits.append((idx, fs, uniq)) + seen.extend(uniq) + unique_fits.sort(key=lambda x: (-len(x[1]), -len(x[2]))) + merged = [] + for mask_idx, fitting_masks, _ in unique_fits: + if mask_idx in merged: + continue + fitting_masks = [i for i in fitting_masks if i not in merged] + combined_mask = separated[mask_idx][1].clone() + for i in fitting_masks: + combined_mask += separated[i][1] + merged.append(i) + merged.append(mask_idx) + final_masks.append(combined_mask) + + print(f"Consolidated {B} masks into {len(final_masks)}") + return (torch.stack(final_masks, dim=0),) + From 1fff7fed6be83f56b8fc9c626e7b46f4e8c67082 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 16 Sep 2025 18:23:09 +0300 Subject: [PATCH 2/2] Rename --- __init__.py | 2 +- nodes/mask_nodes.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/__init__.py b/__init__.py index 2c23dd4..80a5386 100644 --- a/__init__.py +++ b/__init__.py @@ -41,7 +41,7 @@ NODE_CONFIG = { "ResizeMask": {"class": ResizeMask, "name": "Resize Mask"}, "RoundMask": {"class": RoundMask, "name": "Round Mask"}, "SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"}, - "KJConsolidateMasks": {"class": KJConsolidateMasks, "name": "Consolidate Masks"}, + "ConsolidateMasksKJ": {"class": ConsolidateMasksKJ, "name": "Consolidate Masks"}, #images "AddLabel": {"class": AddLabel, "name": "Add Label"}, "ColorMatch": {"class": ColorMatch, "name": "Color Match"}, diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 8268f18..c519ee0 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -1426,12 +1426,12 @@ class SeparateMasks: return torch.empty((1, 64, 64), device=mask.device), -class KJConsolidateMasks: +class ConsolidateMasksKJ: @classmethod def INPUT_TYPES(s): return { "required": { - "mask": ("MASK",), + "masks": ("MASK",), "width": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}), "height": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}), "padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), @@ -1444,8 +1444,8 @@ class KJConsolidateMasks: CATEGORY = "KJNodes/masking" DESCRIPTION = "Consolidates a batch of separate masks by finding the largest group of masks that fit inside a tile of the given width and height (including the padding), and repeating until no more masks can be combined." - def consolidate(self, mask, width=512, height=512, padding=0): - B, H, W = mask.shape + def consolidate(self, masks, width=512, height=512, padding=0): + B, H, W = masks.shape def mask_fits(coords, candidate_coords): x_min, y_min, x_max, y_max = coords @@ -1459,7 +1459,7 @@ class KJConsolidateMasks: separated = [] final_masks = [] for b in range(B): - m = mask[b] + m = masks[b] rows, cols = m.any(dim=1), m.any(dim=0) y_min, y_max = torch.where(rows)[0][[0, -1]] x_min, x_max = torch.where(cols)[0][[0, -1]] @@ -1469,8 +1469,8 @@ class KJConsolidateMasks: separated.sort(key=lambda x: x[0]) fits = [] - for i, mask in enumerate(separated): - coord = mask[0] + for i, masks in enumerate(separated): + coord = masks[0] fits_in_box = [] for j, cand_mask in enumerate(separated): if i == j: