mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-08 02:37:51 +08:00
order the separated masks
This commit is contained in:
parent
89fb17ae84
commit
62ee13ef76
@ -1347,7 +1347,7 @@ class SeparateMasks:
|
|||||||
return best_approx.squeeze() if best_approx is not None else hull.squeeze()
|
return best_approx.squeeze() if best_approx is not None else hull.squeeze()
|
||||||
|
|
||||||
def separate(self, mask: torch.Tensor, size_threshold_width: int, size_threshold_height: int, max_poly_points: int, mode: str):
|
def separate(self, mask: torch.Tensor, size_threshold_width: int, size_threshold_height: int, max_poly_points: int, mode: str):
|
||||||
from scipy.ndimage import label
|
from scipy.ndimage import label, center_of_mass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
B, H, W = mask.shape
|
B, H, W = mask.shape
|
||||||
@ -1360,6 +1360,7 @@ class SeparateMasks:
|
|||||||
structure = np.ones((3, 3), dtype=np.int8)
|
structure = np.ones((3, 3), dtype=np.int8)
|
||||||
labeled, ncomponents = label(mask_np, structure=structure)
|
labeled, ncomponents = label(mask_np, structure=structure)
|
||||||
pbar = ProgressBar(ncomponents)
|
pbar = ProgressBar(ncomponents)
|
||||||
|
|
||||||
for component in range(1, ncomponents + 1):
|
for component in range(1, ncomponents + 1):
|
||||||
component_mask_np = (labeled == component).astype(np.uint8)
|
component_mask_np = (labeled == component).astype(np.uint8)
|
||||||
|
|
||||||
@ -1370,20 +1371,25 @@ class SeparateMasks:
|
|||||||
|
|
||||||
width = x_max - x_min + 1
|
width = x_max - x_min + 1
|
||||||
height = y_max - y_min + 1
|
height = y_max - y_min + 1
|
||||||
print(f"Component {component}: width={width}, height={height}")
|
centroid_x = (x_min + x_max) / 2 # Calculate x centroid
|
||||||
|
print(f"Component {component}: width={width}, height={height}, x_pos={centroid_x}")
|
||||||
|
|
||||||
if width >= size_threshold_width and height >= size_threshold_height:
|
if width >= size_threshold_width and height >= size_threshold_height:
|
||||||
polygon = self.get_mask_polygon(component_mask_np, max_poly_points)
|
if mode != "area":
|
||||||
if mode != "area" and polygon is not None:
|
polygon = self.get_mask_polygon(component_mask_np, max_poly_points)
|
||||||
poly_mask = self.polygon_to_mask(polygon, (H, W))
|
if polygon is not None:
|
||||||
poly_mask = torch.tensor(poly_mask, device=mask.device)
|
poly_mask = self.polygon_to_mask(polygon, (H, W))
|
||||||
separated.append(poly_mask)
|
poly_mask = torch.tensor(poly_mask, device=mask.device)
|
||||||
elif mode == "area":
|
separated.append((centroid_x, poly_mask))
|
||||||
|
else:
|
||||||
area_mask = torch.tensor(component_mask_np, device=mask.device)
|
area_mask = torch.tensor(component_mask_np, device=mask.device)
|
||||||
separated.append(area_mask)
|
separated.append((centroid_x, area_mask))
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
if len(separated) > 0:
|
if len(separated) > 0:
|
||||||
|
# Sort by x position and extract only the masks
|
||||||
|
separated.sort(key=lambda x: x[0])
|
||||||
|
separated = [x[1] for x in separated]
|
||||||
out_masks = torch.stack(separated, dim=0)
|
out_masks = torch.stack(separated, dim=0)
|
||||||
return out_masks,
|
return out_masks,
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user