Update nodes.py

This commit is contained in:
kijai 2023-11-17 21:56:00 +02:00
parent 5970b59de1
commit 57e95ab9d2

View File

@ -1324,21 +1324,29 @@ class BatchCropFromMask:
CATEGORY = "KJNodes/masking"
def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
return int(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
if alpha == 0:
return prev_bbox_size
return int(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
def smooth_center(self, prev_center, curr_center, alpha=0.5):
return (int(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
int(alpha * curr_center[1] + (1 - alpha) * prev_center[1]))
if alpha == 0:
return prev_center
return (
int(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
int(alpha * curr_center[1] + (1 - alpha) * prev_center[1])
)
def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
bounding_boxes = []
cropped_images = []
self.max_bbox_size = 0
self.max_bbox_width = 0
self.max_bbox_height = 0
# First, calculate the maximum bounding box size across all masks
curr_max_bbox_size = 0
curr_max_bbox_width = 0
curr_max_bbox_height = 0
for mask in masks:
_mask = tensor2pil(mask)[0]
non_zero_indices = np.nonzero(np.array(_mask))
@ -1346,17 +1354,21 @@ class BatchCropFromMask:
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
width = max_x - min_x
height = max_y - min_y
bbox_size = max(width, height)
curr_max_bbox_size = max(curr_max_bbox_size, bbox_size)
curr_max_bbox_width = max(curr_max_bbox_width, width)
curr_max_bbox_height = max(curr_max_bbox_height, height)
# Smooth the changes in the bounding box size
self.max_bbox_size = self.smooth_bbox_size(self.max_bbox_size, curr_max_bbox_size, bbox_smooth_alpha)
self.max_bbox_width = self.smooth_bbox_size(self.max_bbox_width, curr_max_bbox_width, bbox_smooth_alpha)
self.max_bbox_height = self.smooth_bbox_size(self.max_bbox_height, curr_max_bbox_height, bbox_smooth_alpha)
# Apply the crop size multiplier
self.max_bbox_size = int(self.max_bbox_size * crop_size_mult)
self.max_bbox_width = int(self.max_bbox_width * crop_size_mult)
self.max_bbox_height = int(self.max_bbox_height * crop_size_mult)
bbox_aspect_ratio = self.max_bbox_width / self.max_bbox_height
# Make sure max_bbox_size is divisible by 32, if not, round it upwards so it is
self.max_bbox_size = math.ceil(self.max_bbox_size / 32) * 32
# Make sure max_bbox_size is divisible by 16, if not, round it upwards so it is
self.max_bbox_width = math.ceil(self.max_bbox_width / 32) * 32
self.max_bbox_height = math.ceil(self.max_bbox_height / 32) * 32
# Then, for each mask and corresponding image...
for i, (mask, img) in enumerate(zip(masks, original_images)):
@ -1364,7 +1376,7 @@ class BatchCropFromMask:
non_zero_indices = np.nonzero(np.array(_mask))
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
# Calculate center of bounding box
center_x = np.mean(non_zero_indices[1])
center_y = np.mean(non_zero_indices[0])
@ -1383,13 +1395,13 @@ class BatchCropFromMask:
# Update prev_center for the next frame
self.prev_center = center
# Create bounding box using max_bbox_size
half_box_size = self.max_bbox_size // 2
half_box_size = self.max_bbox_size // 2
min_x = max(0, center[0] - half_box_size)
max_x = min(img.shape[1], center[0] + half_box_size)
min_y = max(0, center[1] - half_box_size)
max_y = min(img.shape[0], center[1] + half_box_size)
# Create bounding box using max_bbox_width and max_bbox_height
half_box_width = round(self.max_bbox_width / 2)
half_box_height = round(self.max_bbox_height / 2)
min_x = max(0, center[0] - half_box_width)
max_x = min(img.shape[1], center[0] + half_box_width)
min_y = max(0, center[1] - half_box_height)
max_y = min(img.shape[0], center[1] + half_box_height)
# Append bounding box coordinates
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
@ -1397,21 +1409,23 @@ class BatchCropFromMask:
# Crop the image from the bounding box
cropped_img = img[min_y:max_y, min_x:max_x, :]
# Resize the cropped image to a fixed size
new_size = max(cropped_img.shape[0], cropped_img.shape[1])
resize_transform = Resize(new_size)
# Calculate the new dimensions while maintaining the aspect ratio
new_height = min(cropped_img.shape[0], self.max_bbox_height)
new_width = int(new_height * bbox_aspect_ratio)
# Resize the image
resize_transform = Resize((new_height, new_width))
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
# Perform the center crop to the desired size
crop_transform = CenterCrop((self.max_bbox_size, self.max_bbox_size))
crop_transform = CenterCrop((self.max_bbox_height, self.max_bbox_width)) # swap the order here if necessary
cropped_resized_img = crop_transform(resized_img)
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
cropped_out = torch.stack(cropped_images, dim=0)
return (original_images, cropped_out, bounding_boxes, self.max_bbox_size, self.max_bbox_size, )
return (original_images, cropped_out, bounding_boxes, self.max_bbox_width, self.max_bbox_height, )
def bbox_to_region(bbox, target_size=None):
@ -1441,6 +1455,10 @@ class BatchUncrop:
"bboxes": ("BBOX",),
"border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ),
"crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"border_top": ("BOOLEAN", {"default": True}),
"border_bottom": ("BOOLEAN", {"default": True}),
"border_left": ("BOOLEAN", {"default": True}),
"border_right": ("BOOLEAN", {"default": True}),
}
}
@ -1449,16 +1467,19 @@ class BatchUncrop:
CATEGORY = "KJNodes/masking"
def uncrop(self, original_images, cropped_images, bboxes, border_blending, crop_rescale):
def inset_border(image, border_width=20, border_color=(0)):
def uncrop(self, original_images, cropped_images, bboxes, border_blending, crop_rescale, border_top, border_bottom, border_left, border_right):
def inset_border(image, border_width, border_color, border_top, border_bottom, border_left, border_right):
draw = ImageDraw.Draw(image)
width, height = image.size
bordered_image = Image.new(image.mode, (width, height), border_color)
bordered_image.paste(image, (0, 0))
draw = ImageDraw.Draw(bordered_image)
draw.rectangle(
(0, 0, width - 1, height - 1), outline=border_color, width=border_width
)
return bordered_image
if border_top:
draw.rectangle((0, 0, width, border_width), fill=border_color)
if border_bottom:
draw.rectangle((0, height - border_width, width, height), fill=border_color)
if border_left:
draw.rectangle((0, 0, border_width, height), fill=border_color)
if border_right:
draw.rectangle((width - border_width, 0, width, height), fill=border_color)
return image
if len(original_images) != len(cropped_images) or len(original_images) != len(bboxes):
raise ValueError("The number of images, crop_images, and bboxes should be the same")
@ -1499,7 +1520,7 @@ class BatchUncrop:
mask = Image.new("L", img.size, 0)
mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255)
mask_block = inset_border(mask_block, int(blend_ratio / 2), (0))
mask_block = inset_border(mask_block, int(blend_ratio / 2), (0), border_top, border_bottom, border_left, border_right)
mask.paste(mask_block, paste_region)
blend.paste(crop_img, paste_region)