This commit is contained in:
kijai 2025-09-16 18:23:09 +03:00
parent ed0ab5231f
commit 1fff7fed6b
2 changed files with 8 additions and 8 deletions

View File

@ -41,7 +41,7 @@ NODE_CONFIG = {
"ResizeMask": {"class": ResizeMask, "name": "Resize Mask"},
"RoundMask": {"class": RoundMask, "name": "Round Mask"},
"SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"},
"KJConsolidateMasks": {"class": KJConsolidateMasks, "name": "Consolidate Masks"},
"ConsolidateMasksKJ": {"class": ConsolidateMasksKJ, "name": "Consolidate Masks"},
#images
"AddLabel": {"class": AddLabel, "name": "Add Label"},
"ColorMatch": {"class": ColorMatch, "name": "Color Match"},

View File

@ -1426,12 +1426,12 @@ class SeparateMasks:
return torch.empty((1, 64, 64), device=mask.device),
class KJConsolidateMasks:
class ConsolidateMasksKJ:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mask": ("MASK",),
"masks": ("MASK",),
"width": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}),
"height": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}),
"padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
@ -1444,8 +1444,8 @@ class KJConsolidateMasks:
CATEGORY = "KJNodes/masking"
DESCRIPTION = "Consolidates a batch of separate masks by finding the largest group of masks that fit inside a tile of the given width and height (including the padding), and repeating until no more masks can be combined."
def consolidate(self, mask, width=512, height=512, padding=0):
B, H, W = mask.shape
def consolidate(self, masks, width=512, height=512, padding=0):
B, H, W = masks.shape
def mask_fits(coords, candidate_coords):
x_min, y_min, x_max, y_max = coords
@ -1459,7 +1459,7 @@ class KJConsolidateMasks:
separated = []
final_masks = []
for b in range(B):
m = mask[b]
m = masks[b]
rows, cols = m.any(dim=1), m.any(dim=0)
y_min, y_max = torch.where(rows)[0][[0, -1]]
x_min, x_max = torch.where(cols)[0][[0, -1]]
@ -1469,8 +1469,8 @@ class KJConsolidateMasks:
separated.sort(key=lambda x: x[0])
fits = []
for i, mask in enumerate(separated):
coord = mask[0]
for i, masks in enumerate(separated):
coord = masks[0]
fits_in_box = []
for j, cand_mask in enumerate(separated):
if i == j: