Merge branch 'pr/331'

This commit is contained in:
kijai 2025-09-16 18:23:24 +03:00
commit ffd4d1c908
2 changed files with 80 additions and 2 deletions

View File

@ -42,6 +42,7 @@ NODE_CONFIG = {
"ResizeMask": {"class": ResizeMask, "name": "Resize Mask"},
"RoundMask": {"class": RoundMask, "name": "Round Mask"},
"SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"},
"ConsolidateMasksKJ": {"class": ConsolidateMasksKJ, "name": "Consolidate Masks"},
#images
"AddLabel": {"class": AddLabel, "name": "Add Label"},
"ColorMatch": {"class": ColorMatch, "name": "Color Match"},
@ -243,4 +244,4 @@ if hasattr(PromptServer, "instance"):
[web.static("/kjweb_async", (Path(__file__).parent.absolute() / "kjweb_async").as_posix())]
)
except:
pass
pass

View File

@ -1424,4 +1424,81 @@ class SeparateMasks:
return out_masks,
else:
return torch.empty((1, 64, 64), device=mask.device),
class ConsolidateMasksKJ:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"masks": ("MASK",),
"width": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}),
"height": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}),
"padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "consolidate"
CATEGORY = "KJNodes/masking"
DESCRIPTION = "Consolidates a batch of separate masks by finding the largest group of masks that fit inside a tile of the given width and height (including the padding), and repeating until no more masks can be combined."
def consolidate(self, masks, width=512, height=512, padding=0):
B, H, W = masks.shape
def mask_fits(coords, candidate_coords):
x_min, y_min, x_max, y_max = coords
cx_min, cy_min, cx_max, cy_max = candidate_coords
nx_min, ny_min = min(x_min, cx_min), min(y_min, cy_min)
nx_max, ny_max = max(x_max, cx_max), max(y_max, cy_max)
if nx_min + width < nx_max + padding or ny_min + height < ny_max + padding:
return False, coords
return True, (nx_min, ny_min, nx_max, ny_max)
separated = []
final_masks = []
for b in range(B):
m = masks[b]
rows, cols = m.any(dim=1), m.any(dim=0)
y_min, y_max = torch.where(rows)[0][[0, -1]]
x_min, x_max = torch.where(cols)[0][[0, -1]]
w = x_max - x_min + 1
h = y_max - y_min + 1
separated.append(((x_min.item(), y_min.item(), x_max.item(), y_max.item()), m))
separated.sort(key=lambda x: x[0])
fits = []
for i, masks in enumerate(separated):
coord = masks[0]
fits_in_box = []
for j, cand_mask in enumerate(separated):
if i == j:
continue
r, coord = mask_fits(coord, cand_mask[0])
if r:
fits_in_box.append(j)
fits.append((i, fits_in_box))
fits.sort(key=lambda x: -len(x[1]))
seen = []
unique_fits = []
for idx, fs in fits:
uniq = [i for i in fs if i not in seen]
unique_fits.append((idx, fs, uniq))
seen.extend(uniq)
unique_fits.sort(key=lambda x: (-len(x[1]), -len(x[2])))
merged = []
for mask_idx, fitting_masks, _ in unique_fits:
if mask_idx in merged:
continue
fitting_masks = [i for i in fitting_masks if i not in merged]
combined_mask = separated[mask_idx][1].clone()
for i in fitting_masks:
combined_mask += separated[i][1]
merged.append(i)
merged.append(mask_idx)
final_masks.append(combined_mask)
print(f"Consolidated {B} masks into {len(final_masks)}")
return (torch.stack(final_masks, dim=0),)