From 171b70bfa51ff7202a62a4a4ee2f33e9211a28ba Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 22 Apr 2024 19:50:34 +0300 Subject: [PATCH] Remake ColorToMask node that code was atrocious --- nodes.py | 50 ++++++++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/nodes.py b/nodes.py index aa779cf..be07541 100644 --- a/nodes.py +++ b/nodes.py @@ -1000,35 +1000,41 @@ Converts chosen RGB value to a mask "green": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}), "blue": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}), "threshold": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}), + "per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}), }, } - def clip(self, images, red, green, blue, threshold, invert): - color = np.array([red, green, blue]) - images = 255. * images.cpu().numpy() - images = np.clip(images, 0, 255).astype(np.uint8) - images = [Image.fromarray(image) for image in images] - images = [np.array(image) for image in images] + def clip(self, images, red, green, blue, threshold, invert, per_batch): - black = [0, 0, 0] - white = [255, 255, 255] + color = torch.tensor([red, green, blue], dtype=torch.uint8) + black = torch.tensor([0, 0, 0], dtype=torch.uint8) + white = torch.tensor([255, 255, 255], dtype=torch.uint8) + if invert: - black, white = white, black + black, white = white, black - new_images = [] - for image in images: - new_image = np.full_like(image, black) + steps = images.shape[0] + pbar = comfy.utils.ProgressBar(steps) + tensors_out = [] + + for start_idx in range(0, images.shape[0], per_batch): - color_distances = np.linalg.norm(image - color, axis=-1) - complement_indexes = color_distances <= threshold + # Calculate color distances + color_distances = torch.norm(images[start_idx:start_idx+per_batch] * 255 - color, dim=-1) + + # Create a mask based on the threshold + mask = color_distances <= threshold + + # Apply the mask to create new images + mask_out = torch.where(mask.unsqueeze(-1), white, black).float() + mask_out = mask_out.mean(dim=-1) - new_image[complement_indexes] = white - - new_images.append(new_image) - - new_images = np.array(new_images).astype(np.float32) / 255.0 - new_images = torch.from_numpy(new_images).permute(3, 0, 1, 2) - return new_images + tensors_out.append(mask_out.cpu()) + batch_count = mask_out.shape[0] + pbar.update(batch_count) + + tensors_out = torch.cat(tensors_out, dim=0) + return tensors_out, class ConditioningMultiCombine: @classmethod @@ -5058,7 +5064,7 @@ Each mask is generated with the specified width and height. mask = torch.ones((height, width), dtype=torch.float32) * value masks.append(mask) masks_out = torch.stack(masks, dim=0) - print(masks_out.shape) + return(masks_out,)