Crop/Uncrop fixes

This commit is contained in:
kijai 2023-11-09 16:56:42 +02:00
parent fcf4b9c235
commit 9af6f33160

View File

@ -1272,6 +1272,7 @@ class ImageBatchTestPattern:
#based on nodes from mtb https://github.com/melMass/comfy_mtb
from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor
from torchvision.transforms import Resize
class BatchCropFromMask:
@ -1304,38 +1305,14 @@ class BatchCropFromMask:
CATEGORY = "KJNodes/masking"
def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
"""
Smooth the bounding box size using exponential smoothing.
return int(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
Args:
prev_bbox_size (int): The bounding box size of the previous frame.
curr_bbox_size (int): The bounding box size of the current frame.
alpha (float): The smoothing factor, between 0 and 1.
A larger alpha places more weight on the current frame's size.
Returns:
int: The smoothed bounding box size.
"""
return int(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
def smooth_center(self, prev_center, curr_center, alpha=0.5):
"""
Smooth the center coordinates using exponential smoothing.
Args:
prev_center (tuple): The center coordinates of the previous frame.
curr_center (tuple): The center coordinates of the current frame.
alpha (float): The smoothing factor, between 0 and 1.
A larger alpha places more weight on the current frame's center.
Returns:
tuple: The smoothed center coordinates.
"""
return (int(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
int(alpha * curr_center[1] + (1 - alpha) * prev_center[1]))
def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
bounding_boxes = []
cropped_images = []
@ -1402,9 +1379,15 @@ class BatchCropFromMask:
# Crop the image from the bounding box
cropped_img = img[min_y:max_y, min_x:max_x, :]
cropped_images.append(cropped_img)
# Resize the cropped image to a fixed size
resize_transform = Resize((self.max_bbox_size, self.max_bbox_size))
resized_img = resize_transform(cropped_img.permute(2, 0, 1)).permute(1, 2, 0)
cropped_images.append(resized_img)
cropped_out = torch.stack(cropped_images, dim=0)
return (original_images, cropped_out, bounding_boxes, self.max_bbox_size, self.max_bbox_size, )
@ -1434,10 +1417,8 @@ class BatchUncrop:
"original_images": ("IMAGE",),
"cropped_images": ("IMAGE",),
"bboxes": ("BBOX",),
"border_blending": (
"FLOAT",
{"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01},
),
"border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ),
"crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}
}
@ -1446,7 +1427,7 @@ class BatchUncrop:
CATEGORY = "KJNodes/masking"
def uncrop(self, original_images, cropped_images, bboxes, border_blending):
def uncrop(self, original_images, cropped_images, bboxes, border_blending, crop_rescale):
def inset_border(image, border_width=20, border_color=(0)):
width, height = image.size
bordered_image = Image.new(image.mode, (width, height), border_color)
@ -1462,6 +1443,7 @@ class BatchUncrop:
input_images = tensor2pil(original_images)
crop_imgs = tensor2pil(cropped_images)
out_images = []
for i in range(len(input_images)):
img = input_images[i]
@ -1472,9 +1454,18 @@ class BatchUncrop:
bb_x, bb_y, bb_width, bb_height = bbox
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
# scale factors
scale_x = crop_rescale
scale_y = crop_rescale
# scaled paste_region
paste_region = (int(paste_region[0]*scale_x), int(paste_region[1]*scale_y), int(paste_region[2]*scale_x), int(paste_region[3]*scale_y))
# rescale the crop image to fit the paste_region
crop = crop.resize((int(paste_region[2]-paste_region[0]), int(paste_region[3]-paste_region[1])))
crop_img = crop.convert("RGB")
if border_blending > 1.0:
border_blending = 1.0
elif border_blending < 0.0:
@ -1485,9 +1476,9 @@ class BatchUncrop:
blend = img.convert("RGBA")
mask = Image.new("L", img.size, 0)
mask_block = Image.new("L", (bb_width, bb_height), 255)
mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255)
mask_block = inset_border(mask_block, int(blend_ratio / 2), (0))
mask.paste(mask_block, paste_region)
blend.paste(crop_img, paste_region)