order the separated masks

This commit is contained in:
kijai 2025-03-19 00:31:18 +02:00
parent 89fb17ae84
commit 62ee13ef76

View File

@ -1347,7 +1347,7 @@ class SeparateMasks:
return best_approx.squeeze() if best_approx is not None else hull.squeeze()
def separate(self, mask: torch.Tensor, size_threshold_width: int, size_threshold_height: int, max_poly_points: int, mode: str):
from scipy.ndimage import label
from scipy.ndimage import label, center_of_mass
import numpy as np
B, H, W = mask.shape
@ -1360,6 +1360,7 @@ class SeparateMasks:
structure = np.ones((3, 3), dtype=np.int8)
labeled, ncomponents = label(mask_np, structure=structure)
pbar = ProgressBar(ncomponents)
for component in range(1, ncomponents + 1):
component_mask_np = (labeled == component).astype(np.uint8)
@ -1370,20 +1371,25 @@ class SeparateMasks:
width = x_max - x_min + 1
height = y_max - y_min + 1
print(f"Component {component}: width={width}, height={height}")
centroid_x = (x_min + x_max) / 2 # Calculate x centroid
print(f"Component {component}: width={width}, height={height}, x_pos={centroid_x}")
if width >= size_threshold_width and height >= size_threshold_height:
polygon = self.get_mask_polygon(component_mask_np, max_poly_points)
if mode != "area" and polygon is not None:
poly_mask = self.polygon_to_mask(polygon, (H, W))
poly_mask = torch.tensor(poly_mask, device=mask.device)
separated.append(poly_mask)
elif mode == "area":
if mode != "area":
polygon = self.get_mask_polygon(component_mask_np, max_poly_points)
if polygon is not None:
poly_mask = self.polygon_to_mask(polygon, (H, W))
poly_mask = torch.tensor(poly_mask, device=mask.device)
separated.append((centroid_x, poly_mask))
else:
area_mask = torch.tensor(component_mask_np, device=mask.device)
separated.append(area_mask)
separated.append((centroid_x, area_mask))
pbar.update(1)
if len(separated) > 0:
# Sort by x position and extract only the masks
separated.sort(key=lambda x: x[0])
separated = [x[1] for x in separated]
out_masks = torch.stack(separated, dim=0)
return out_masks,
else: