mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-10 05:15:05 +08:00
Add advanced crop/uncrop and some mask tools
This commit is contained in:
parent
dccee73432
commit
c4381dfe36
429
nodes.py
429
nodes.py
@ -1,7 +1,8 @@
|
|||||||
import nodes
|
import nodes
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision.utils as vutils
|
from torchvision.transforms import Resize, CenterCrop, InterpolationMode
|
||||||
|
from torchvision.transforms import functional as TF
|
||||||
import scipy.ndimage
|
import scipy.ndimage
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import ImageFilter, Image, ImageDraw, ImageFont
|
from PIL import ImageFilter, Image, ImageDraw, ImageFont
|
||||||
@ -10,6 +11,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
import librosa
|
import librosa
|
||||||
|
import random
|
||||||
from scipy.special import erf
|
from scipy.special import erf
|
||||||
from .fluid import Fluid
|
from .fluid import Fluid
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -188,8 +190,6 @@ class CreateAudioMask:
|
|||||||
return (1.0 - torch.cat(out, dim=0),)
|
return (1.0 - torch.cat(out, dim=0),)
|
||||||
return (torch.cat(out, dim=0),torch.cat(masks, dim=0),)
|
return (torch.cat(out, dim=0),torch.cat(masks, dim=0),)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CreateGradientMask:
|
class CreateGradientMask:
|
||||||
|
|
||||||
RETURN_TYPES = ("MASK",)
|
RETURN_TYPES = ("MASK",)
|
||||||
@ -531,6 +531,8 @@ class GrowMaskWithBlur:
|
|||||||
"max": 10.0,
|
"max": 10.0,
|
||||||
"step": 0.1
|
"step": 0.1
|
||||||
}),
|
}),
|
||||||
|
"lerp_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"decay_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -540,7 +542,9 @@ class GrowMaskWithBlur:
|
|||||||
RETURN_NAMES = ("mask", "mask_inverted",)
|
RETURN_NAMES = ("mask", "mask_inverted",)
|
||||||
FUNCTION = "expand_mask"
|
FUNCTION = "expand_mask"
|
||||||
|
|
||||||
def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, sigma, incremental_expandrate, use_cuda):
|
def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, sigma, incremental_expandrate, use_cuda, lerp_alpha, decay_factor):
|
||||||
|
alpha = lerp_alpha
|
||||||
|
decay = decay_factor
|
||||||
if( flip_input ):
|
if( flip_input ):
|
||||||
mask = 1.0 - mask
|
mask = 1.0 - mask
|
||||||
c = 0 if tapered_corners else 1
|
c = 0 if tapered_corners else 1
|
||||||
@ -549,6 +553,7 @@ class GrowMaskWithBlur:
|
|||||||
[c, 1, c]])
|
[c, 1, c]])
|
||||||
growmask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
growmask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
||||||
out = []
|
out = []
|
||||||
|
previous_output = None
|
||||||
for m in growmask:
|
for m in growmask:
|
||||||
output = m.numpy()
|
output = m.numpy()
|
||||||
for _ in range(abs(expand)):
|
for _ in range(abs(expand)):
|
||||||
@ -561,6 +566,14 @@ class GrowMaskWithBlur:
|
|||||||
else:
|
else:
|
||||||
expand += abs(incremental_expandrate) # Use abs(growrate) to ensure positive change
|
expand += abs(incremental_expandrate) # Use abs(growrate) to ensure positive change
|
||||||
output = torch.from_numpy(output)
|
output = torch.from_numpy(output)
|
||||||
|
if alpha < 1.0 and previous_output is not None:
|
||||||
|
# Interpolate between the previous and current frame
|
||||||
|
output = alpha * output + (1 - alpha) * previous_output
|
||||||
|
if decay < 1.0 and previous_output is not None:
|
||||||
|
# Add the decayed previous output to the current frame
|
||||||
|
output += decay * previous_output
|
||||||
|
output = output / output.max()
|
||||||
|
previous_output = output
|
||||||
out.append(output)
|
out.append(output)
|
||||||
|
|
||||||
blurred = torch.stack(out, dim=0).reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
blurred = torch.stack(out, dim=0).reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||||
@ -568,7 +581,7 @@ class GrowMaskWithBlur:
|
|||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
blurred = blurred.to(device) # Move blurred tensor to the GPU
|
blurred = blurred.to(device) # Move blurred tensor to the GPU
|
||||||
|
|
||||||
batch_size, height, width, channels = blurred.shape
|
channels = blurred.shape[-1]
|
||||||
if blur_radius != 0:
|
if blur_radius != 0:
|
||||||
blurkernel_size = blur_radius * 2 + 1
|
blurkernel_size = blur_radius * 2 + 1
|
||||||
blurkernel = gaussian_kernel(blurkernel_size, sigma, device=blurred.device).repeat(channels, 1, 1).unsqueeze(1)
|
blurkernel = gaussian_kernel(blurkernel_size, sigma, device=blurred.device).repeat(channels, 1, 1).unsqueeze(1)
|
||||||
@ -1215,7 +1228,6 @@ class ImageGridComposite3x3:
|
|||||||
grid = torch.cat((top_row, mid_row, bottom_row), dim=1)
|
grid = torch.cat((top_row, mid_row, bottom_row), dim=1)
|
||||||
return (grid,)
|
return (grid,)
|
||||||
|
|
||||||
import random
|
|
||||||
class ImageBatchTestPattern:
|
class ImageBatchTestPattern:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1272,7 +1284,7 @@ class ImageBatchTestPattern:
|
|||||||
#based on nodes from mtb https://github.com/melMass/comfy_mtb
|
#based on nodes from mtb https://github.com/melMass/comfy_mtb
|
||||||
|
|
||||||
from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor
|
from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor
|
||||||
from torchvision.transforms import Resize, CenterCrop
|
|
||||||
|
|
||||||
class BatchCropFromMask:
|
class BatchCropFromMask:
|
||||||
|
|
||||||
@ -1496,6 +1508,276 @@ class BatchUncrop:
|
|||||||
|
|
||||||
return (pil2tensor(out_images),)
|
return (pil2tensor(out_images),)
|
||||||
|
|
||||||
|
class BatchCropFromMaskAdvanced:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"original_images": ("IMAGE",),
|
||||||
|
"masks": ("MASK",),
|
||||||
|
"crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (
|
||||||
|
"IMAGE",
|
||||||
|
"IMAGE",
|
||||||
|
"MASK",
|
||||||
|
"IMAGE",
|
||||||
|
"MASK",
|
||||||
|
"BBOX",
|
||||||
|
"BBOX",
|
||||||
|
"INT",
|
||||||
|
"INT",
|
||||||
|
)
|
||||||
|
RETURN_NAMES = (
|
||||||
|
"original_images",
|
||||||
|
"cropped_images",
|
||||||
|
"cropped_masks",
|
||||||
|
"combined_crop_image",
|
||||||
|
"combined_crop_masks",
|
||||||
|
"bboxes",
|
||||||
|
"combined_bounding_box",
|
||||||
|
"bbox_width",
|
||||||
|
"bbox_height",
|
||||||
|
)
|
||||||
|
FUNCTION = "crop"
|
||||||
|
CATEGORY = "KJNodes/masking"
|
||||||
|
|
||||||
|
def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
|
||||||
|
return int(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
|
||||||
|
|
||||||
|
def smooth_center(self, prev_center, curr_center, alpha=0.5):
|
||||||
|
return (int(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
|
||||||
|
int(alpha * curr_center[1] + (1 - alpha) * prev_center[1]))
|
||||||
|
|
||||||
|
def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
|
||||||
|
bounding_boxes = []
|
||||||
|
combined_bounding_box = []
|
||||||
|
cropped_images = []
|
||||||
|
cropped_masks = []
|
||||||
|
cropped_masks_out = []
|
||||||
|
combined_crop_out = []
|
||||||
|
combined_cropped_images = []
|
||||||
|
combined_cropped_masks = []
|
||||||
|
|
||||||
|
def calculate_bbox(mask):
|
||||||
|
non_zero_indices = np.nonzero(np.array(mask))
|
||||||
|
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
|
||||||
|
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
|
||||||
|
width = max_x - min_x
|
||||||
|
height = max_y - min_y
|
||||||
|
bbox_size = max(width, height)
|
||||||
|
return min_x, max_x, min_y, max_y, bbox_size
|
||||||
|
|
||||||
|
combined_mask = torch.max(masks, dim=0)[0]
|
||||||
|
_mask = tensor2pil(combined_mask)[0]
|
||||||
|
new_min_x, new_max_x, new_min_y, new_max_y, combined_bbox_size = calculate_bbox(_mask)
|
||||||
|
center_x = (new_min_x + new_max_x) / 2
|
||||||
|
center_y = (new_min_y + new_max_y) / 2
|
||||||
|
half_box_size = int(combined_bbox_size // 2)
|
||||||
|
new_min_x = max(0, int(center_x - half_box_size))
|
||||||
|
new_max_x = min(original_images[0].shape[1], int(center_x + half_box_size))
|
||||||
|
new_min_y = max(0, int(center_y - half_box_size))
|
||||||
|
new_max_y = min(original_images[0].shape[0], int(center_y + half_box_size))
|
||||||
|
|
||||||
|
combined_bounding_box.append((new_min_x, new_min_y, new_max_x - new_min_x, new_max_y - new_min_y))
|
||||||
|
|
||||||
|
self.max_bbox_size = 0
|
||||||
|
|
||||||
|
# First, calculate the maximum bounding box size across all masks
|
||||||
|
curr_max_bbox_size = max(calculate_bbox(tensor2pil(mask)[0])[-1] for mask in masks)
|
||||||
|
# Smooth the changes in the bounding box size
|
||||||
|
self.max_bbox_size = self.smooth_bbox_size(self.max_bbox_size, curr_max_bbox_size, bbox_smooth_alpha)
|
||||||
|
# Apply the crop size multiplier
|
||||||
|
self.max_bbox_size = int(self.max_bbox_size * crop_size_mult)
|
||||||
|
# Make sure max_bbox_size is divisible by 32, if not, round it upwards so it is
|
||||||
|
self.max_bbox_size = math.ceil(self.max_bbox_size / 32) * 32
|
||||||
|
|
||||||
|
# Then, for each mask and corresponding image...
|
||||||
|
for i, (mask, img) in enumerate(zip(masks, original_images)):
|
||||||
|
_mask = tensor2pil(mask)[0]
|
||||||
|
non_zero_indices = np.nonzero(np.array(_mask))
|
||||||
|
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
|
||||||
|
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
|
||||||
|
|
||||||
|
# Calculate center of bounding box
|
||||||
|
center_x = np.mean(non_zero_indices[1])
|
||||||
|
center_y = np.mean(non_zero_indices[0])
|
||||||
|
curr_center = (int(center_x), int(center_y))
|
||||||
|
|
||||||
|
# If this is the first frame, initialize prev_center with curr_center
|
||||||
|
if not hasattr(self, 'prev_center'):
|
||||||
|
self.prev_center = curr_center
|
||||||
|
|
||||||
|
# Smooth the changes in the center coordinates from the second frame onwards
|
||||||
|
if i > 0:
|
||||||
|
center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha)
|
||||||
|
else:
|
||||||
|
center = curr_center
|
||||||
|
|
||||||
|
# Update prev_center for the next frame
|
||||||
|
self.prev_center = center
|
||||||
|
|
||||||
|
# Create bounding box using max_bbox_size
|
||||||
|
half_box_size = self.max_bbox_size // 2
|
||||||
|
half_box_size = self.max_bbox_size // 2
|
||||||
|
min_x = max(0, center[0] - half_box_size)
|
||||||
|
max_x = min(img.shape[1], center[0] + half_box_size)
|
||||||
|
min_y = max(0, center[1] - half_box_size)
|
||||||
|
max_y = min(img.shape[0], center[1] + half_box_size)
|
||||||
|
|
||||||
|
# Append bounding box coordinates
|
||||||
|
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
|
||||||
|
|
||||||
|
# Crop the image from the bounding box
|
||||||
|
cropped_img = img[min_y:max_y, min_x:max_x, :]
|
||||||
|
cropped_mask = mask[min_y:max_y, min_x:max_x]
|
||||||
|
|
||||||
|
# Resize the cropped image to a fixed size
|
||||||
|
new_size = max(cropped_img.shape[0], cropped_img.shape[1])
|
||||||
|
resize_transform = Resize(new_size, interpolation = InterpolationMode.NEAREST)
|
||||||
|
resized_mask = resize_transform(cropped_mask.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
|
||||||
|
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
|
||||||
|
# Perform the center crop to the desired size
|
||||||
|
crop_transform = CenterCrop((self.max_bbox_size, self.max_bbox_size))
|
||||||
|
|
||||||
|
cropped_resized_img = crop_transform(resized_img)
|
||||||
|
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
|
||||||
|
|
||||||
|
cropped_resized_mask = crop_transform(resized_mask)
|
||||||
|
cropped_masks.append(cropped_resized_mask)
|
||||||
|
|
||||||
|
combined_cropped_img = original_images[i][new_min_y:new_max_y, new_min_x:new_max_x, :]
|
||||||
|
combined_cropped_images.append(combined_cropped_img)
|
||||||
|
|
||||||
|
combined_cropped_mask = masks[i][new_min_y:new_max_y, new_min_x:new_max_x]
|
||||||
|
combined_cropped_masks.append(combined_cropped_mask)
|
||||||
|
|
||||||
|
cropped_out = torch.stack(cropped_images, dim=0)
|
||||||
|
combined_crop_out = torch.stack(combined_cropped_images, dim=0)
|
||||||
|
cropped_masks_out = torch.stack(cropped_masks, dim=0)
|
||||||
|
combined_crop_mask_out = torch.stack(combined_cropped_masks, dim=0)
|
||||||
|
|
||||||
|
return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size)
|
||||||
|
|
||||||
|
|
||||||
|
def bbox_to_region(bbox, target_size=None):
|
||||||
|
bbox = bbox_check(bbox, target_size)
|
||||||
|
return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
|
||||||
|
|
||||||
|
def bbox_check(bbox, target_size=None):
|
||||||
|
if not target_size:
|
||||||
|
return bbox
|
||||||
|
|
||||||
|
new_bbox = (
|
||||||
|
bbox[0],
|
||||||
|
bbox[1],
|
||||||
|
min(target_size[0] - bbox[0], bbox[2]),
|
||||||
|
min(target_size[1] - bbox[1], bbox[3]),
|
||||||
|
)
|
||||||
|
return new_bbox
|
||||||
|
|
||||||
|
class BatchUncropAdvanced:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"original_images": ("IMAGE",),
|
||||||
|
"cropped_images": ("IMAGE",),
|
||||||
|
"cropped_masks": ("MASK",),
|
||||||
|
"combined_crop_mask": ("MASK",),
|
||||||
|
"bboxes": ("BBOX",),
|
||||||
|
"border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
||||||
|
"crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"use_combined_mask": ("BOOLEAN", {"default": False}),
|
||||||
|
"use_square_mask": ("BOOLEAN", {"default": True}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"combined_bounding_box": ("BBOX", {"default": None}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "uncrop"
|
||||||
|
|
||||||
|
CATEGORY = "KJNodes/masking"
|
||||||
|
|
||||||
|
def uncrop(self, original_images, cropped_images, cropped_masks, combined_crop_mask, bboxes, border_blending, crop_rescale, use_combined_mask, use_square_mask, combined_bounding_box = None):
|
||||||
|
|
||||||
|
def inset_border(image, border_width=20, border_color=(0)):
|
||||||
|
width, height = image.size
|
||||||
|
bordered_image = Image.new(image.mode, (width, height), border_color)
|
||||||
|
bordered_image.paste(image, (0, 0))
|
||||||
|
draw = ImageDraw.Draw(bordered_image)
|
||||||
|
draw.rectangle((0, 0, width - 1, height - 1), outline=border_color, width=border_width)
|
||||||
|
return bordered_image
|
||||||
|
|
||||||
|
if len(original_images) != len(cropped_images) or len(original_images) != len(bboxes):
|
||||||
|
raise ValueError("The number of images, crop_images, and bboxes should be the same")
|
||||||
|
|
||||||
|
crop_imgs = tensor2pil(cropped_images)
|
||||||
|
input_images = tensor2pil(original_images)
|
||||||
|
out_images = []
|
||||||
|
|
||||||
|
for i in range(len(input_images)):
|
||||||
|
img = input_images[i]
|
||||||
|
crop = crop_imgs[i]
|
||||||
|
bbox = bboxes[i]
|
||||||
|
|
||||||
|
|
||||||
|
if use_combined_mask:
|
||||||
|
bb_x, bb_y, bb_width, bb_height = combined_bounding_box[0]
|
||||||
|
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
|
||||||
|
mask = combined_crop_mask[i]
|
||||||
|
else:
|
||||||
|
bb_x, bb_y, bb_width, bb_height = bbox
|
||||||
|
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
|
||||||
|
mask = cropped_masks[i]
|
||||||
|
|
||||||
|
# scale paste_region
|
||||||
|
scale_x = scale_y = crop_rescale
|
||||||
|
paste_region = (int(paste_region[0]*scale_x), int(paste_region[1]*scale_y), int(paste_region[2]*scale_x), int(paste_region[3]*scale_y))
|
||||||
|
|
||||||
|
# rescale the crop image to fit the paste_region
|
||||||
|
crop = crop.resize((int(paste_region[2]-paste_region[0]), int(paste_region[3]-paste_region[1])))
|
||||||
|
crop_img = crop.convert("RGB")
|
||||||
|
|
||||||
|
#border blending
|
||||||
|
if border_blending > 1.0:
|
||||||
|
border_blending = 1.0
|
||||||
|
elif border_blending < 0.0:
|
||||||
|
border_blending = 0.0
|
||||||
|
|
||||||
|
blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
|
||||||
|
blend = img.convert("RGBA")
|
||||||
|
|
||||||
|
if use_square_mask:
|
||||||
|
mask = Image.new("L", img.size, 0)
|
||||||
|
mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255)
|
||||||
|
mask_block = inset_border(mask_block, int(blend_ratio / 2), (0))
|
||||||
|
mask.paste(mask_block, paste_region)
|
||||||
|
else:
|
||||||
|
original_mask = tensor2pil(mask)[0]
|
||||||
|
original_mask = original_mask.resize((paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]))
|
||||||
|
mask = Image.new("L", img.size, 0)
|
||||||
|
mask.paste(original_mask, paste_region)
|
||||||
|
|
||||||
|
mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
|
||||||
|
mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
|
||||||
|
|
||||||
|
blend.paste(crop_img, paste_region)
|
||||||
|
blend.putalpha(mask)
|
||||||
|
|
||||||
|
img = Image.alpha_composite(img.convert("RGBA"), blend)
|
||||||
|
out_images.append(img.convert("RGB"))
|
||||||
|
|
||||||
|
return (pil2tensor(out_images),)
|
||||||
|
|
||||||
|
|
||||||
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
||||||
|
|
||||||
class BatchCLIPSeg:
|
class BatchCLIPSeg:
|
||||||
@ -1512,6 +1794,8 @@ class BatchCLIPSeg:
|
|||||||
"text": ("STRING", {"multiline": False}),
|
"text": ("STRING", {"multiline": False}),
|
||||||
"threshold": ("FLOAT", {"default": 0.4,"min": 0.0, "max": 10.0, "step": 0.01}),
|
"threshold": ("FLOAT", {"default": 0.4,"min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
"binary_mask": ("BOOLEAN", {"default": True}),
|
"binary_mask": ("BOOLEAN", {"default": True}),
|
||||||
|
"combine_mask": ("BOOLEAN", {"default": True}),
|
||||||
|
"use_cuda": ("BOOLEAN", {"default": True}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1521,13 +1805,14 @@ class BatchCLIPSeg:
|
|||||||
|
|
||||||
FUNCTION = "segment_image"
|
FUNCTION = "segment_image"
|
||||||
|
|
||||||
def segment_image(self, images, text, threshold, binary_mask):
|
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda):
|
||||||
|
|
||||||
out = []
|
out = []
|
||||||
height, width, _ = images[0].shape
|
height, width, _ = images[0].shape
|
||||||
print(height)
|
if use_cuda and torch.cuda.is_available():
|
||||||
print(width)
|
device = torch.device("cuda")
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||||
model.to(device) # Ensure the model is on the correct device
|
model.to(device) # Ensure the model is on the correct device
|
||||||
images = images.to(device)
|
images = images.to(device)
|
||||||
@ -1561,6 +1846,11 @@ class BatchCLIPSeg:
|
|||||||
out.append(resized_tensor)
|
out.append(resized_tensor)
|
||||||
|
|
||||||
results = torch.stack(out).cpu()
|
results = torch.stack(out).cpu()
|
||||||
|
|
||||||
|
if combine_mask:
|
||||||
|
combined_results = torch.max(results, dim=0)[0]
|
||||||
|
results = combined_results.unsqueeze(0).repeat(len(images),1,1)
|
||||||
|
|
||||||
if binary_mask:
|
if binary_mask:
|
||||||
results = results.round()
|
results = results.round()
|
||||||
|
|
||||||
@ -1581,6 +1871,115 @@ class RoundMask:
|
|||||||
mask = mask.round()
|
mask = mask.round()
|
||||||
return (mask,)
|
return (mask,)
|
||||||
|
|
||||||
|
class ResizeMask:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"mask": ("MASK",),
|
||||||
|
"width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, "display": "number" }),
|
||||||
|
"height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, "display": "number" }),
|
||||||
|
"keep_proportions": ("BOOLEAN", { "default": False }),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MASK", "INT", "INT",)
|
||||||
|
RETURN_NAMES = ("mask", "width", "height",)
|
||||||
|
FUNCTION = "resize"
|
||||||
|
CATEGORY = "KJNodes/masking"
|
||||||
|
|
||||||
|
def resize(self, mask, width, height, keep_proportions):
|
||||||
|
if keep_proportions:
|
||||||
|
_, oh, ow, _ = mask.shape
|
||||||
|
width = ow if width == 0 else width
|
||||||
|
height = oh if height == 0 else height
|
||||||
|
ratio = min(width / ow, height / oh)
|
||||||
|
width = round(ow*ratio)
|
||||||
|
height = round(oh*ratio)
|
||||||
|
|
||||||
|
outputs = mask.unsqueeze(0) # Add an extra dimension for batch size
|
||||||
|
outputs = F.interpolate(outputs, size=(height, width), mode="nearest")
|
||||||
|
outputs = outputs.squeeze(0) # Remove the extra dimension after interpolation
|
||||||
|
|
||||||
|
return(outputs, outputs.shape[2], outputs.shape[1],)
|
||||||
|
|
||||||
|
class OffsetMask:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"mask": ("MASK",),
|
||||||
|
"x": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
|
||||||
|
"y": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
|
||||||
|
"angle": ("INT", { "default": 0, "min": -360, "max": 360, "step": 1, "display": "number" }),
|
||||||
|
"duplication_factor": ("INT", { "default": 1, "min": 1, "max": 1000, "step": 1, "display": "number" }),
|
||||||
|
"roll": ("BOOLEAN", { "default": False }),
|
||||||
|
"incremental": ("BOOLEAN", { "default": False }),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MASK",)
|
||||||
|
RETURN_NAMES = ("mask",)
|
||||||
|
FUNCTION = "offset"
|
||||||
|
CATEGORY = "KJNodes/masking"
|
||||||
|
|
||||||
|
def offset(self, mask, x, y, angle, roll=False, incremental=False, duplication_factor=1):
|
||||||
|
# Create duplicates of the mask batch
|
||||||
|
mask = mask.repeat(duplication_factor, 1, 1)
|
||||||
|
|
||||||
|
batch_size, height, width = mask.shape
|
||||||
|
|
||||||
|
if angle is not 0 and incremental:
|
||||||
|
for i in range(batch_size):
|
||||||
|
rotation_angle = angle * (i+1)
|
||||||
|
mask[i] = TF.rotate(mask[i].unsqueeze(0), rotation_angle).squeeze(0)
|
||||||
|
elif angle > 0:
|
||||||
|
for i in range(batch_size):
|
||||||
|
mask[i] = TF.rotate(mask[i].unsqueeze(0), angle).squeeze(0)
|
||||||
|
|
||||||
|
if roll:
|
||||||
|
if incremental:
|
||||||
|
for i in range(batch_size):
|
||||||
|
shift_x = min(x*(i+1), width-1)
|
||||||
|
shift_y = min(y*(i+1), height-1)
|
||||||
|
if shift_x != 0:
|
||||||
|
mask[i] = torch.roll(mask[i], shifts=shift_x, dims=1)
|
||||||
|
if shift_y != 0:
|
||||||
|
mask[i] = torch.roll(mask[i], shifts=shift_y, dims=0)
|
||||||
|
else:
|
||||||
|
shift_x = min(x, width-1)
|
||||||
|
shift_y = min(y, height-1)
|
||||||
|
if shift_x != 0:
|
||||||
|
mask = torch.roll(mask, shifts=shift_x, dims=2)
|
||||||
|
if shift_y != 0:
|
||||||
|
mask = torch.roll(mask, shifts=shift_y, dims=1)
|
||||||
|
else:
|
||||||
|
if incremental:
|
||||||
|
for i in range(batch_size):
|
||||||
|
temp_x = min(x * (i+1), width-1)
|
||||||
|
temp_y = min(y * (i+1), height-1)
|
||||||
|
if temp_x > 0:
|
||||||
|
mask[i] = torch.cat([torch.zeros((height, temp_x)), mask[i, :, :-temp_x]], dim=1)
|
||||||
|
elif temp_x < 0:
|
||||||
|
mask[i] = torch.cat([mask[i, :, -temp_x:], torch.zeros((height, -temp_x))], dim=1)
|
||||||
|
if temp_y > 0:
|
||||||
|
mask[i] = torch.cat([torch.zeros((temp_y, width)), mask[i, :-temp_y, :]], dim=0)
|
||||||
|
elif temp_y < 0:
|
||||||
|
mask[i] = torch.cat([mask[i, -temp_y:, :], torch.zeros((-temp_y, width))], dim=0)
|
||||||
|
else:
|
||||||
|
temp_x = min(x, width-1)
|
||||||
|
temp_y = min(y, height-1)
|
||||||
|
if temp_x > 0:
|
||||||
|
mask = torch.cat([torch.zeros((batch_size, height, temp_x)), mask[:, :, :-temp_x]], dim=2)
|
||||||
|
elif temp_x < 0:
|
||||||
|
mask = torch.cat([mask[:, :, -temp_x:], torch.zeros((batch_size, height, -temp_x))], dim=2)
|
||||||
|
if temp_y > 0:
|
||||||
|
mask = torch.cat([torch.zeros((batch_size, temp_y, width)), mask[:, :-temp_y, :]], dim=1)
|
||||||
|
elif temp_y < 0:
|
||||||
|
mask = torch.cat([mask[:, -temp_y:, :], torch.zeros((batch_size, -temp_y, width))], dim=1)
|
||||||
|
|
||||||
|
return mask,
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"INTConstant": INTConstant,
|
"INTConstant": INTConstant,
|
||||||
"FloatConstant": FloatConstant,
|
"FloatConstant": FloatConstant,
|
||||||
@ -1610,9 +2009,13 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ImageBatchTestPattern": ImageBatchTestPattern,
|
"ImageBatchTestPattern": ImageBatchTestPattern,
|
||||||
"ReplaceImagesInBatch": ReplaceImagesInBatch,
|
"ReplaceImagesInBatch": ReplaceImagesInBatch,
|
||||||
"BatchCropFromMask": BatchCropFromMask,
|
"BatchCropFromMask": BatchCropFromMask,
|
||||||
|
"BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced,
|
||||||
"BatchUncrop": BatchUncrop,
|
"BatchUncrop": BatchUncrop,
|
||||||
|
"BatchUncropAdvanced": BatchUncropAdvanced,
|
||||||
"BatchCLIPSeg": BatchCLIPSeg,
|
"BatchCLIPSeg": BatchCLIPSeg,
|
||||||
"RoundMask": RoundMask,
|
"RoundMask": RoundMask,
|
||||||
|
"ResizeMask": ResizeMask,
|
||||||
|
"OffsetMask": OffsetMask,
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"INTConstant": "INT Constant",
|
"INTConstant": "INT Constant",
|
||||||
@ -1642,7 +2045,11 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ImageBatchTestPattern": "ImageBatchTestPattern",
|
"ImageBatchTestPattern": "ImageBatchTestPattern",
|
||||||
"ReplaceImagesInBatch": "ReplaceImagesInBatch",
|
"ReplaceImagesInBatch": "ReplaceImagesInBatch",
|
||||||
"BatchCropFromMask": "BatchCropFromMask",
|
"BatchCropFromMask": "BatchCropFromMask",
|
||||||
|
"BatchCropFromMaskAdvanced": "BatchCropFromMaskAdvanced",
|
||||||
"BatchUncrop": "BatchUncrop",
|
"BatchUncrop": "BatchUncrop",
|
||||||
|
"BatchUncropAdvanced": "BatchUncropAdvanced",
|
||||||
"BatchCLIPSeg": "BatchCLIPSeg",
|
"BatchCLIPSeg": "BatchCLIPSeg",
|
||||||
"RoundMask": "RoundMask",
|
"RoundMask": "RoundMask",
|
||||||
|
"ResizeMask": "ResizeMask",
|
||||||
|
"OffsetMask": "OffsetMask",
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user