diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index b7f1161..8852d06 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -1347,7 +1347,7 @@ class SeparateMasks: return best_approx.squeeze() if best_approx is not None else hull.squeeze() def separate(self, mask: torch.Tensor, size_threshold_width: int, size_threshold_height: int, max_poly_points: int, mode: str): - from scipy.ndimage import label + from scipy.ndimage import label, center_of_mass import numpy as np B, H, W = mask.shape @@ -1360,6 +1360,7 @@ class SeparateMasks: structure = np.ones((3, 3), dtype=np.int8) labeled, ncomponents = label(mask_np, structure=structure) pbar = ProgressBar(ncomponents) + for component in range(1, ncomponents + 1): component_mask_np = (labeled == component).astype(np.uint8) @@ -1370,20 +1371,25 @@ class SeparateMasks: width = x_max - x_min + 1 height = y_max - y_min + 1 - print(f"Component {component}: width={width}, height={height}") + centroid_x = (x_min + x_max) / 2 # Calculate x centroid + print(f"Component {component}: width={width}, height={height}, x_pos={centroid_x}") if width >= size_threshold_width and height >= size_threshold_height: - polygon = self.get_mask_polygon(component_mask_np, max_poly_points) - if mode != "area" and polygon is not None: - poly_mask = self.polygon_to_mask(polygon, (H, W)) - poly_mask = torch.tensor(poly_mask, device=mask.device) - separated.append(poly_mask) - elif mode == "area": + if mode != "area": + polygon = self.get_mask_polygon(component_mask_np, max_poly_points) + if polygon is not None: + poly_mask = self.polygon_to_mask(polygon, (H, W)) + poly_mask = torch.tensor(poly_mask, device=mask.device) + separated.append((centroid_x, poly_mask)) + else: area_mask = torch.tensor(component_mask_np, device=mask.device) - separated.append(area_mask) + separated.append((centroid_x, area_mask)) pbar.update(1) if len(separated) > 0: + # Sort by x position and extract only the masks + separated.sort(key=lambda x: x[0]) + separated = [x[1] for x in separated] out_masks = torch.stack(separated, dim=0) return out_masks, else: