From bfe72cc964ded40a2fdf3735b9a8914377155aee Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 8 Feb 2025 16:37:08 +0200 Subject: [PATCH] Add ImageCropByMaskBatch, SeparateMasks --- __init__.py | 3 + nodes/image_nodes.py | 140 +++++++++++++++++++++++++++++++++++++++++-- nodes/mask_nodes.py | 137 ++++++++++++++++++++++++++++++++++++++++++ nodes/nodes.py | 19 +++--- 4 files changed, 286 insertions(+), 13 deletions(-) diff --git a/__init__.py b/__init__.py index 8888f3d..4dc0053 100644 --- a/__init__.py +++ b/__init__.py @@ -40,9 +40,11 @@ NODE_CONFIG = { "RemapMaskRange": {"class": RemapMaskRange, "name": "Remap Mask Range"}, "ResizeMask": {"class": ResizeMask, "name": "Resize Mask"}, "RoundMask": {"class": RoundMask, "name": "Round Mask"}, + "SeparateMasks": {"class": SeparateMasks, "name": "Separate Masks"}, #images "AddLabel": {"class": AddLabel, "name": "Add Label"}, "ColorMatch": {"class": ColorMatch, "name": "Color Match"}, + "ImageTensorList": {"class": ImageTensorList, "name": "Image Tensor List"}, "CrossFadeImages": {"class": CrossFadeImages, "name": "Cross Fade Images"}, "CrossFadeImagesMulti": {"class": CrossFadeImagesMulti, "name": "Cross Fade Images Multi"}, "GetImagesFromBatchIndexed": {"class": GetImagesFromBatchIndexed, "name": "Get Images From Batch Indexed"}, @@ -59,6 +61,7 @@ NODE_CONFIG = { "ImageConcatFromBatch": {"class": ImageConcatFromBatch, "name": "Image Concatenate From Batch"}, "ImageConcatMulti": {"class": ImageConcatMulti, "name": "Image Concatenate Multi"}, "ImageCropByMaskAndResize": {"class": ImageCropByMaskAndResize, "name": "Image Crop By Mask And Resize"}, + "ImageCropByMaskBatch": {"class": ImageCropByMaskBatch, "name": "Image Crop By Mask Batch"}, "ImageUncropByMask": {"class": ImageUncropByMask, "name": "Image Uncrop By Mask"}, "ImageGrabPIL": {"class": ImageGrabPIL, "name": "Image Grab PIL"}, "ImageGridComposite2x2": {"class": ImageGridComposite2x2, "name": "Image Grid Composite 2x2"}, diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index c26fed0..728dba1 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -389,6 +389,7 @@ class ImageConcatFromBatch: grid[row*height:(row+1)*height, col*width:(col+1)*width, :] = resized_image return grid.unsqueeze(0), + class ImageGridComposite2x2: @classmethod @@ -502,7 +503,7 @@ class ImageGrabPIL: RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("image",) FUNCTION = "screencap" - CATEGORY = "KJNodes/experimental" + CATEGORY = "KJNodes/image" DESCRIPTION = """ Captures an area specified by screen coordinates. Can be used for realtime diffusion with autoqueue. @@ -550,7 +551,7 @@ class Screencap_mss: RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("image",) FUNCTION = "screencap" - CATEGORY = "KJNodes/experimental" + CATEGORY = "KJNodes/image" DESCRIPTION = """ Captures an area specified by screen coordinates. Can be used for realtime diffusion with autoqueue. @@ -1116,7 +1117,7 @@ class ImageAndMaskPreview(SaveImage): RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("composite",) FUNCTION = "execute" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/masking" DESCRIPTION = """ Preview an image or a mask, when both inputs are used composites the mask on top of the image. @@ -1804,6 +1805,35 @@ with the **inputcount** and clicking update. image, = image_batch_node.batch(image, new_image) return (image,) + +class ImageTensorList: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image1": ("IMAGE",), + "image2": ("IMAGE",), + }} + + RETURN_TYPES = ("IMAGE",) + #OUTPUT_IS_LIST = (True,) + FUNCTION = "append" + CATEGORY = "KJNodes/image" + DESCRIPTION = """ +Creates an image list from the input images. +""" + + def append(self, image1, image2): + image_list = [] + if isinstance(image1, torch.Tensor) and isinstance(image2, torch.Tensor): + image_list = [image1, image2] + elif isinstance(image1, list) and isinstance(image2, torch.Tensor): + image_list = image1 + [image2] + elif isinstance(image1, torch.Tensor) and isinstance(image2, list): + image_list = [image1] + image2 + elif isinstance(image1, list) and isinstance(image2, list): + image_list = image1 + image2 + return image_list, + class ImageAddMulti: @classmethod def INPUT_TYPES(s): @@ -2723,4 +2753,106 @@ class ImageUncropByMask: output_list.append(result) - return (torch.stack(output_list),) \ No newline at end of file + return (torch.stack(output_list),) + +class ImageCropByMaskBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE", ), + "masks": ("MASK", ), + "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }), + "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }), + "padding": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1, }), + "preserve_size": ("BOOLEAN", {"default": False}), + "bg_color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255, separated by commas."}), + } + } + + RETURN_TYPES = ("IMAGE", "MASK", ) + RETURN_NAMES = ("images", "masks",) + FUNCTION = "crop" + CATEGORY = "KJNodes/image" + DESCRIPTION = "Crops the input images based on the provided masks." + + def crop(self, image, masks, width, height, bg_color, padding, preserve_size): + B, H, W, C = image.shape + BM, HM, WM = masks.shape + mask_count = BM + if HM != H or WM != W: + masks = F.interpolate(masks.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1) + print(masks.shape) + output_images = [] + output_masks = [] + + bg_color = [int(x.strip())/255.0 for x in bg_color.split(",")] + + # For each mask + for i in range(mask_count): + curr_mask = masks[i] + + # Find bounds + y_indices, x_indices = torch.nonzero(curr_mask, as_tuple=True) + if len(y_indices) == 0 or len(x_indices) == 0: + continue + + # Get exact bounds with padding + min_y = max(0, y_indices.min().item() - padding) + max_y = min(H, y_indices.max().item() + 1 + padding) + min_x = max(0, x_indices.min().item() - padding) + max_x = min(W, x_indices.max().item() + 1 + padding) + + # Ensure mask has correct shape for multiplication + curr_mask = curr_mask.unsqueeze(-1).expand(-1, -1, C) + + # Crop image and mask together + cropped_img = image[0, min_y:max_y, min_x:max_x, :] + cropped_mask = curr_mask[min_y:max_y, min_x:max_x, :] + + crop_h, crop_w = cropped_img.shape[0:2] + new_w = crop_w + new_h = crop_h + + if not preserve_size or crop_w > width or crop_h > height: + scale = min(width/crop_w, height/crop_h) + new_w = int(crop_w * scale) + new_h = int(crop_h * scale) + + # Resize RGB + resized_img = common_upscale(cropped_img.permute(2,0,1).unsqueeze(0), new_w, new_h, "lanczos", "disabled").squeeze(0).permute(1,2,0) + resized_mask = torch.nn.functional.interpolate( + cropped_mask.permute(2,0,1).unsqueeze(0), + size=(new_h, new_w), + mode='nearest' + ).squeeze(0).permute(1,2,0) + else: + resized_img = cropped_img + resized_mask = cropped_mask + + # Create empty tensors + new_img = torch.zeros((height, width, 3), dtype=image.dtype) + new_mask = torch.zeros((height, width), dtype=image.dtype) + + # Pad both + pad_x = (width - new_w) // 2 + pad_y = (height - new_h) // 2 + new_img[pad_y:pad_y+new_h, pad_x:pad_x+new_w, :] = resized_img + if len(resized_mask.shape) == 3: + resized_mask = resized_mask[:,:,0] # Take first channel if 3D + new_mask[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = resized_mask + + output_images.append(new_img) + output_masks.append(new_mask) + + if not output_images: + return (torch.zeros((0, height, width, 3), dtype=image.dtype),) + + out_rgb = torch.stack(output_images, dim=0) + out_masks = torch.stack(output_masks, dim=0) + + # Apply mask to RGB + mask_expanded = out_masks.unsqueeze(-1).expand(-1, -1, -1, 3) + background_color = torch.tensor(bg_color, dtype=torch.float32, device=image.device) + out_rgb = out_rgb * mask_expanded + background_color * (1 - mask_expanded) + + return (out_rgb, out_masks) diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 87d7602..3f7a2e3 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -1251,3 +1251,140 @@ Sets new min and max values for the mask. scaled_mask = torch.clamp(scaled_mask, min=0.0, max=1.0) return (scaled_mask, ) + + +def get_mask_polygon(self, mask_np): + import cv2 + """Helper function to get polygon points from mask""" + # Find contours + contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + if not contours: + return None + + # Get the largest contour + largest_contour = max(contours, key=cv2.contourArea) + + # Approximate polygon + epsilon = 0.02 * cv2.arcLength(largest_contour, True) + polygon = cv2.approxPolyDP(largest_contour, epsilon, True) + + return polygon.squeeze() + +import cv2 +class SeparateMasks: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK", ), + "size_threshold_width" : ("INT", {"default": 256, "min": 0.0, "max": 4096, "step": 1}), + "size_threshold_height" : ("INT", {"default": 256, "min": 0.0, "max": 4096, "step": 1}), + "mode": (["convex_polygons", "area"],), + "max_poly_points": ("INT", {"default": 8, "min": 3, "max": 32, "step": 1}), + + }, + } + + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("mask",) + FUNCTION = "separate" + CATEGORY = "KJNodes/masking" + OUTPUT_NODE = True + + def polygon_to_mask(self, polygon, shape): + mask = np.zeros((shape[0], shape[1]), dtype=np.uint8) # Fixed shape handling + + if len(polygon.shape) == 2: # Check if polygon points are valid + polygon = polygon.astype(np.int32) + cv2.fillPoly(mask, [polygon], 1) + return mask + + def get_mask_polygon(self, mask_np, max_points): + contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not contours: + return None + + largest_contour = max(contours, key=cv2.contourArea) + hull = cv2.convexHull(largest_contour) + + # Initialize with smaller epsilon for more points + perimeter = cv2.arcLength(hull, True) + epsilon = perimeter * 0.01 # Start smaller + + min_eps = perimeter * 0.001 # Much smaller minimum + max_eps = perimeter * 0.2 # Smaller maximum + + best_approx = None + best_diff = float('inf') + max_iterations = 20 + + #print(f"Target points: {max_points}, Perimeter: {perimeter}") + + for i in range(max_iterations): + curr_eps = (min_eps + max_eps) / 2 + approx = cv2.approxPolyDP(hull, curr_eps, True) + points_diff = len(approx) - max_points + + #print(f"Iteration {i}: points={len(approx)}, eps={curr_eps:.4f}") + + if abs(points_diff) < best_diff: + best_approx = approx + best_diff = abs(points_diff) + + if len(approx) > max_points: + min_eps = curr_eps * 1.1 # More gradual adjustment + elif len(approx) < max_points: + max_eps = curr_eps * 0.9 # More gradual adjustment + else: + return approx.squeeze() + + if abs(max_eps - min_eps) < perimeter * 0.0001: # Relative tolerance + break + + # If we didn't find exact match, return best approximation + return best_approx.squeeze() if best_approx is not None else hull.squeeze() + + def separate(self, mask: torch.Tensor, size_threshold_width: int, size_threshold_height: int, max_poly_points: int, mode: str): + from scipy.ndimage import label + import numpy as np + + B, H, W = mask.shape + separated = [] + + mask = mask.round() + + for b in range(B): + mask_np = mask[b].cpu().numpy().astype(np.uint8) + structure = np.ones((3, 3), dtype=np.int8) + labeled, ncomponents = label(mask_np, structure=structure) + pbar = ProgressBar(ncomponents) + for component in range(1, ncomponents + 1): + component_mask_np = (labeled == component).astype(np.uint8) + + rows = np.any(component_mask_np, axis=1) + cols = np.any(component_mask_np, axis=0) + y_min, y_max = np.where(rows)[0][[0, -1]] + x_min, x_max = np.where(cols)[0][[0, -1]] + + width = x_max - x_min + 1 + height = y_max - y_min + 1 + print(f"Component {component}: width={width}, height={height}") + + if width >= size_threshold_width and height >= size_threshold_height: + polygon = self.get_mask_polygon(component_mask_np, max_poly_points) + if mode != "area" and polygon is not None: + poly_mask = self.polygon_to_mask(polygon, (H, W)) + poly_mask = torch.tensor(poly_mask, device=mask.device) + separated.append(poly_mask) + elif mode == "area": + area_mask = torch.tensor(component_mask_np, device=mask.device) + separated.append(area_mask) + pbar.update(1) + + if len(separated) > 0: + out_masks = torch.stack(separated, dim=0) + return out_masks, + else: + return torch.empty((1, 64, 64), device=mask.device), + \ No newline at end of file diff --git a/nodes/nodes.py b/nodes/nodes.py index 2e343a0..abcfcf7 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -111,7 +111,7 @@ class ScaleBatchPromptSchedule: RETURN_TYPES = ("STRING",) FUNCTION = "scaleschedule" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/misc" DESCRIPTION = """ Scales a batch schedule from Fizz' nodes BatchPromptSchedule to a different frame count. @@ -155,7 +155,7 @@ class GetLatentsFromBatchIndexed: RETURN_TYPES = ("LATENT",) FUNCTION = "indexedlatentsfrombatch" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/latents" DESCRIPTION = """ Selects and returns the latents at the specified indices as an latent batch. """ @@ -238,7 +238,7 @@ class AppendStringsToList: } RETURN_TYPES = ("STRING",) FUNCTION = "joinstring" - CATEGORY = "KJNodes/constants" + CATEGORY = "KJNodes/text" def joinstring(self, string1, string2): if not isinstance(string1, list): @@ -261,7 +261,7 @@ class JoinStrings: } RETURN_TYPES = ("STRING",) FUNCTION = "joinstring" - CATEGORY = "KJNodes/constants" + CATEGORY = "KJNodes/text" def joinstring(self, string1, string2, delimiter): joined_string = string1 + delimiter + string2 @@ -283,7 +283,7 @@ class JoinStringMulti: RETURN_TYPES = ("STRING",) RETURN_NAMES = ("string",) FUNCTION = "combine" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/text" DESCRIPTION = """ Creates single string, or a list of strings, from multiple input strings. @@ -738,7 +738,7 @@ class EmptyLatentImagePresets: RETURN_TYPES = ("LATENT", "INT", "INT") RETURN_NAMES = ("Latent", "Width", "Height") FUNCTION = "generate" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/latents" def generate(self, dimensions, invert, batch_size): from nodes import EmptyLatentImage @@ -784,7 +784,7 @@ class EmptyLatentImageCustomPresets: RETURN_TYPES = ("LATENT", "INT", "INT") RETURN_NAMES = ("Latent", "Width", "Height") FUNCTION = "generate" - CATEGORY = "KJNodes" + CATEGORY = "KJNodes/latents" DESCRIPTION = """ Generates an empty latent image with the specified dimensions. The choices are loaded from 'custom_dimensions.json' in the nodes folder. @@ -1113,7 +1113,7 @@ class GenerateNoise: "model": ("MODEL", ), "sigmas": ("SIGMAS", ), "latent_channels": (['4', '16', ],), - "shape": (["BCHW", "BCTHW"],), + "shape": (["BCHW", "BCTHW","BTCHW",],), } } @@ -1131,7 +1131,8 @@ Generates noise for injection or to be used as empty latents on samplers with ad noise = torch.randn([batch_size, int(latent_channels), height // 8, width // 8], dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu") elif shape == "BCTHW": noise = torch.randn([1, int(latent_channels), batch_size,height // 8, width // 8], dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu") - print(noise.shape) + elif shape == "BTCHW": + noise = torch.randn([1, batch_size, int(latent_channels), height // 8, width // 8], dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu") if sigmas is not None: sigma = sigmas[0] - sigmas[-1] sigma /= model.model.latent_format.scale_factor