diff --git a/__init__.py b/__init__.py index c3bb6c7..9846844 100644 --- a/__init__.py +++ b/__init__.py @@ -42,6 +42,7 @@ NODE_CONFIG = { "ResizeMask": {"class": ResizeMask, "name": "Resize Mask"}, "RoundMask": {"class": RoundMask, "name": "Round Mask"}, "SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"}, + "ConsolidateMasksKJ": {"class": ConsolidateMasksKJ, "name": "Consolidate Masks"}, #images "AddLabel": {"class": AddLabel, "name": "Add Label"}, "ColorMatch": {"class": ColorMatch, "name": "Color Match"}, @@ -243,4 +244,4 @@ if hasattr(PromptServer, "instance"): [web.static("/kjweb_async", (Path(__file__).parent.absolute() / "kjweb_async").as_posix())] ) except: - pass \ No newline at end of file + pass diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 9a84a12..c519ee0 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -1424,4 +1424,81 @@ class SeparateMasks: return out_masks, else: return torch.empty((1, 64, 64), device=mask.device), - \ No newline at end of file + + +class ConsolidateMasksKJ: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "masks": ("MASK",), + "width": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}), + "height": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}), + "padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), + }, + } + + RETURN_TYPES = ("MASK",) + FUNCTION = "consolidate" + + CATEGORY = "KJNodes/masking" + DESCRIPTION = "Consolidates a batch of separate masks by finding the largest group of masks that fit inside a tile of the given width and height (including the padding), and repeating until no more masks can be combined." + + def consolidate(self, masks, width=512, height=512, padding=0): + B, H, W = masks.shape + + def mask_fits(coords, candidate_coords): + x_min, y_min, x_max, y_max = coords + cx_min, cy_min, cx_max, cy_max = candidate_coords + nx_min, ny_min = min(x_min, cx_min), min(y_min, cy_min) + nx_max, ny_max = max(x_max, cx_max), max(y_max, cy_max) + if nx_min + width < nx_max + padding or ny_min + height < ny_max + padding: + return False, coords + return True, (nx_min, ny_min, nx_max, ny_max) + + separated = [] + final_masks = [] + for b in range(B): + m = masks[b] + rows, cols = m.any(dim=1), m.any(dim=0) + y_min, y_max = torch.where(rows)[0][[0, -1]] + x_min, x_max = torch.where(cols)[0][[0, -1]] + w = x_max - x_min + 1 + h = y_max - y_min + 1 + separated.append(((x_min.item(), y_min.item(), x_max.item(), y_max.item()), m)) + + separated.sort(key=lambda x: x[0]) + fits = [] + for i, masks in enumerate(separated): + coord = masks[0] + fits_in_box = [] + for j, cand_mask in enumerate(separated): + if i == j: + continue + r, coord = mask_fits(coord, cand_mask[0]) + if r: + fits_in_box.append(j) + fits.append((i, fits_in_box)) + fits.sort(key=lambda x: -len(x[1])) + seen = [] + unique_fits = [] + for idx, fs in fits: + uniq = [i for i in fs if i not in seen] + unique_fits.append((idx, fs, uniq)) + seen.extend(uniq) + unique_fits.sort(key=lambda x: (-len(x[1]), -len(x[2]))) + merged = [] + for mask_idx, fitting_masks, _ in unique_fits: + if mask_idx in merged: + continue + fitting_masks = [i for i in fitting_masks if i not in merged] + combined_mask = separated[mask_idx][1].clone() + for i in fitting_masks: + combined_mask += separated[i][1] + merged.append(i) + merged.append(mask_idx) + final_masks.append(combined_mask) + + print(f"Consolidated {B} masks into {len(final_masks)}") + return (torch.stack(final_masks, dim=0),) +