diff --git a/nodes.py b/nodes.py index 9b95279..ac56557 100644 --- a/nodes.py +++ b/nodes.py @@ -1865,8 +1865,13 @@ class BatchCropFromMaskAdvanced: def calculate_bbox(mask): non_zero_indices = np.nonzero(np.array(mask)) - min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) - min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) + + # handle empty masks + min_x, max_x, min_y, max_y = 0, 0, 0, 0 + if len(non_zero_indices[1]) > 0 and len(non_zero_indices[0]) > 0: + min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) + min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) + width = max_x - min_x height = max_y - min_y bbox_size = max(width, height) @@ -1896,66 +1901,79 @@ class BatchCropFromMaskAdvanced: # Make sure max_bbox_size is divisible by 16, if not, round it upwards so it is self.max_bbox_size = math.ceil(self.max_bbox_size / 16) * 16 + if self.max_bbox_size > original_images[0].shape[0] or self.max_bbox_size > original_images[0].shape[1]: + # max_bbox_size can only be as big as our input's width or height, and it has to be even + self.max_bbox_size = math.floor(min(original_images[0].shape[0], original_images[0].shape[1]) / 2) * 2 + # Then, for each mask and corresponding image... for i, (mask, img) in enumerate(zip(masks, original_images)): _mask = tensor2pil(mask)[0] non_zero_indices = np.nonzero(np.array(_mask)) - min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) - min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) - # Calculate center of bounding box - center_x = np.mean(non_zero_indices[1]) - center_y = np.mean(non_zero_indices[0]) - curr_center = (round(center_x), round(center_y)) + # check for empty masks + if len(non_zero_indices[0]) > 0 and len(non_zero_indices[1]) > 0: + min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) + min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) - # If this is the first frame, initialize prev_center with curr_center - if not hasattr(self, 'prev_center'): - self.prev_center = curr_center + # Calculate center of bounding box + center_x = np.mean(non_zero_indices[1]) + center_y = np.mean(non_zero_indices[0]) + curr_center = (round(center_x), round(center_y)) - # Smooth the changes in the center coordinates from the second frame onwards - if i > 0: - center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha) + # If this is the first frame, initialize prev_center with curr_center + if not hasattr(self, 'prev_center'): + self.prev_center = curr_center + + # Smooth the changes in the center coordinates from the second frame onwards + if i > 0: + center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha) + else: + center = curr_center + + # Update prev_center for the next frame + self.prev_center = center + + # Create bounding box using max_bbox_size + half_box_size = self.max_bbox_size // 2 + min_x = max(0, center[0] - half_box_size) + max_x = min(img.shape[1], center[0] + half_box_size) + min_y = max(0, center[1] - half_box_size) + max_y = min(img.shape[0], center[1] + half_box_size) + + # Append bounding box coordinates + bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y)) + + # Crop the image from the bounding box + cropped_img = img[min_y:max_y, min_x:max_x, :] + cropped_mask = mask[min_y:max_y, min_x:max_x] + + # Resize the cropped image to a fixed size + new_size = max(cropped_img.shape[0], cropped_img.shape[1]) + resize_transform = Resize(new_size, interpolation=InterpolationMode.NEAREST, max_size=max(img.shape[0], img.shape[1])) + resized_mask = resize_transform(cropped_mask.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) + resized_img = resize_transform(cropped_img.permute(2, 0, 1)) + # Perform the center crop to the desired size + # Constrain the crop to the smaller of our bbox or our image so we don't expand past the image dimensions. + crop_transform = CenterCrop((min(self.max_bbox_size, resized_img.shape[1]), min(self.max_bbox_size, resized_img.shape[2]))) + + cropped_resized_img = crop_transform(resized_img) + cropped_images.append(cropped_resized_img.permute(1, 2, 0)) + + cropped_resized_mask = crop_transform(resized_mask) + cropped_masks.append(cropped_resized_mask) + + combined_cropped_img = original_images[i][new_min_y:new_max_y, new_min_x:new_max_x, :] + combined_cropped_images.append(combined_cropped_img) + + combined_cropped_mask = masks[i][new_min_y:new_max_y, new_min_x:new_max_x] + combined_cropped_masks.append(combined_cropped_mask) else: - center = curr_center + bounding_boxes.append((0, 0, img.shape[1], img.shape[0])) + cropped_images.append(img) + cropped_masks.append(mask) + combined_cropped_images.append(img) + combined_cropped_masks.append(mask) - # Update prev_center for the next frame - self.prev_center = center - - # Create bounding box using max_bbox_size - half_box_size = self.max_bbox_size // 2 - half_box_size = self.max_bbox_size // 2 - min_x = max(0, center[0] - half_box_size) - max_x = min(img.shape[1], center[0] + half_box_size) - min_y = max(0, center[1] - half_box_size) - max_y = min(img.shape[0], center[1] + half_box_size) - - # Append bounding box coordinates - bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y)) - - # Crop the image from the bounding box - cropped_img = img[min_y:max_y, min_x:max_x, :] - cropped_mask = mask[min_y:max_y, min_x:max_x] - - # Resize the cropped image to a fixed size - new_size = max(cropped_img.shape[0], cropped_img.shape[1]) - resize_transform = Resize(new_size, interpolation = InterpolationMode.NEAREST) - resized_mask = resize_transform(cropped_mask.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) - resized_img = resize_transform(cropped_img.permute(2, 0, 1)) - # Perform the center crop to the desired size - crop_transform = CenterCrop((self.max_bbox_size, self.max_bbox_size)) - - cropped_resized_img = crop_transform(resized_img) - cropped_images.append(cropped_resized_img.permute(1, 2, 0)) - - cropped_resized_mask = crop_transform(resized_mask) - cropped_masks.append(cropped_resized_mask) - - combined_cropped_img = original_images[i][new_min_y:new_max_y, new_min_x:new_max_x, :] - combined_cropped_images.append(combined_cropped_img) - - combined_cropped_mask = masks[i][new_min_y:new_max_y, new_min_x:new_max_x] - combined_cropped_masks.append(combined_cropped_mask) - cropped_out = torch.stack(cropped_images, dim=0) combined_crop_out = torch.stack(combined_cropped_images, dim=0) cropped_masks_out = torch.stack(cropped_masks, dim=0)