From d1f6cf77b0f149bd86f3c6f14b669b9f4cf30abd Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 3 Oct 2025 19:06:09 +0300 Subject: [PATCH] optimize draw mask on image --- nodes/mask_nodes.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 8ca69d9..72d7d45 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -1529,8 +1529,11 @@ class DrawMaskOnImage: "image": ("IMAGE", ), "mask": ("MASK", ), "color": ("STRING", {"default": "0, 0, 0", "tooltip": "Color as RGB values in range 0-255 or 0.0-1.0, separated by commas."}), - } + }, + "optional": { + "device": (["cpu", "gpu"], {"default": "cpu", "tooltip": "Device to use for processing"}), } + } RETURN_TYPES = ("IMAGE", ) RETURN_NAMES = ("images",) @@ -1538,11 +1541,14 @@ class DrawMaskOnImage: CATEGORY = "KJNodes/masking" DESCRIPTION = "Applies the provided masks to the input images." - def apply(self, image, mask, color): + def apply(self, image, mask, color, device="cpu"): B, H, W, C = image.shape BM, HM, WM = mask.shape - in_masks = mask.clone() + processing_device = main_device if device == "gpu" else torch.device("cpu") + + in_masks = mask.clone().to(processing_device) + in_images = image.clone().to(processing_device) if HM != H or WM != W: in_masks = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1) @@ -1562,12 +1568,12 @@ class DrawMaskOnImage: else: bg_values.append(int(val_str) / 255.0) - background_color = torch.tensor(bg_values, dtype=torch.float32, device=image.device) + background_color = torch.tensor(bg_values, dtype=torch.float32, device=in_images.device) - for i in range(B): + for i in tqdm(range(B), desc="DrawMaskOnImage batch"): curr_mask = in_masks[i] img_idx = min(i, B - 1) - curr_image = image[img_idx] + curr_image = in_images[img_idx] mask_expanded = curr_mask.unsqueeze(-1).expand(-1, -1, 3) masked_image = curr_image * (1 - mask_expanded) + background_color * (mask_expanded) output_images.append(masked_image) @@ -1576,7 +1582,7 @@ class DrawMaskOnImage: if not output_images: return (torch.zeros((0, H, W, 3), dtype=image.dtype),) - out_rgb = torch.stack(output_images, dim=0) + out_rgb = torch.stack(output_images, dim=0).cpu() return (out_rgb, )