mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 21:04:41 +08:00
Add advanced crop/uncrop and some mask tools
This commit is contained in:
parent
dccee73432
commit
c4381dfe36
429
nodes.py
429
nodes.py
@ -1,7 +1,8 @@
|
||||
import nodes
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.utils as vutils
|
||||
from torchvision.transforms import Resize, CenterCrop, InterpolationMode
|
||||
from torchvision.transforms import functional as TF
|
||||
import scipy.ndimage
|
||||
import numpy as np
|
||||
from PIL import ImageFilter, Image, ImageDraw, ImageFont
|
||||
@ -10,6 +11,7 @@ import json
|
||||
import re
|
||||
import os
|
||||
import librosa
|
||||
import random
|
||||
from scipy.special import erf
|
||||
from .fluid import Fluid
|
||||
import comfy.model_management
|
||||
@ -188,8 +190,6 @@ class CreateAudioMask:
|
||||
return (1.0 - torch.cat(out, dim=0),)
|
||||
return (torch.cat(out, dim=0),torch.cat(masks, dim=0),)
|
||||
|
||||
|
||||
|
||||
class CreateGradientMask:
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
@ -531,6 +531,8 @@ class GrowMaskWithBlur:
|
||||
"max": 10.0,
|
||||
"step": 0.1
|
||||
}),
|
||||
"lerp_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"decay_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
},
|
||||
}
|
||||
|
||||
@ -540,7 +542,9 @@ class GrowMaskWithBlur:
|
||||
RETURN_NAMES = ("mask", "mask_inverted",)
|
||||
FUNCTION = "expand_mask"
|
||||
|
||||
def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, sigma, incremental_expandrate, use_cuda):
|
||||
def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, sigma, incremental_expandrate, use_cuda, lerp_alpha, decay_factor):
|
||||
alpha = lerp_alpha
|
||||
decay = decay_factor
|
||||
if( flip_input ):
|
||||
mask = 1.0 - mask
|
||||
c = 0 if tapered_corners else 1
|
||||
@ -549,6 +553,7 @@ class GrowMaskWithBlur:
|
||||
[c, 1, c]])
|
||||
growmask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
||||
out = []
|
||||
previous_output = None
|
||||
for m in growmask:
|
||||
output = m.numpy()
|
||||
for _ in range(abs(expand)):
|
||||
@ -561,6 +566,14 @@ class GrowMaskWithBlur:
|
||||
else:
|
||||
expand += abs(incremental_expandrate) # Use abs(growrate) to ensure positive change
|
||||
output = torch.from_numpy(output)
|
||||
if alpha < 1.0 and previous_output is not None:
|
||||
# Interpolate between the previous and current frame
|
||||
output = alpha * output + (1 - alpha) * previous_output
|
||||
if decay < 1.0 and previous_output is not None:
|
||||
# Add the decayed previous output to the current frame
|
||||
output += decay * previous_output
|
||||
output = output / output.max()
|
||||
previous_output = output
|
||||
out.append(output)
|
||||
|
||||
blurred = torch.stack(out, dim=0).reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
@ -568,7 +581,7 @@ class GrowMaskWithBlur:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
blurred = blurred.to(device) # Move blurred tensor to the GPU
|
||||
|
||||
batch_size, height, width, channels = blurred.shape
|
||||
channels = blurred.shape[-1]
|
||||
if blur_radius != 0:
|
||||
blurkernel_size = blur_radius * 2 + 1
|
||||
blurkernel = gaussian_kernel(blurkernel_size, sigma, device=blurred.device).repeat(channels, 1, 1).unsqueeze(1)
|
||||
@ -1215,7 +1228,6 @@ class ImageGridComposite3x3:
|
||||
grid = torch.cat((top_row, mid_row, bottom_row), dim=1)
|
||||
return (grid,)
|
||||
|
||||
import random
|
||||
class ImageBatchTestPattern:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1272,7 +1284,7 @@ class ImageBatchTestPattern:
|
||||
#based on nodes from mtb https://github.com/melMass/comfy_mtb
|
||||
|
||||
from .utility import tensor2pil, pil2tensor, tensor2np, np2tensor
|
||||
from torchvision.transforms import Resize, CenterCrop
|
||||
|
||||
|
||||
class BatchCropFromMask:
|
||||
|
||||
@ -1496,6 +1508,276 @@ class BatchUncrop:
|
||||
|
||||
return (pil2tensor(out_images),)
|
||||
|
||||
class BatchCropFromMaskAdvanced:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"original_images": ("IMAGE",),
|
||||
"masks": ("MASK",),
|
||||
"crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (
|
||||
"IMAGE",
|
||||
"IMAGE",
|
||||
"MASK",
|
||||
"IMAGE",
|
||||
"MASK",
|
||||
"BBOX",
|
||||
"BBOX",
|
||||
"INT",
|
||||
"INT",
|
||||
)
|
||||
RETURN_NAMES = (
|
||||
"original_images",
|
||||
"cropped_images",
|
||||
"cropped_masks",
|
||||
"combined_crop_image",
|
||||
"combined_crop_masks",
|
||||
"bboxes",
|
||||
"combined_bounding_box",
|
||||
"bbox_width",
|
||||
"bbox_height",
|
||||
)
|
||||
FUNCTION = "crop"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
|
||||
return int(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
|
||||
|
||||
def smooth_center(self, prev_center, curr_center, alpha=0.5):
|
||||
return (int(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
|
||||
int(alpha * curr_center[1] + (1 - alpha) * prev_center[1]))
|
||||
|
||||
def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
|
||||
bounding_boxes = []
|
||||
combined_bounding_box = []
|
||||
cropped_images = []
|
||||
cropped_masks = []
|
||||
cropped_masks_out = []
|
||||
combined_crop_out = []
|
||||
combined_cropped_images = []
|
||||
combined_cropped_masks = []
|
||||
|
||||
def calculate_bbox(mask):
|
||||
non_zero_indices = np.nonzero(np.array(mask))
|
||||
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
|
||||
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
|
||||
width = max_x - min_x
|
||||
height = max_y - min_y
|
||||
bbox_size = max(width, height)
|
||||
return min_x, max_x, min_y, max_y, bbox_size
|
||||
|
||||
combined_mask = torch.max(masks, dim=0)[0]
|
||||
_mask = tensor2pil(combined_mask)[0]
|
||||
new_min_x, new_max_x, new_min_y, new_max_y, combined_bbox_size = calculate_bbox(_mask)
|
||||
center_x = (new_min_x + new_max_x) / 2
|
||||
center_y = (new_min_y + new_max_y) / 2
|
||||
half_box_size = int(combined_bbox_size // 2)
|
||||
new_min_x = max(0, int(center_x - half_box_size))
|
||||
new_max_x = min(original_images[0].shape[1], int(center_x + half_box_size))
|
||||
new_min_y = max(0, int(center_y - half_box_size))
|
||||
new_max_y = min(original_images[0].shape[0], int(center_y + half_box_size))
|
||||
|
||||
combined_bounding_box.append((new_min_x, new_min_y, new_max_x - new_min_x, new_max_y - new_min_y))
|
||||
|
||||
self.max_bbox_size = 0
|
||||
|
||||
# First, calculate the maximum bounding box size across all masks
|
||||
curr_max_bbox_size = max(calculate_bbox(tensor2pil(mask)[0])[-1] for mask in masks)
|
||||
# Smooth the changes in the bounding box size
|
||||
self.max_bbox_size = self.smooth_bbox_size(self.max_bbox_size, curr_max_bbox_size, bbox_smooth_alpha)
|
||||
# Apply the crop size multiplier
|
||||
self.max_bbox_size = int(self.max_bbox_size * crop_size_mult)
|
||||
# Make sure max_bbox_size is divisible by 32, if not, round it upwards so it is
|
||||
self.max_bbox_size = math.ceil(self.max_bbox_size / 32) * 32
|
||||
|
||||
# Then, for each mask and corresponding image...
|
||||
for i, (mask, img) in enumerate(zip(masks, original_images)):
|
||||
_mask = tensor2pil(mask)[0]
|
||||
non_zero_indices = np.nonzero(np.array(_mask))
|
||||
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
|
||||
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
|
||||
|
||||
# Calculate center of bounding box
|
||||
center_x = np.mean(non_zero_indices[1])
|
||||
center_y = np.mean(non_zero_indices[0])
|
||||
curr_center = (int(center_x), int(center_y))
|
||||
|
||||
# If this is the first frame, initialize prev_center with curr_center
|
||||
if not hasattr(self, 'prev_center'):
|
||||
self.prev_center = curr_center
|
||||
|
||||
# Smooth the changes in the center coordinates from the second frame onwards
|
||||
if i > 0:
|
||||
center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha)
|
||||
else:
|
||||
center = curr_center
|
||||
|
||||
# Update prev_center for the next frame
|
||||
self.prev_center = center
|
||||
|
||||
# Create bounding box using max_bbox_size
|
||||
half_box_size = self.max_bbox_size // 2
|
||||
half_box_size = self.max_bbox_size // 2
|
||||
min_x = max(0, center[0] - half_box_size)
|
||||
max_x = min(img.shape[1], center[0] + half_box_size)
|
||||
min_y = max(0, center[1] - half_box_size)
|
||||
max_y = min(img.shape[0], center[1] + half_box_size)
|
||||
|
||||
# Append bounding box coordinates
|
||||
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
|
||||
|
||||
# Crop the image from the bounding box
|
||||
cropped_img = img[min_y:max_y, min_x:max_x, :]
|
||||
cropped_mask = mask[min_y:max_y, min_x:max_x]
|
||||
|
||||
# Resize the cropped image to a fixed size
|
||||
new_size = max(cropped_img.shape[0], cropped_img.shape[1])
|
||||
resize_transform = Resize(new_size, interpolation = InterpolationMode.NEAREST)
|
||||
resized_mask = resize_transform(cropped_mask.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
|
||||
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
|
||||
# Perform the center crop to the desired size
|
||||
crop_transform = CenterCrop((self.max_bbox_size, self.max_bbox_size))
|
||||
|
||||
cropped_resized_img = crop_transform(resized_img)
|
||||
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
|
||||
|
||||
cropped_resized_mask = crop_transform(resized_mask)
|
||||
cropped_masks.append(cropped_resized_mask)
|
||||
|
||||
combined_cropped_img = original_images[i][new_min_y:new_max_y, new_min_x:new_max_x, :]
|
||||
combined_cropped_images.append(combined_cropped_img)
|
||||
|
||||
combined_cropped_mask = masks[i][new_min_y:new_max_y, new_min_x:new_max_x]
|
||||
combined_cropped_masks.append(combined_cropped_mask)
|
||||
|
||||
cropped_out = torch.stack(cropped_images, dim=0)
|
||||
combined_crop_out = torch.stack(combined_cropped_images, dim=0)
|
||||
cropped_masks_out = torch.stack(cropped_masks, dim=0)
|
||||
combined_crop_mask_out = torch.stack(combined_cropped_masks, dim=0)
|
||||
|
||||
return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size)
|
||||
|
||||
|
||||
def bbox_to_region(bbox, target_size=None):
|
||||
bbox = bbox_check(bbox, target_size)
|
||||
return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
|
||||
|
||||
def bbox_check(bbox, target_size=None):
|
||||
if not target_size:
|
||||
return bbox
|
||||
|
||||
new_bbox = (
|
||||
bbox[0],
|
||||
bbox[1],
|
||||
min(target_size[0] - bbox[0], bbox[2]),
|
||||
min(target_size[1] - bbox[1], bbox[3]),
|
||||
)
|
||||
return new_bbox
|
||||
|
||||
class BatchUncropAdvanced:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"original_images": ("IMAGE",),
|
||||
"cropped_images": ("IMAGE",),
|
||||
"cropped_masks": ("MASK",),
|
||||
"combined_crop_mask": ("MASK",),
|
||||
"bboxes": ("BBOX",),
|
||||
"border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
||||
"crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"use_combined_mask": ("BOOLEAN", {"default": False}),
|
||||
"use_square_mask": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional": {
|
||||
"combined_bounding_box": ("BBOX", {"default": None}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "uncrop"
|
||||
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
def uncrop(self, original_images, cropped_images, cropped_masks, combined_crop_mask, bboxes, border_blending, crop_rescale, use_combined_mask, use_square_mask, combined_bounding_box = None):
|
||||
|
||||
def inset_border(image, border_width=20, border_color=(0)):
|
||||
width, height = image.size
|
||||
bordered_image = Image.new(image.mode, (width, height), border_color)
|
||||
bordered_image.paste(image, (0, 0))
|
||||
draw = ImageDraw.Draw(bordered_image)
|
||||
draw.rectangle((0, 0, width - 1, height - 1), outline=border_color, width=border_width)
|
||||
return bordered_image
|
||||
|
||||
if len(original_images) != len(cropped_images) or len(original_images) != len(bboxes):
|
||||
raise ValueError("The number of images, crop_images, and bboxes should be the same")
|
||||
|
||||
crop_imgs = tensor2pil(cropped_images)
|
||||
input_images = tensor2pil(original_images)
|
||||
out_images = []
|
||||
|
||||
for i in range(len(input_images)):
|
||||
img = input_images[i]
|
||||
crop = crop_imgs[i]
|
||||
bbox = bboxes[i]
|
||||
|
||||
|
||||
if use_combined_mask:
|
||||
bb_x, bb_y, bb_width, bb_height = combined_bounding_box[0]
|
||||
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
|
||||
mask = combined_crop_mask[i]
|
||||
else:
|
||||
bb_x, bb_y, bb_width, bb_height = bbox
|
||||
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
|
||||
mask = cropped_masks[i]
|
||||
|
||||
# scale paste_region
|
||||
scale_x = scale_y = crop_rescale
|
||||
paste_region = (int(paste_region[0]*scale_x), int(paste_region[1]*scale_y), int(paste_region[2]*scale_x), int(paste_region[3]*scale_y))
|
||||
|
||||
# rescale the crop image to fit the paste_region
|
||||
crop = crop.resize((int(paste_region[2]-paste_region[0]), int(paste_region[3]-paste_region[1])))
|
||||
crop_img = crop.convert("RGB")
|
||||
|
||||
#border blending
|
||||
if border_blending > 1.0:
|
||||
border_blending = 1.0
|
||||
elif border_blending < 0.0:
|
||||
border_blending = 0.0
|
||||
|
||||
blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
|
||||
blend = img.convert("RGBA")
|
||||
|
||||
if use_square_mask:
|
||||
mask = Image.new("L", img.size, 0)
|
||||
mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255)
|
||||
mask_block = inset_border(mask_block, int(blend_ratio / 2), (0))
|
||||
mask.paste(mask_block, paste_region)
|
||||
else:
|
||||
original_mask = tensor2pil(mask)[0]
|
||||
original_mask = original_mask.resize((paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]))
|
||||
mask = Image.new("L", img.size, 0)
|
||||
mask.paste(original_mask, paste_region)
|
||||
|
||||
mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
|
||||
mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
|
||||
|
||||
blend.paste(crop_img, paste_region)
|
||||
blend.putalpha(mask)
|
||||
|
||||
img = Image.alpha_composite(img.convert("RGBA"), blend)
|
||||
out_images.append(img.convert("RGB"))
|
||||
|
||||
return (pil2tensor(out_images),)
|
||||
|
||||
|
||||
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
||||
|
||||
class BatchCLIPSeg:
|
||||
@ -1512,6 +1794,8 @@ class BatchCLIPSeg:
|
||||
"text": ("STRING", {"multiline": False}),
|
||||
"threshold": ("FLOAT", {"default": 0.4,"min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"binary_mask": ("BOOLEAN", {"default": True}),
|
||||
"combine_mask": ("BOOLEAN", {"default": True}),
|
||||
"use_cuda": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
@ -1521,13 +1805,14 @@ class BatchCLIPSeg:
|
||||
|
||||
FUNCTION = "segment_image"
|
||||
|
||||
def segment_image(self, images, text, threshold, binary_mask):
|
||||
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda):
|
||||
|
||||
out = []
|
||||
height, width, _ = images[0].shape
|
||||
print(height)
|
||||
print(width)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if use_cuda and torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||
model.to(device) # Ensure the model is on the correct device
|
||||
images = images.to(device)
|
||||
@ -1561,6 +1846,11 @@ class BatchCLIPSeg:
|
||||
out.append(resized_tensor)
|
||||
|
||||
results = torch.stack(out).cpu()
|
||||
|
||||
if combine_mask:
|
||||
combined_results = torch.max(results, dim=0)[0]
|
||||
results = combined_results.unsqueeze(0).repeat(len(images),1,1)
|
||||
|
||||
if binary_mask:
|
||||
results = results.round()
|
||||
|
||||
@ -1581,6 +1871,115 @@ class RoundMask:
|
||||
mask = mask.round()
|
||||
return (mask,)
|
||||
|
||||
class ResizeMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, "display": "number" }),
|
||||
"height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, "display": "number" }),
|
||||
"keep_proportions": ("BOOLEAN", { "default": False }),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK", "INT", "INT",)
|
||||
RETURN_NAMES = ("mask", "width", "height",)
|
||||
FUNCTION = "resize"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
def resize(self, mask, width, height, keep_proportions):
|
||||
if keep_proportions:
|
||||
_, oh, ow, _ = mask.shape
|
||||
width = ow if width == 0 else width
|
||||
height = oh if height == 0 else height
|
||||
ratio = min(width / ow, height / oh)
|
||||
width = round(ow*ratio)
|
||||
height = round(oh*ratio)
|
||||
|
||||
outputs = mask.unsqueeze(0) # Add an extra dimension for batch size
|
||||
outputs = F.interpolate(outputs, size=(height, width), mode="nearest")
|
||||
outputs = outputs.squeeze(0) # Remove the extra dimension after interpolation
|
||||
|
||||
return(outputs, outputs.shape[2], outputs.shape[1],)
|
||||
|
||||
class OffsetMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"x": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
|
||||
"y": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
|
||||
"angle": ("INT", { "default": 0, "min": -360, "max": 360, "step": 1, "display": "number" }),
|
||||
"duplication_factor": ("INT", { "default": 1, "min": 1, "max": 1000, "step": 1, "display": "number" }),
|
||||
"roll": ("BOOLEAN", { "default": False }),
|
||||
"incremental": ("BOOLEAN", { "default": False }),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
RETURN_NAMES = ("mask",)
|
||||
FUNCTION = "offset"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
def offset(self, mask, x, y, angle, roll=False, incremental=False, duplication_factor=1):
|
||||
# Create duplicates of the mask batch
|
||||
mask = mask.repeat(duplication_factor, 1, 1)
|
||||
|
||||
batch_size, height, width = mask.shape
|
||||
|
||||
if angle is not 0 and incremental:
|
||||
for i in range(batch_size):
|
||||
rotation_angle = angle * (i+1)
|
||||
mask[i] = TF.rotate(mask[i].unsqueeze(0), rotation_angle).squeeze(0)
|
||||
elif angle > 0:
|
||||
for i in range(batch_size):
|
||||
mask[i] = TF.rotate(mask[i].unsqueeze(0), angle).squeeze(0)
|
||||
|
||||
if roll:
|
||||
if incremental:
|
||||
for i in range(batch_size):
|
||||
shift_x = min(x*(i+1), width-1)
|
||||
shift_y = min(y*(i+1), height-1)
|
||||
if shift_x != 0:
|
||||
mask[i] = torch.roll(mask[i], shifts=shift_x, dims=1)
|
||||
if shift_y != 0:
|
||||
mask[i] = torch.roll(mask[i], shifts=shift_y, dims=0)
|
||||
else:
|
||||
shift_x = min(x, width-1)
|
||||
shift_y = min(y, height-1)
|
||||
if shift_x != 0:
|
||||
mask = torch.roll(mask, shifts=shift_x, dims=2)
|
||||
if shift_y != 0:
|
||||
mask = torch.roll(mask, shifts=shift_y, dims=1)
|
||||
else:
|
||||
if incremental:
|
||||
for i in range(batch_size):
|
||||
temp_x = min(x * (i+1), width-1)
|
||||
temp_y = min(y * (i+1), height-1)
|
||||
if temp_x > 0:
|
||||
mask[i] = torch.cat([torch.zeros((height, temp_x)), mask[i, :, :-temp_x]], dim=1)
|
||||
elif temp_x < 0:
|
||||
mask[i] = torch.cat([mask[i, :, -temp_x:], torch.zeros((height, -temp_x))], dim=1)
|
||||
if temp_y > 0:
|
||||
mask[i] = torch.cat([torch.zeros((temp_y, width)), mask[i, :-temp_y, :]], dim=0)
|
||||
elif temp_y < 0:
|
||||
mask[i] = torch.cat([mask[i, -temp_y:, :], torch.zeros((-temp_y, width))], dim=0)
|
||||
else:
|
||||
temp_x = min(x, width-1)
|
||||
temp_y = min(y, height-1)
|
||||
if temp_x > 0:
|
||||
mask = torch.cat([torch.zeros((batch_size, height, temp_x)), mask[:, :, :-temp_x]], dim=2)
|
||||
elif temp_x < 0:
|
||||
mask = torch.cat([mask[:, :, -temp_x:], torch.zeros((batch_size, height, -temp_x))], dim=2)
|
||||
if temp_y > 0:
|
||||
mask = torch.cat([torch.zeros((batch_size, temp_y, width)), mask[:, :-temp_y, :]], dim=1)
|
||||
elif temp_y < 0:
|
||||
mask = torch.cat([mask[:, -temp_y:, :], torch.zeros((batch_size, -temp_y, width))], dim=1)
|
||||
|
||||
return mask,
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"INTConstant": INTConstant,
|
||||
"FloatConstant": FloatConstant,
|
||||
@ -1610,9 +2009,13 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ImageBatchTestPattern": ImageBatchTestPattern,
|
||||
"ReplaceImagesInBatch": ReplaceImagesInBatch,
|
||||
"BatchCropFromMask": BatchCropFromMask,
|
||||
"BatchCropFromMaskAdvanced": BatchCropFromMaskAdvanced,
|
||||
"BatchUncrop": BatchUncrop,
|
||||
"BatchUncropAdvanced": BatchUncropAdvanced,
|
||||
"BatchCLIPSeg": BatchCLIPSeg,
|
||||
"RoundMask": RoundMask,
|
||||
"ResizeMask": ResizeMask,
|
||||
"OffsetMask": OffsetMask,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"INTConstant": "INT Constant",
|
||||
@ -1642,7 +2045,11 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageBatchTestPattern": "ImageBatchTestPattern",
|
||||
"ReplaceImagesInBatch": "ReplaceImagesInBatch",
|
||||
"BatchCropFromMask": "BatchCropFromMask",
|
||||
"BatchCropFromMaskAdvanced": "BatchCropFromMaskAdvanced",
|
||||
"BatchUncrop": "BatchUncrop",
|
||||
"BatchUncropAdvanced": "BatchUncropAdvanced",
|
||||
"BatchCLIPSeg": "BatchCLIPSeg",
|
||||
"RoundMask": "RoundMask",
|
||||
"ResizeMask": "ResizeMask",
|
||||
"OffsetMask": "OffsetMask",
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user