From 9af6f331602e7a2a112f749832644a1eca0467e5 Mon Sep 17 00:00:00 2001 From: kijai Date: Thu, 9 Nov 2023 16:56:42 +0200 Subject: [PATCH] Crop/Uncrop fixes --- nodes.py | 61 ++++++++++++++++++++++++-------------------------------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/nodes.py b/nodes.py index fd79ec3..50a68fd 100644 --- a/nodes.py +++ b/nodes.py @@ -1272,6 +1272,7 @@ class ImageBatchTestPattern: #based on nodes from mtb https://github.com/melMass/comfy_mtb from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor +from torchvision.transforms import Resize class BatchCropFromMask: @@ -1304,38 +1305,14 @@ class BatchCropFromMask: CATEGORY = "KJNodes/masking" def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha): - """ - Smooth the bounding box size using exponential smoothing. + return int(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size) - Args: - prev_bbox_size (int): The bounding box size of the previous frame. - curr_bbox_size (int): The bounding box size of the current frame. - alpha (float): The smoothing factor, between 0 and 1. - A larger alpha places more weight on the current frame's size. - - Returns: - int: The smoothed bounding box size. - """ - return int(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size) def smooth_center(self, prev_center, curr_center, alpha=0.5): - """ - Smooth the center coordinates using exponential smoothing. - - Args: - prev_center (tuple): The center coordinates of the previous frame. - curr_center (tuple): The center coordinates of the current frame. - alpha (float): The smoothing factor, between 0 and 1. - A larger alpha places more weight on the current frame's center. - - Returns: - tuple: The smoothed center coordinates. - """ return (int(alpha * curr_center[0] + (1 - alpha) * prev_center[0]), int(alpha * curr_center[1] + (1 - alpha) * prev_center[1])) def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha): - - + bounding_boxes = [] cropped_images = [] @@ -1402,9 +1379,15 @@ class BatchCropFromMask: # Crop the image from the bounding box cropped_img = img[min_y:max_y, min_x:max_x, :] - cropped_images.append(cropped_img) + + # Resize the cropped image to a fixed size + resize_transform = Resize((self.max_bbox_size, self.max_bbox_size)) + resized_img = resize_transform(cropped_img.permute(2, 0, 1)).permute(1, 2, 0) + + cropped_images.append(resized_img) cropped_out = torch.stack(cropped_images, dim=0) + return (original_images, cropped_out, bounding_boxes, self.max_bbox_size, self.max_bbox_size, ) @@ -1434,10 +1417,8 @@ class BatchUncrop: "original_images": ("IMAGE",), "cropped_images": ("IMAGE",), "bboxes": ("BBOX",), - "border_blending": ( - "FLOAT", - {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, - ), + "border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ), + "crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), } } @@ -1446,7 +1427,7 @@ class BatchUncrop: CATEGORY = "KJNodes/masking" - def uncrop(self, original_images, cropped_images, bboxes, border_blending): + def uncrop(self, original_images, cropped_images, bboxes, border_blending, crop_rescale): def inset_border(image, border_width=20, border_color=(0)): width, height = image.size bordered_image = Image.new(image.mode, (width, height), border_color) @@ -1462,6 +1443,7 @@ class BatchUncrop: input_images = tensor2pil(original_images) crop_imgs = tensor2pil(cropped_images) + out_images = [] for i in range(len(input_images)): img = input_images[i] @@ -1472,9 +1454,18 @@ class BatchUncrop: bb_x, bb_y, bb_width, bb_height = bbox paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size) + + # scale factors + scale_x = crop_rescale + scale_y = crop_rescale + # scaled paste_region + paste_region = (int(paste_region[0]*scale_x), int(paste_region[1]*scale_y), int(paste_region[2]*scale_x), int(paste_region[3]*scale_y)) + + # rescale the crop image to fit the paste_region + crop = crop.resize((int(paste_region[2]-paste_region[0]), int(paste_region[3]-paste_region[1]))) crop_img = crop.convert("RGB") - + if border_blending > 1.0: border_blending = 1.0 elif border_blending < 0.0: @@ -1485,9 +1476,9 @@ class BatchUncrop: blend = img.convert("RGBA") mask = Image.new("L", img.size, 0) - mask_block = Image.new("L", (bb_width, bb_height), 255) + mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255) mask_block = inset_border(mask_block, int(blend_ratio / 2), (0)) - + mask.paste(mask_block, paste_region) blend.paste(crop_img, paste_region)