diff --git a/nodes.py b/nodes.py index 835fcd2..23a4308 100644 --- a/nodes.py +++ b/nodes.py @@ -1,7 +1,8 @@ import nodes import torch import torch.nn.functional as F -import torchvision.utils as vutils +from torchvision.transforms import Resize, CenterCrop, InterpolationMode +from torchvision.transforms import functional as TF import scipy.ndimage import numpy as np from PIL import ImageFilter, Image, ImageDraw, ImageFont @@ -10,6 +11,7 @@ import json import re import os import librosa +import random from scipy.special import erf from .fluid import Fluid import comfy.model_management @@ -188,8 +190,6 @@ class CreateAudioMask: return (1.0 - torch.cat(out, dim=0),) return (torch.cat(out, dim=0),torch.cat(masks, dim=0),) - - class CreateGradientMask: RETURN_TYPES = ("MASK",) @@ -531,6 +531,8 @@ class GrowMaskWithBlur: "max": 10.0, "step": 0.1 }), + "lerp_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "decay_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, } @@ -540,7 +542,9 @@ class GrowMaskWithBlur: RETURN_NAMES = ("mask", "mask_inverted",) FUNCTION = "expand_mask" - def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, sigma, incremental_expandrate, use_cuda): + def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, sigma, incremental_expandrate, use_cuda, lerp_alpha, decay_factor): + alpha = lerp_alpha + decay = decay_factor if( flip_input ): mask = 1.0 - mask c = 0 if tapered_corners else 1 @@ -549,6 +553,7 @@ class GrowMaskWithBlur: [c, 1, c]]) growmask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) out = [] + previous_output = None for m in growmask: output = m.numpy() for _ in range(abs(expand)): @@ -561,6 +566,14 @@ class GrowMaskWithBlur: else: expand += abs(incremental_expandrate) # Use abs(growrate) to ensure positive change output = torch.from_numpy(output) + if alpha < 1.0 and previous_output is not None: + # Interpolate between the previous and current frame + output = alpha * output + (1 - alpha) * previous_output + if decay < 1.0 and previous_output is not None: + # Add the decayed previous output to the current frame + output += decay * previous_output + output = output / output.max() + previous_output = output out.append(output) blurred = torch.stack(out, dim=0).reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) @@ -568,7 +581,7 @@ class GrowMaskWithBlur: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") blurred = blurred.to(device) # Move blurred tensor to the GPU - batch_size, height, width, channels = blurred.shape + channels = blurred.shape[-1] if blur_radius != 0: blurkernel_size = blur_radius * 2 + 1 blurkernel = gaussian_kernel(blurkernel_size, sigma, device=blurred.device).repeat(channels, 1, 1).unsqueeze(1) @@ -1215,7 +1228,6 @@ class ImageGridComposite3x3: grid = torch.cat((top_row, mid_row, bottom_row), dim=1) return (grid,) -import random class ImageBatchTestPattern: @classmethod def INPUT_TYPES(s): @@ -1272,7 +1284,7 @@ class ImageBatchTestPattern: #based on nodes from mtb https://github.com/melMass/comfy_mtb from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor -from torchvision.transforms import Resize, CenterCrop + class BatchCropFromMask: @@ -1496,6 +1508,276 @@ class BatchUncrop: return (pil2tensor(out_images),) +class BatchCropFromMaskAdvanced: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "original_images": ("IMAGE",), + "masks": ("MASK",), + "crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + }, + } + + RETURN_TYPES = ( + "IMAGE", + "IMAGE", + "MASK", + "IMAGE", + "MASK", + "BBOX", + "BBOX", + "INT", + "INT", + ) + RETURN_NAMES = ( + "original_images", + "cropped_images", + "cropped_masks", + "combined_crop_image", + "combined_crop_masks", + "bboxes", + "combined_bounding_box", + "bbox_width", + "bbox_height", + ) + FUNCTION = "crop" + CATEGORY = "KJNodes/masking" + + def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha): + return int(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size) + + def smooth_center(self, prev_center, curr_center, alpha=0.5): + return (int(alpha * curr_center[0] + (1 - alpha) * prev_center[0]), + int(alpha * curr_center[1] + (1 - alpha) * prev_center[1])) + + def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha): + bounding_boxes = [] + combined_bounding_box = [] + cropped_images = [] + cropped_masks = [] + cropped_masks_out = [] + combined_crop_out = [] + combined_cropped_images = [] + combined_cropped_masks = [] + + def calculate_bbox(mask): + non_zero_indices = np.nonzero(np.array(mask)) + min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) + min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) + width = max_x - min_x + height = max_y - min_y + bbox_size = max(width, height) + return min_x, max_x, min_y, max_y, bbox_size + + combined_mask = torch.max(masks, dim=0)[0] + _mask = tensor2pil(combined_mask)[0] + new_min_x, new_max_x, new_min_y, new_max_y, combined_bbox_size = calculate_bbox(_mask) + center_x = (new_min_x + new_max_x) / 2 + center_y = (new_min_y + new_max_y) / 2 + half_box_size = int(combined_bbox_size // 2) + new_min_x = max(0, int(center_x - half_box_size)) + new_max_x = min(original_images[0].shape[1], int(center_x + half_box_size)) + new_min_y = max(0, int(center_y - half_box_size)) + new_max_y = min(original_images[0].shape[0], int(center_y + half_box_size)) + + combined_bounding_box.append((new_min_x, new_min_y, new_max_x - new_min_x, new_max_y - new_min_y)) + + self.max_bbox_size = 0 + + # First, calculate the maximum bounding box size across all masks + curr_max_bbox_size = max(calculate_bbox(tensor2pil(mask)[0])[-1] for mask in masks) + # Smooth the changes in the bounding box size + self.max_bbox_size = self.smooth_bbox_size(self.max_bbox_size, curr_max_bbox_size, bbox_smooth_alpha) + # Apply the crop size multiplier + self.max_bbox_size = int(self.max_bbox_size * crop_size_mult) + # Make sure max_bbox_size is divisible by 32, if not, round it upwards so it is + self.max_bbox_size = math.ceil(self.max_bbox_size / 32) * 32 + + # Then, for each mask and corresponding image... + for i, (mask, img) in enumerate(zip(masks, original_images)): + _mask = tensor2pil(mask)[0] + non_zero_indices = np.nonzero(np.array(_mask)) + min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) + min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) + + # Calculate center of bounding box + center_x = np.mean(non_zero_indices[1]) + center_y = np.mean(non_zero_indices[0]) + curr_center = (int(center_x), int(center_y)) + + # If this is the first frame, initialize prev_center with curr_center + if not hasattr(self, 'prev_center'): + self.prev_center = curr_center + + # Smooth the changes in the center coordinates from the second frame onwards + if i > 0: + center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha) + else: + center = curr_center + + # Update prev_center for the next frame + self.prev_center = center + + # Create bounding box using max_bbox_size + half_box_size = self.max_bbox_size // 2 + half_box_size = self.max_bbox_size // 2 + min_x = max(0, center[0] - half_box_size) + max_x = min(img.shape[1], center[0] + half_box_size) + min_y = max(0, center[1] - half_box_size) + max_y = min(img.shape[0], center[1] + half_box_size) + + # Append bounding box coordinates + bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y)) + + # Crop the image from the bounding box + cropped_img = img[min_y:max_y, min_x:max_x, :] + cropped_mask = mask[min_y:max_y, min_x:max_x] + + # Resize the cropped image to a fixed size + new_size = max(cropped_img.shape[0], cropped_img.shape[1]) + resize_transform = Resize(new_size, interpolation = InterpolationMode.NEAREST) + resized_mask = resize_transform(cropped_mask.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) + resized_img = resize_transform(cropped_img.permute(2, 0, 1)) + # Perform the center crop to the desired size + crop_transform = CenterCrop((self.max_bbox_size, self.max_bbox_size)) + + cropped_resized_img = crop_transform(resized_img) + cropped_images.append(cropped_resized_img.permute(1, 2, 0)) + + cropped_resized_mask = crop_transform(resized_mask) + cropped_masks.append(cropped_resized_mask) + + combined_cropped_img = original_images[i][new_min_y:new_max_y, new_min_x:new_max_x, :] + combined_cropped_images.append(combined_cropped_img) + + combined_cropped_mask = masks[i][new_min_y:new_max_y, new_min_x:new_max_x] + combined_cropped_masks.append(combined_cropped_mask) + + cropped_out = torch.stack(cropped_images, dim=0) + combined_crop_out = torch.stack(combined_cropped_images, dim=0) + cropped_masks_out = torch.stack(cropped_masks, dim=0) + combined_crop_mask_out = torch.stack(combined_cropped_masks, dim=0) + + return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size) + + +def bbox_to_region(bbox, target_size=None): + bbox = bbox_check(bbox, target_size) + return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]) + +def bbox_check(bbox, target_size=None): + if not target_size: + return bbox + + new_bbox = ( + bbox[0], + bbox[1], + min(target_size[0] - bbox[0], bbox[2]), + min(target_size[1] - bbox[1], bbox[3]), + ) + return new_bbox + +class BatchUncropAdvanced: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "original_images": ("IMAGE",), + "cropped_images": ("IMAGE",), + "cropped_masks": ("MASK",), + "combined_crop_mask": ("MASK",), + "bboxes": ("BBOX",), + "border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ), + "crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "use_combined_mask": ("BOOLEAN", {"default": False}), + "use_square_mask": ("BOOLEAN", {"default": True}), + }, + "optional": { + "combined_bounding_box": ("BBOX", {"default": None}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "uncrop" + + CATEGORY = "KJNodes/masking" + + def uncrop(self, original_images, cropped_images, cropped_masks, combined_crop_mask, bboxes, border_blending, crop_rescale, use_combined_mask, use_square_mask, combined_bounding_box = None): + + def inset_border(image, border_width=20, border_color=(0)): + width, height = image.size + bordered_image = Image.new(image.mode, (width, height), border_color) + bordered_image.paste(image, (0, 0)) + draw = ImageDraw.Draw(bordered_image) + draw.rectangle((0, 0, width - 1, height - 1), outline=border_color, width=border_width) + return bordered_image + + if len(original_images) != len(cropped_images) or len(original_images) != len(bboxes): + raise ValueError("The number of images, crop_images, and bboxes should be the same") + + crop_imgs = tensor2pil(cropped_images) + input_images = tensor2pil(original_images) + out_images = [] + + for i in range(len(input_images)): + img = input_images[i] + crop = crop_imgs[i] + bbox = bboxes[i] + + + if use_combined_mask: + bb_x, bb_y, bb_width, bb_height = combined_bounding_box[0] + paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size) + mask = combined_crop_mask[i] + else: + bb_x, bb_y, bb_width, bb_height = bbox + paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size) + mask = cropped_masks[i] + + # scale paste_region + scale_x = scale_y = crop_rescale + paste_region = (int(paste_region[0]*scale_x), int(paste_region[1]*scale_y), int(paste_region[2]*scale_x), int(paste_region[3]*scale_y)) + + # rescale the crop image to fit the paste_region + crop = crop.resize((int(paste_region[2]-paste_region[0]), int(paste_region[3]-paste_region[1]))) + crop_img = crop.convert("RGB") + + #border blending + if border_blending > 1.0: + border_blending = 1.0 + elif border_blending < 0.0: + border_blending = 0.0 + + blend_ratio = (max(crop_img.size) / 2) * float(border_blending) + blend = img.convert("RGBA") + + if use_square_mask: + mask = Image.new("L", img.size, 0) + mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255) + mask_block = inset_border(mask_block, int(blend_ratio / 2), (0)) + mask.paste(mask_block, paste_region) + else: + original_mask = tensor2pil(mask)[0] + original_mask = original_mask.resize((paste_region[2]-paste_region[0], paste_region[3]-paste_region[1])) + mask = Image.new("L", img.size, 0) + mask.paste(original_mask, paste_region) + + mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4)) + mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4)) + + blend.paste(crop_img, paste_region) + blend.putalpha(mask) + + img = Image.alpha_composite(img.convert("RGBA"), blend) + out_images.append(img.convert("RGB")) + + return (pil2tensor(out_images),) + + from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation class BatchCLIPSeg: @@ -1512,6 +1794,8 @@ class BatchCLIPSeg: "text": ("STRING", {"multiline": False}), "threshold": ("FLOAT", {"default": 0.4,"min": 0.0, "max": 10.0, "step": 0.01}), "binary_mask": ("BOOLEAN", {"default": True}), + "combine_mask": ("BOOLEAN", {"default": True}), + "use_cuda": ("BOOLEAN", {"default": True}), }, } @@ -1521,13 +1805,14 @@ class BatchCLIPSeg: FUNCTION = "segment_image" - def segment_image(self, images, text, threshold, binary_mask): + def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda): out = [] height, width, _ = images[0].shape - print(height) - print(width) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if use_cuda and torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") model.to(device) # Ensure the model is on the correct device images = images.to(device) @@ -1561,6 +1846,11 @@ class BatchCLIPSeg: out.append(resized_tensor) results = torch.stack(out).cpu() + + if combine_mask: + combined_results = torch.max(results, dim=0)[0] + results = combined_results.unsqueeze(0).repeat(len(images),1,1) + if binary_mask: results = results.round() @@ -1581,6 +1871,115 @@ class RoundMask: mask = mask.round() return (mask,) +class ResizeMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + "width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, "display": "number" }), + "height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, "display": "number" }), + "keep_proportions": ("BOOLEAN", { "default": False }), + } + } + + RETURN_TYPES = ("MASK", "INT", "INT",) + RETURN_NAMES = ("mask", "width", "height",) + FUNCTION = "resize" + CATEGORY = "KJNodes/masking" + + def resize(self, mask, width, height, keep_proportions): + if keep_proportions: + _, oh, ow, _ = mask.shape + width = ow if width == 0 else width + height = oh if height == 0 else height + ratio = min(width / ow, height / oh) + width = round(ow*ratio) + height = round(oh*ratio) + + outputs = mask.unsqueeze(0) # Add an extra dimension for batch size + outputs = F.interpolate(outputs, size=(height, width), mode="nearest") + outputs = outputs.squeeze(0) # Remove the extra dimension after interpolation + + return(outputs, outputs.shape[2], outputs.shape[1],) + +class OffsetMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + "x": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }), + "y": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }), + "angle": ("INT", { "default": 0, "min": -360, "max": 360, "step": 1, "display": "number" }), + "duplication_factor": ("INT", { "default": 1, "min": 1, "max": 1000, "step": 1, "display": "number" }), + "roll": ("BOOLEAN", { "default": False }), + "incremental": ("BOOLEAN", { "default": False }), + } + } + + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("mask",) + FUNCTION = "offset" + CATEGORY = "KJNodes/masking" + + def offset(self, mask, x, y, angle, roll=False, incremental=False, duplication_factor=1): + # Create duplicates of the mask batch + mask = mask.repeat(duplication_factor, 1, 1) + + batch_size, height, width = mask.shape + + if angle is not 0 and incremental: + for i in range(batch_size): + rotation_angle = angle * (i+1) + mask[i] = TF.rotate(mask[i].unsqueeze(0), rotation_angle).squeeze(0) + elif angle > 0: + for i in range(batch_size): + mask[i] = TF.rotate(mask[i].unsqueeze(0), angle).squeeze(0) + + if roll: + if incremental: + for i in range(batch_size): + shift_x = min(x*(i+1), width-1) + shift_y = min(y*(i+1), height-1) + if shift_x != 0: + mask[i] = torch.roll(mask[i], shifts=shift_x, dims=1) + if shift_y != 0: + mask[i] = torch.roll(mask[i], shifts=shift_y, dims=0) + else: + shift_x = min(x, width-1) + shift_y = min(y, height-1) + if shift_x != 0: + mask = torch.roll(mask, shifts=shift_x, dims=2) + if shift_y != 0: + mask = torch.roll(mask, shifts=shift_y, dims=1) + else: + if incremental: + for i in range(batch_size): + temp_x = min(x * (i+1), width-1) + temp_y = min(y * (i+1), height-1) + if temp_x > 0: + mask[i] = torch.cat([torch.zeros((height, temp_x)), mask[i, :, :-temp_x]], dim=1) + elif temp_x < 0: + mask[i] = torch.cat([mask[i, :, -temp_x:], torch.zeros((height, -temp_x))], dim=1) + if temp_y > 0: + mask[i] = torch.cat([torch.zeros((temp_y, width)), mask[i, :-temp_y, :]], dim=0) + elif temp_y < 0: + mask[i] = torch.cat([mask[i, -temp_y:, :], torch.zeros((-temp_y, width))], dim=0) + else: + temp_x = min(x, width-1) + temp_y = min(y, height-1) + if temp_x > 0: + mask = torch.cat([torch.zeros((batch_size, height, temp_x)), mask[:, :, :-temp_x]], dim=2) + elif temp_x < 0: + mask = torch.cat([mask[:, :, -temp_x:], torch.zeros((batch_size, height, -temp_x))], dim=2) + if temp_y > 0: + mask = torch.cat([torch.zeros((batch_size, temp_y, width)), mask[:, :-temp_y, :]], dim=1) + elif temp_y < 0: + mask = torch.cat([mask[:, -temp_y:, :], torch.zeros((batch_size, -temp_y, width))], dim=1) + + return mask, + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -1610,9 +2009,13 @@ NODE_CLASS_MAPPINGS = { "ImageBatchTestPattern": ImageBatchTestPattern, "ReplaceImagesInBatch": ReplaceImagesInBatch, "BatchCropFromMask": BatchCropFromMask, + "BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced, "BatchUncrop": BatchUncrop, + "BatchUncropAdvanced": BatchUncropAdvanced, "BatchCLIPSeg": BatchCLIPSeg, "RoundMask": RoundMask, + "ResizeMask": ResizeMask, + "OffsetMask": OffsetMask, } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -1642,7 +2045,11 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageBatchTestPattern": "ImageBatchTestPattern", "ReplaceImagesInBatch": "ReplaceImagesInBatch", "BatchCropFromMask": "BatchCropFromMask", + "BatchCropFromMaskAdvanced": "BatchCropFromMaskAdvanced", "BatchUncrop": "BatchUncrop", + "BatchUncropAdvanced": "BatchUncropAdvanced", "BatchCLIPSeg": "BatchCLIPSeg", "RoundMask": "RoundMask", + "ResizeMask": "ResizeMask", + "OffsetMask": "OffsetMask", } \ No newline at end of file