diff --git a/nodes.py b/nodes.py index 599b306..fd79ec3 100644 --- a/nodes.py +++ b/nodes.py @@ -1282,6 +1282,7 @@ class BatchCropFromMask: "original_images": ("IMAGE",), "masks": ("MASK",), "crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), }, } @@ -1302,12 +1303,48 @@ class BatchCropFromMask: FUNCTION = "crop" CATEGORY = "KJNodes/masking" - def crop(self, masks, original_images, crop_size_mult): + def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha): + """ + Smooth the bounding box size using exponential smoothing. + + Args: + prev_bbox_size (int): The bounding box size of the previous frame. + curr_bbox_size (int): The bounding box size of the current frame. + alpha (float): The smoothing factor, between 0 and 1. + A larger alpha places more weight on the current frame's size. + + Returns: + int: The smoothed bounding box size. + """ + return int(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size) + def smooth_center(self, prev_center, curr_center, alpha=0.5): + """ + Smooth the center coordinates using exponential smoothing. + + Args: + prev_center (tuple): The center coordinates of the previous frame. + curr_center (tuple): The center coordinates of the current frame. + alpha (float): The smoothing factor, between 0 and 1. + A larger alpha places more weight on the current frame's center. + + Returns: + tuple: The smoothed center coordinates. + """ + return (int(alpha * curr_center[0] + (1 - alpha) * prev_center[0]), + int(alpha * curr_center[1] + (1 - alpha) * prev_center[1])) + + def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha): + + bounding_boxes = [] cropped_images = [] + # Initialize max_bbox_size for the first frame + if not hasattr(self, 'max_bbox_size'): + self.max_bbox_size = 0 + # First, calculate the maximum bounding box size across all masks - max_bbox_size = 0 + curr_max_bbox_size = 0 for mask in masks: _mask = tensor2pil(mask)[0] non_zero_indices = np.nonzero(np.array(_mask)) @@ -1316,31 +1353,49 @@ class BatchCropFromMask: width = max_x - min_x height = max_y - min_y bbox_size = max(width, height) - max_bbox_size = max(max_bbox_size, bbox_size) + curr_max_bbox_size = max(curr_max_bbox_size, bbox_size) - # Make sure max_bbox_size is divisible by 32, if not, round it upwards so it is - max_bbox_size = math.ceil(max_bbox_size / 32) * 32 + # Smooth the changes in the bounding box size + self.max_bbox_size = self.smooth_bbox_size(self.max_bbox_size, curr_max_bbox_size, bbox_smooth_alpha) # Apply the crop size multiplier - max_bbox_size = int(max_bbox_size * crop_size_mult) + self.max_bbox_size = int(self.max_bbox_size * crop_size_mult) + + # Make sure max_bbox_size is divisible by 32, if not, round it upwards so it is + self.max_bbox_size = math.ceil(self.max_bbox_size / 32) * 32 # Then, for each mask and corresponding image... - for mask, img in zip(masks, original_images): + for i, (mask, img) in enumerate(zip(masks, original_images)): _mask = tensor2pil(mask)[0] non_zero_indices = np.nonzero(np.array(_mask)) min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) # Calculate center of bounding box - center_x = (max_x + min_x) // 2 - center_y = (max_y + min_y) // 2 + center_x = np.mean(non_zero_indices[1]) + center_y = np.mean(non_zero_indices[0]) + curr_center = (int(center_x), int(center_y)) + + # If this is the first frame, initialize prev_center with curr_center + if not hasattr(self, 'prev_center'): + self.prev_center = curr_center + + # Smooth the changes in the center coordinates from the second frame onwards + if i > 0: + center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha) + else: + center = curr_center + + # Update prev_center for the next frame + self.prev_center = center # Create bounding box using max_bbox_size - half_box_size = max_bbox_size // 2 - min_x = max(0, center_x - half_box_size) - max_x = min(img.shape[1], center_x + half_box_size) - min_y = max(0, center_y - half_box_size) - max_y = min(img.shape[0], center_y + half_box_size) + half_box_size = self.max_bbox_size // 2 + half_box_size = self.max_bbox_size // 2 + min_x = max(0, center[0] - half_box_size) + max_x = min(img.shape[1], center[0] + half_box_size) + min_y = max(0, center[1] - half_box_size) + max_y = min(img.shape[0], center[1] + half_box_size) # Append bounding box coordinates bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y)) @@ -1351,7 +1406,7 @@ class BatchCropFromMask: cropped_out = torch.stack(cropped_images, dim=0) - return (original_images, cropped_out, bounding_boxes, max_bbox_size, max_bbox_size, ) + return (original_images, cropped_out, bounding_boxes, self.max_bbox_size, self.max_bbox_size, ) def bbox_to_region(bbox, target_size=None):