Add ImageCropByMaskBatch, SeparateMasks

This commit is contained in:
kijai 2025-02-08 16:37:08 +02:00
parent a22b269242
commit bfe72cc964
4 changed files with 286 additions and 13 deletions

View File

@ -40,9 +40,11 @@ NODE_CONFIG = {
"RemapMaskRange": {"class": RemapMaskRange, "name": "Remap Mask Range"},
"ResizeMask": {"class": ResizeMask, "name": "Resize Mask"},
"RoundMask": {"class": RoundMask, "name": "Round Mask"},
"SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"},
#images
"AddLabel": {"class": AddLabel, "name": "Add Label"},
"ColorMatch": {"class": ColorMatch, "name": "Color Match"},
"ImageTensorList": {"class": ImageTensorList, "name": "Image Tensor List"},
"CrossFadeImages": {"class": CrossFadeImages, "name": "Cross Fade Images"},
"CrossFadeImagesMulti": {"class": CrossFadeImagesMulti, "name": "Cross Fade Images Multi"},
"GetImagesFromBatchIndexed": {"class": GetImagesFromBatchIndexed, "name": "Get Images From Batch Indexed"},
@ -59,6 +61,7 @@ NODE_CONFIG = {
"ImageConcatFromBatch": {"class": ImageConcatFromBatch, "name": "Image Concatenate From Batch"},
"ImageConcatMulti": {"class": ImageConcatMulti, "name": "Image Concatenate Multi"},
"ImageCropByMaskAndResize": {"class": ImageCropByMaskAndResize, "name": "Image Crop By Mask And Resize"},
"ImageCropByMaskBatch": {"class": ImageCropByMaskBatch, "name": "Image Crop By Mask Batch"},
"ImageUncropByMask": {"class": ImageUncropByMask, "name": "Image Uncrop By Mask"},
"ImageGrabPIL": {"class": ImageGrabPIL, "name": "Image Grab PIL"},
"ImageGridComposite2x2": {"class": ImageGridComposite2x2, "name": "Image Grid Composite 2x2"},

View File

@ -389,6 +389,7 @@ class ImageConcatFromBatch:
grid[row*height:(row+1)*height, col*width:(col+1)*width, :] = resized_image
return grid.unsqueeze(0),
class ImageGridComposite2x2:
@classmethod
@ -502,7 +503,7 @@ class ImageGrabPIL:
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "screencap"
CATEGORY = "KJNodes/experimental"
CATEGORY = "KJNodes/image"
DESCRIPTION = """
Captures an area specified by screen coordinates.
Can be used for realtime diffusion with autoqueue.
@ -550,7 +551,7 @@ class Screencap_mss:
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "screencap"
CATEGORY = "KJNodes/experimental"
CATEGORY = "KJNodes/image"
DESCRIPTION = """
Captures an area specified by screen coordinates.
Can be used for realtime diffusion with autoqueue.
@ -1116,7 +1117,7 @@ class ImageAndMaskPreview(SaveImage):
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("composite",)
FUNCTION = "execute"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/masking"
DESCRIPTION = """
Preview an image or a mask, when both inputs are used
composites the mask on top of the image.
@ -1804,6 +1805,35 @@ with the **inputcount** and clicking update.
image, = image_batch_node.batch(image, new_image)
return (image,)
class ImageTensorList:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image1": ("IMAGE",),
"image2": ("IMAGE",),
}}
RETURN_TYPES = ("IMAGE",)
#OUTPUT_IS_LIST = (True,)
FUNCTION = "append"
CATEGORY = "KJNodes/image"
DESCRIPTION = """
Creates an image list from the input images.
"""
def append(self, image1, image2):
image_list = []
if isinstance(image1, torch.Tensor) and isinstance(image2, torch.Tensor):
image_list = [image1, image2]
elif isinstance(image1, list) and isinstance(image2, torch.Tensor):
image_list = image1 + [image2]
elif isinstance(image1, torch.Tensor) and isinstance(image2, list):
image_list = [image1] + image2
elif isinstance(image1, list) and isinstance(image2, list):
image_list = image1 + image2
return image_list,
class ImageAddMulti:
@classmethod
def INPUT_TYPES(s):
@ -2723,4 +2753,106 @@ class ImageUncropByMask:
output_list.append(result)
return (torch.stack(output_list),)
return (torch.stack(output_list),)
class ImageCropByMaskBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE", ),
"masks": ("MASK", ),
"width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
"height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }),
"padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1, }),
"preserve_size": ("BOOLEAN", {"default": False}),
"bg_color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255, separated by commas."}),
}
}
RETURN_TYPES = ("IMAGE", "MASK", )
RETURN_NAMES = ("images", "masks",)
FUNCTION = "crop"
CATEGORY = "KJNodes/image"
DESCRIPTION = "Crops the input images based on the provided masks."
def crop(self, image, masks, width, height, bg_color, padding, preserve_size):
B, H, W, C = image.shape
BM, HM, WM = masks.shape
mask_count = BM
if HM != H or WM != W:
masks = F.interpolate(masks.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)
print(masks.shape)
output_images = []
output_masks = []
bg_color = [int(x.strip())/255.0 for x in bg_color.split(",")]
# For each mask
for i in range(mask_count):
curr_mask = masks[i]
# Find bounds
y_indices, x_indices = torch.nonzero(curr_mask, as_tuple=True)
if len(y_indices) == 0 or len(x_indices) == 0:
continue
# Get exact bounds with padding
min_y = max(0, y_indices.min().item() - padding)
max_y = min(H, y_indices.max().item() + 1 + padding)
min_x = max(0, x_indices.min().item() - padding)
max_x = min(W, x_indices.max().item() + 1 + padding)
# Ensure mask has correct shape for multiplication
curr_mask = curr_mask.unsqueeze(-1).expand(-1, -1, C)
# Crop image and mask together
cropped_img = image[0, min_y:max_y, min_x:max_x, :]
cropped_mask = curr_mask[min_y:max_y, min_x:max_x, :]
crop_h, crop_w = cropped_img.shape[0:2]
new_w = crop_w
new_h = crop_h
if not preserve_size or crop_w > width or crop_h > height:
scale = min(width/crop_w, height/crop_h)
new_w = int(crop_w * scale)
new_h = int(crop_h * scale)
# Resize RGB
resized_img = common_upscale(cropped_img.permute(2,0,1).unsqueeze(0), new_w, new_h, "lanczos", "disabled").squeeze(0).permute(1,2,0)
resized_mask = torch.nn.functional.interpolate(
cropped_mask.permute(2,0,1).unsqueeze(0),
size=(new_h, new_w),
mode='nearest'
).squeeze(0).permute(1,2,0)
else:
resized_img = cropped_img
resized_mask = cropped_mask
# Create empty tensors
new_img = torch.zeros((height, width, 3), dtype=image.dtype)
new_mask = torch.zeros((height, width), dtype=image.dtype)
# Pad both
pad_x = (width - new_w) // 2
pad_y = (height - new_h) // 2
new_img[pad_y:pad_y+new_h, pad_x:pad_x+new_w, :] = resized_img
if len(resized_mask.shape) == 3:
resized_mask = resized_mask[:,:,0] # Take first channel if 3D
new_mask[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = resized_mask
output_images.append(new_img)
output_masks.append(new_mask)
if not output_images:
return (torch.zeros((0, height, width, 3), dtype=image.dtype),)
out_rgb = torch.stack(output_images, dim=0)
out_masks = torch.stack(output_masks, dim=0)
# Apply mask to RGB
mask_expanded = out_masks.unsqueeze(-1).expand(-1, -1, -1, 3)
background_color = torch.tensor(bg_color, dtype=torch.float32, device=image.device)
out_rgb = out_rgb * mask_expanded + background_color * (1 - mask_expanded)
return (out_rgb, out_masks)

View File

@ -1251,3 +1251,140 @@ Sets new min and max values for the mask.
scaled_mask = torch.clamp(scaled_mask, min=0.0, max=1.0)
return (scaled_mask, )
def get_mask_polygon(self, mask_np):
import cv2
"""Helper function to get polygon points from mask"""
# Find contours
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
# Get the largest contour
largest_contour = max(contours, key=cv2.contourArea)
# Approximate polygon
epsilon = 0.02 * cv2.arcLength(largest_contour, True)
polygon = cv2.approxPolyDP(largest_contour, epsilon, True)
return polygon.squeeze()
import cv2
class SeparateMasks:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK", ),
"size_threshold_width" : ("INT", {"default": 256, "min": 0.0, "max": 4096, "step": 1}),
"size_threshold_height" : ("INT", {"default": 256, "min": 0.0, "max": 4096, "step": 1}),
"mode": (["convex_polygons", "area"],),
"max_poly_points": ("INT", {"default": 8, "min": 3, "max": 32, "step": 1}),
},
}
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("mask",)
FUNCTION = "separate"
CATEGORY = "KJNodes/masking"
OUTPUT_NODE = True
def polygon_to_mask(self, polygon, shape):
mask = np.zeros((shape[0], shape[1]), dtype=np.uint8) # Fixed shape handling
if len(polygon.shape) == 2: # Check if polygon points are valid
polygon = polygon.astype(np.int32)
cv2.fillPoly(mask, [polygon], 1)
return mask
def get_mask_polygon(self, mask_np, max_points):
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
largest_contour = max(contours, key=cv2.contourArea)
hull = cv2.convexHull(largest_contour)
# Initialize with smaller epsilon for more points
perimeter = cv2.arcLength(hull, True)
epsilon = perimeter * 0.01 # Start smaller
min_eps = perimeter * 0.001 # Much smaller minimum
max_eps = perimeter * 0.2 # Smaller maximum
best_approx = None
best_diff = float('inf')
max_iterations = 20
#print(f"Target points: {max_points}, Perimeter: {perimeter}")
for i in range(max_iterations):
curr_eps = (min_eps + max_eps) / 2
approx = cv2.approxPolyDP(hull, curr_eps, True)
points_diff = len(approx) - max_points
#print(f"Iteration {i}: points={len(approx)}, eps={curr_eps:.4f}")
if abs(points_diff) < best_diff:
best_approx = approx
best_diff = abs(points_diff)
if len(approx) > max_points:
min_eps = curr_eps * 1.1 # More gradual adjustment
elif len(approx) < max_points:
max_eps = curr_eps * 0.9 # More gradual adjustment
else:
return approx.squeeze()
if abs(max_eps - min_eps) < perimeter * 0.0001: # Relative tolerance
break
# If we didn't find exact match, return best approximation
return best_approx.squeeze() if best_approx is not None else hull.squeeze()
def separate(self, mask: torch.Tensor, size_threshold_width: int, size_threshold_height: int, max_poly_points: int, mode: str):
from scipy.ndimage import label
import numpy as np
B, H, W = mask.shape
separated = []
mask = mask.round()
for b in range(B):
mask_np = mask[b].cpu().numpy().astype(np.uint8)
structure = np.ones((3, 3), dtype=np.int8)
labeled, ncomponents = label(mask_np, structure=structure)
pbar = ProgressBar(ncomponents)
for component in range(1, ncomponents + 1):
component_mask_np = (labeled == component).astype(np.uint8)
rows = np.any(component_mask_np, axis=1)
cols = np.any(component_mask_np, axis=0)
y_min, y_max = np.where(rows)[0][[0, -1]]
x_min, x_max = np.where(cols)[0][[0, -1]]
width = x_max - x_min + 1
height = y_max - y_min + 1
print(f"Component {component}: width={width}, height={height}")
if width >= size_threshold_width and height >= size_threshold_height:
polygon = self.get_mask_polygon(component_mask_np, max_poly_points)
if mode != "area" and polygon is not None:
poly_mask = self.polygon_to_mask(polygon, (H, W))
poly_mask = torch.tensor(poly_mask, device=mask.device)
separated.append(poly_mask)
elif mode == "area":
area_mask = torch.tensor(component_mask_np, device=mask.device)
separated.append(area_mask)
pbar.update(1)
if len(separated) > 0:
out_masks = torch.stack(separated, dim=0)
return out_masks,
else:
return torch.empty((1, 64, 64), device=mask.device),

View File

@ -111,7 +111,7 @@ class ScaleBatchPromptSchedule:
RETURN_TYPES = ("STRING",)
FUNCTION = "scaleschedule"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/misc"
DESCRIPTION = """
Scales a batch schedule from Fizz' nodes BatchPromptSchedule
to a different frame count.
@ -155,7 +155,7 @@ class GetLatentsFromBatchIndexed:
RETURN_TYPES = ("LATENT",)
FUNCTION = "indexedlatentsfrombatch"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/latents"
DESCRIPTION = """
Selects and returns the latents at the specified indices as an latent batch.
"""
@ -238,7 +238,7 @@ class AppendStringsToList:
}
RETURN_TYPES = ("STRING",)
FUNCTION = "joinstring"
CATEGORY = "KJNodes/constants"
CATEGORY = "KJNodes/text"
def joinstring(self, string1, string2):
if not isinstance(string1, list):
@ -261,7 +261,7 @@ class JoinStrings:
}
RETURN_TYPES = ("STRING",)
FUNCTION = "joinstring"
CATEGORY = "KJNodes/constants"
CATEGORY = "KJNodes/text"
def joinstring(self, string1, string2, delimiter):
joined_string = string1 + delimiter + string2
@ -283,7 +283,7 @@ class JoinStringMulti:
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("string",)
FUNCTION = "combine"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/text"
DESCRIPTION = """
Creates single string, or a list of strings, from
multiple input strings.
@ -738,7 +738,7 @@ class EmptyLatentImagePresets:
RETURN_TYPES = ("LATENT", "INT", "INT")
RETURN_NAMES = ("Latent", "Width", "Height")
FUNCTION = "generate"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/latents"
def generate(self, dimensions, invert, batch_size):
from nodes import EmptyLatentImage
@ -784,7 +784,7 @@ class EmptyLatentImageCustomPresets:
RETURN_TYPES = ("LATENT", "INT", "INT")
RETURN_NAMES = ("Latent", "Width", "Height")
FUNCTION = "generate"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/latents"
DESCRIPTION = """
Generates an empty latent image with the specified dimensions.
The choices are loaded from 'custom_dimensions.json' in the nodes folder.
@ -1113,7 +1113,7 @@ class GenerateNoise:
"model": ("MODEL", ),
"sigmas": ("SIGMAS", ),
"latent_channels": (['4', '16', ],),
"shape": (["BCHW", "BCTHW"],),
"shape": (["BCHW", "BCTHW","BTCHW",],),
}
}
@ -1131,7 +1131,8 @@ Generates noise for injection or to be used as empty latents on samplers with ad
noise = torch.randn([batch_size, int(latent_channels), height // 8, width // 8], dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu")
elif shape == "BCTHW":
noise = torch.randn([1, int(latent_channels), batch_size,height // 8, width // 8], dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu")
print(noise.shape)
elif shape == "BTCHW":
noise = torch.randn([1, batch_size, int(latent_channels), height // 8, width // 8], dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu")
if sigmas is not None:
sigma = sigmas[0] - sigmas[-1]
sigma /= model.model.latent_format.scale_factor