diff --git a/nodes.py b/nodes.py index 3d76f12..612d0bd 100644 --- a/nodes.py +++ b/nodes.py @@ -1903,14 +1903,8 @@ class BatchCropFromMaskAdvanced: _mask = tensor2pil(mask)[0] non_zero_indices = np.nonzero(np.array(_mask)) - # handle empty masks - if len(non_zero_indices[0]) == 0 or len(non_zero_indices[1]) == 0: - bounding_boxes.append((0, 0, img.shape[1], img.shape[0])) - cropped_images.append(img) - cropped_masks.append(mask) - combined_cropped_images.append(img) - combined_cropped_masks.append(mask) - else: + # check for empty masks + if len(non_zero_indices[0]) > 0 and len(non_zero_indices[1]) > 0: min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) @@ -1966,6 +1960,12 @@ class BatchCropFromMaskAdvanced: combined_cropped_mask = masks[i][new_min_y:new_max_y, new_min_x:new_max_x] combined_cropped_masks.append(combined_cropped_mask) + else: + bounding_boxes.append((0, 0, img.shape[1], img.shape[0])) + cropped_images.append(img) + cropped_masks.append(mask) + combined_cropped_images.append(img) + combined_cropped_masks.append(mask) cropped_out = torch.stack(cropped_images, dim=0) combined_crop_out = torch.stack(combined_cropped_images, dim=0)