mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-02 04:20:07 +08:00
Merge branch 'pr/331'
This commit is contained in:
commit
ffd4d1c908
@ -42,6 +42,7 @@ NODE_CONFIG = {
|
|||||||
"ResizeMask": {"class": ResizeMask, "name": "Resize Mask"},
|
"ResizeMask": {"class": ResizeMask, "name": "Resize Mask"},
|
||||||
"RoundMask": {"class": RoundMask, "name": "Round Mask"},
|
"RoundMask": {"class": RoundMask, "name": "Round Mask"},
|
||||||
"SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"},
|
"SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"},
|
||||||
|
"ConsolidateMasksKJ": {"class": ConsolidateMasksKJ, "name": "Consolidate Masks"},
|
||||||
#images
|
#images
|
||||||
"AddLabel": {"class": AddLabel, "name": "Add Label"},
|
"AddLabel": {"class": AddLabel, "name": "Add Label"},
|
||||||
"ColorMatch": {"class": ColorMatch, "name": "Color Match"},
|
"ColorMatch": {"class": ColorMatch, "name": "Color Match"},
|
||||||
@ -243,4 +244,4 @@ if hasattr(PromptServer, "instance"):
|
|||||||
[web.static("/kjweb_async", (Path(__file__).parent.absolute() / "kjweb_async").as_posix())]
|
[web.static("/kjweb_async", (Path(__file__).parent.absolute() / "kjweb_async").as_posix())]
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -1424,4 +1424,81 @@ class SeparateMasks:
|
|||||||
return out_masks,
|
return out_masks,
|
||||||
else:
|
else:
|
||||||
return torch.empty((1, 64, 64), device=mask.device),
|
return torch.empty((1, 64, 64), device=mask.device),
|
||||||
|
|
||||||
|
|
||||||
|
class ConsolidateMasksKJ:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"masks": ("MASK",),
|
||||||
|
"width": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}),
|
||||||
|
"height": ("INT", {"default": 512, "min": 0, "max": 4096, "step": 64}),
|
||||||
|
"padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MASK",)
|
||||||
|
FUNCTION = "consolidate"
|
||||||
|
|
||||||
|
CATEGORY = "KJNodes/masking"
|
||||||
|
DESCRIPTION = "Consolidates a batch of separate masks by finding the largest group of masks that fit inside a tile of the given width and height (including the padding), and repeating until no more masks can be combined."
|
||||||
|
|
||||||
|
def consolidate(self, masks, width=512, height=512, padding=0):
|
||||||
|
B, H, W = masks.shape
|
||||||
|
|
||||||
|
def mask_fits(coords, candidate_coords):
|
||||||
|
x_min, y_min, x_max, y_max = coords
|
||||||
|
cx_min, cy_min, cx_max, cy_max = candidate_coords
|
||||||
|
nx_min, ny_min = min(x_min, cx_min), min(y_min, cy_min)
|
||||||
|
nx_max, ny_max = max(x_max, cx_max), max(y_max, cy_max)
|
||||||
|
if nx_min + width < nx_max + padding or ny_min + height < ny_max + padding:
|
||||||
|
return False, coords
|
||||||
|
return True, (nx_min, ny_min, nx_max, ny_max)
|
||||||
|
|
||||||
|
separated = []
|
||||||
|
final_masks = []
|
||||||
|
for b in range(B):
|
||||||
|
m = masks[b]
|
||||||
|
rows, cols = m.any(dim=1), m.any(dim=0)
|
||||||
|
y_min, y_max = torch.where(rows)[0][[0, -1]]
|
||||||
|
x_min, x_max = torch.where(cols)[0][[0, -1]]
|
||||||
|
w = x_max - x_min + 1
|
||||||
|
h = y_max - y_min + 1
|
||||||
|
separated.append(((x_min.item(), y_min.item(), x_max.item(), y_max.item()), m))
|
||||||
|
|
||||||
|
separated.sort(key=lambda x: x[0])
|
||||||
|
fits = []
|
||||||
|
for i, masks in enumerate(separated):
|
||||||
|
coord = masks[0]
|
||||||
|
fits_in_box = []
|
||||||
|
for j, cand_mask in enumerate(separated):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
r, coord = mask_fits(coord, cand_mask[0])
|
||||||
|
if r:
|
||||||
|
fits_in_box.append(j)
|
||||||
|
fits.append((i, fits_in_box))
|
||||||
|
fits.sort(key=lambda x: -len(x[1]))
|
||||||
|
seen = []
|
||||||
|
unique_fits = []
|
||||||
|
for idx, fs in fits:
|
||||||
|
uniq = [i for i in fs if i not in seen]
|
||||||
|
unique_fits.append((idx, fs, uniq))
|
||||||
|
seen.extend(uniq)
|
||||||
|
unique_fits.sort(key=lambda x: (-len(x[1]), -len(x[2])))
|
||||||
|
merged = []
|
||||||
|
for mask_idx, fitting_masks, _ in unique_fits:
|
||||||
|
if mask_idx in merged:
|
||||||
|
continue
|
||||||
|
fitting_masks = [i for i in fitting_masks if i not in merged]
|
||||||
|
combined_mask = separated[mask_idx][1].clone()
|
||||||
|
for i in fitting_masks:
|
||||||
|
combined_mask += separated[i][1]
|
||||||
|
merged.append(i)
|
||||||
|
merged.append(mask_idx)
|
||||||
|
final_masks.append(combined_mask)
|
||||||
|
|
||||||
|
print(f"Consolidated {B} masks into {len(final_masks)}")
|
||||||
|
return (torch.stack(final_masks, dim=0),)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user