mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-26 12:51:50 +08:00
order the separated masks
This commit is contained in:
parent
89fb17ae84
commit
62ee13ef76
@ -1347,7 +1347,7 @@ class SeparateMasks:
|
||||
return best_approx.squeeze() if best_approx is not None else hull.squeeze()
|
||||
|
||||
def separate(self, mask: torch.Tensor, size_threshold_width: int, size_threshold_height: int, max_poly_points: int, mode: str):
|
||||
from scipy.ndimage import label
|
||||
from scipy.ndimage import label, center_of_mass
|
||||
import numpy as np
|
||||
|
||||
B, H, W = mask.shape
|
||||
@ -1360,6 +1360,7 @@ class SeparateMasks:
|
||||
structure = np.ones((3, 3), dtype=np.int8)
|
||||
labeled, ncomponents = label(mask_np, structure=structure)
|
||||
pbar = ProgressBar(ncomponents)
|
||||
|
||||
for component in range(1, ncomponents + 1):
|
||||
component_mask_np = (labeled == component).astype(np.uint8)
|
||||
|
||||
@ -1370,20 +1371,25 @@ class SeparateMasks:
|
||||
|
||||
width = x_max - x_min + 1
|
||||
height = y_max - y_min + 1
|
||||
print(f"Component {component}: width={width}, height={height}")
|
||||
centroid_x = (x_min + x_max) / 2 # Calculate x centroid
|
||||
print(f"Component {component}: width={width}, height={height}, x_pos={centroid_x}")
|
||||
|
||||
if width >= size_threshold_width and height >= size_threshold_height:
|
||||
polygon = self.get_mask_polygon(component_mask_np, max_poly_points)
|
||||
if mode != "area" and polygon is not None:
|
||||
poly_mask = self.polygon_to_mask(polygon, (H, W))
|
||||
poly_mask = torch.tensor(poly_mask, device=mask.device)
|
||||
separated.append(poly_mask)
|
||||
elif mode == "area":
|
||||
if mode != "area":
|
||||
polygon = self.get_mask_polygon(component_mask_np, max_poly_points)
|
||||
if polygon is not None:
|
||||
poly_mask = self.polygon_to_mask(polygon, (H, W))
|
||||
poly_mask = torch.tensor(poly_mask, device=mask.device)
|
||||
separated.append((centroid_x, poly_mask))
|
||||
else:
|
||||
area_mask = torch.tensor(component_mask_np, device=mask.device)
|
||||
separated.append(area_mask)
|
||||
separated.append((centroid_x, area_mask))
|
||||
pbar.update(1)
|
||||
|
||||
if len(separated) > 0:
|
||||
# Sort by x position and extract only the masks
|
||||
separated.sort(key=lambda x: x[0])
|
||||
separated = [x[1] for x in separated]
|
||||
out_masks = torch.stack(separated, dim=0)
|
||||
return out_masks,
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user