diff --git a/__init__.py b/__init__.py index 2c23dd4..80a5386 100644 --- a/__init__.py +++ b/__init__.py @@ -41,7 +41,7 @@ NODE_CONFIG = { "ResizeMask": {"class": ResizeMask, "name": "Resize Mask"}, "RoundMask": {"class": RoundMask, "name": "Round Mask"}, "SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"}, - "KJConsolidateMasks": {"class": KJConsolidateMasks, "name": "Consolidate Masks"}, + "ConsolidateMasksKJ": {"class": ConsolidateMasksKJ, "name": "Consolidate Masks"}, #images "AddLabel": {"class": AddLabel, "name": "Add Label"}, "ColorMatch": {"class": ColorMatch, "name": "Color Match"}, diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 8268f18..c519ee0 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -1426,12 +1426,12 @@ class SeparateMasks: return torch.empty((1, 64, 64), device=mask.device), -class KJConsolidateMasks: +class ConsolidateMasksKJ: @classmethod def INPUT_TYPES(s): return { "required": { - "mask": ("MASK",), + "masks": ("MASK",), "width": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}), "height": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}), "padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), @@ -1444,8 +1444,8 @@ class KJConsolidateMasks: CATEGORY = "KJNodes/masking" DESCRIPTION = "Consolidates a batch of separate masks by finding the largest group of masks that fit inside a tile of the given width and height (including the padding), and repeating until no more masks can be combined." - def consolidate(self, mask, width=512, height=512, padding=0): - B, H, W = mask.shape + def consolidate(self, masks, width=512, height=512, padding=0): + B, H, W = masks.shape def mask_fits(coords, candidate_coords): x_min, y_min, x_max, y_max = coords @@ -1459,7 +1459,7 @@ class KJConsolidateMasks: separated = [] final_masks = [] for b in range(B): - m = mask[b] + m = masks[b] rows, cols = m.any(dim=1), m.any(dim=0) y_min, y_max = torch.where(rows)[0][[0, -1]] x_min, x_max = torch.where(cols)[0][[0, -1]] @@ -1469,8 +1469,8 @@ class KJConsolidateMasks: separated.sort(key=lambda x: x[0]) fits = [] - for i, mask in enumerate(separated): - coord = mask[0] + for i, masks in enumerate(separated): + coord = masks[0] fits_in_box = [] for j, cand_mask in enumerate(separated): if i == j: