mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-23 20:44:33 +08:00
Merge branch 'pr/331'
This commit is contained in:
commit
ffd4d1c908
@ -42,6 +42,7 @@ NODE_CONFIG = {
|
||||
"ResizeMask": {"class": ResizeMask, "name": "Resize Mask"},
|
||||
"RoundMask": {"class": RoundMask, "name": "Round Mask"},
|
||||
"SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"},
|
||||
"ConsolidateMasksKJ": {"class": ConsolidateMasksKJ, "name": "Consolidate Masks"},
|
||||
#images
|
||||
"AddLabel": {"class": AddLabel, "name": "Add Label"},
|
||||
"ColorMatch": {"class": ColorMatch, "name": "Color Match"},
|
||||
@ -243,4 +244,4 @@ if hasattr(PromptServer, "instance"):
|
||||
[web.static("/kjweb_async", (Path(__file__).parent.absolute() / "kjweb_async").as_posix())]
|
||||
)
|
||||
except:
|
||||
pass
|
||||
pass
|
||||
|
||||
@ -1424,4 +1424,81 @@ class SeparateMasks:
|
||||
return out_masks,
|
||||
else:
|
||||
return torch.empty((1, 64, 64), device=mask.device),
|
||||
|
||||
|
||||
|
||||
class ConsolidateMasksKJ:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"masks": ("MASK",),
|
||||
"width": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}),
|
||||
"height": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}),
|
||||
"padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "consolidate"
|
||||
|
||||
CATEGORY = "KJNodes/masking"
|
||||
DESCRIPTION = "Consolidates a batch of separate masks by finding the largest group of masks that fit inside a tile of the given width and height (including the padding), and repeating until no more masks can be combined."
|
||||
|
||||
def consolidate(self, masks, width=512, height=512, padding=0):
|
||||
B, H, W = masks.shape
|
||||
|
||||
def mask_fits(coords, candidate_coords):
|
||||
x_min, y_min, x_max, y_max = coords
|
||||
cx_min, cy_min, cx_max, cy_max = candidate_coords
|
||||
nx_min, ny_min = min(x_min, cx_min), min(y_min, cy_min)
|
||||
nx_max, ny_max = max(x_max, cx_max), max(y_max, cy_max)
|
||||
if nx_min + width < nx_max + padding or ny_min + height < ny_max + padding:
|
||||
return False, coords
|
||||
return True, (nx_min, ny_min, nx_max, ny_max)
|
||||
|
||||
separated = []
|
||||
final_masks = []
|
||||
for b in range(B):
|
||||
m = masks[b]
|
||||
rows, cols = m.any(dim=1), m.any(dim=0)
|
||||
y_min, y_max = torch.where(rows)[0][[0, -1]]
|
||||
x_min, x_max = torch.where(cols)[0][[0, -1]]
|
||||
w = x_max - x_min + 1
|
||||
h = y_max - y_min + 1
|
||||
separated.append(((x_min.item(), y_min.item(), x_max.item(), y_max.item()), m))
|
||||
|
||||
separated.sort(key=lambda x: x[0])
|
||||
fits = []
|
||||
for i, masks in enumerate(separated):
|
||||
coord = masks[0]
|
||||
fits_in_box = []
|
||||
for j, cand_mask in enumerate(separated):
|
||||
if i == j:
|
||||
continue
|
||||
r, coord = mask_fits(coord, cand_mask[0])
|
||||
if r:
|
||||
fits_in_box.append(j)
|
||||
fits.append((i, fits_in_box))
|
||||
fits.sort(key=lambda x: -len(x[1]))
|
||||
seen = []
|
||||
unique_fits = []
|
||||
for idx, fs in fits:
|
||||
uniq = [i for i in fs if i not in seen]
|
||||
unique_fits.append((idx, fs, uniq))
|
||||
seen.extend(uniq)
|
||||
unique_fits.sort(key=lambda x: (-len(x[1]), -len(x[2])))
|
||||
merged = []
|
||||
for mask_idx, fitting_masks, _ in unique_fits:
|
||||
if mask_idx in merged:
|
||||
continue
|
||||
fitting_masks = [i for i in fitting_masks if i not in merged]
|
||||
combined_mask = separated[mask_idx][1].clone()
|
||||
for i in fitting_masks:
|
||||
combined_mask += separated[i][1]
|
||||
merged.append(i)
|
||||
merged.append(mask_idx)
|
||||
final_masks.append(combined_mask)
|
||||
|
||||
print(f"Consolidated {B} masks into {len(final_masks)}")
|
||||
return (torch.stack(final_masks, dim=0),)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user