From 8d7a3a3bab7e9bd7005f9289a4cda89a3c661fbf Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 18 Sep 2025 18:05:34 +0300 Subject: [PATCH] Add DrawMaskOnImage --- __init__.py | 1 + nodes/mask_nodes.py | 58 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/__init__.py b/__init__.py index 5ed7113..7ba74db 100644 --- a/__init__.py +++ b/__init__.py @@ -22,6 +22,7 @@ NODE_CONFIG = { "ConditioningSetMaskAndCombine5": {"class": ConditioningSetMaskAndCombine5, "name": "ConditioningSetMaskAndCombine5"}, "CondPassThrough": {"class": CondPassThrough}, #masking + "DrawMaskOnImage": {"class": DrawMaskOnImage, "name": "Draw Mask On Image"}, "DownloadAndLoadCLIPSeg": {"class": DownloadAndLoadCLIPSeg, "name": "(Down)load CLIPSeg"}, "BatchCLIPSeg": {"class": BatchCLIPSeg, "name": "Batch CLIPSeg"}, "ColorToMask": {"class": ColorToMask, "name": "Color To Mask"}, diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index c519ee0..9b5172d 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -1502,3 +1502,61 @@ class ConsolidateMasksKJ: print(f"Consolidated {B} masks into {len(final_masks)}") return (torch.stack(final_masks, dim=0),) + +class DrawMaskOnImage: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE", ), + "mask": ("MASK", ), + "color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255 or 0.0-1.0, separated by commas."}), + } + } + + RETURN_TYPES = ("IMAGE", ) + RETURN_NAMES = ("images",) + FUNCTION = "apply" + CATEGORY = "KJNodes/image" + DESCRIPTION = "Applies the provided masks to the input images." + + def apply(self, image, mask, color): + B, H, W, C = image.shape + BM, HM, WM = mask.shape + + in_masks = mask.clone() + + if HM != H or WM != W: + in_masks = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1) + if B > BM: + in_masks = in_masks.repeat((B + BM - 1) // BM, 1, 1)[:B] + elif BM > B: + in_masks = in_masks[:B] + + output_images = [] + + # Parse background color - detect if values are integers or floats + bg_values = [] + for x in color.split(","): + val_str = x.strip() + if '.' in val_str: + bg_values.append(float(val_str)) + else: + bg_values.append(int(val_str) / 255.0) + + background_color = torch.tensor(bg_values, dtype=torch.float32, device=image.device) + + for i in range(B): + curr_mask = in_masks[i] + img_idx = min(i, B - 1) + curr_image = image[img_idx] + mask_expanded = curr_mask.unsqueeze(-1).expand(-1, -1, 3) + masked_image = curr_image * (1 - mask_expanded) + background_color * (mask_expanded) + output_images.append(masked_image) + + # If no masks were processed, return empty tensor + if not output_images: + return (torch.zeros((0, H, W, 3), dtype=image.dtype),) + + out_rgb = torch.stack(output_images, dim=0) + + return (out_rgb, ) \ No newline at end of file