From b9abf5df31d37c713f21073db738f1e46230a618 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 3 Oct 2025 19:01:46 +0300 Subject: [PATCH] optimize blockify mask node --- nodes/mask_nodes.py | 87 +++++++++++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 31 deletions(-) diff --git a/nodes/mask_nodes.py b/nodes/mask_nodes.py index 5397e36..8ca69d9 100644 --- a/nodes/mask_nodes.py +++ b/nodes/mask_nodes.py @@ -6,6 +6,7 @@ import scipy.ndimage import numpy as np from contextlib import nullcontext import os +from tqdm import tqdm from comfy import model_management from comfy.utils import ProgressBar @@ -17,7 +18,7 @@ import folder_paths from ..utility.utility import tensor2pil, pil2tensor script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -device = model_management.get_torch_device() +main_device = model_management.get_torch_device() offload_device = model_management.unet_offload_device() class BatchCLIPSeg: @@ -1013,7 +1014,7 @@ class GrowMaskWithBlur: previous_output = None current_expand = expand for m in growmask: - output = m.unsqueeze(0).unsqueeze(0).to(device) # Add batch and channel dims for kornia + output = m.unsqueeze(0).unsqueeze(0).to(main_device) # Add batch and channel dims for kornia if abs(round(current_expand)) > 0: # Create kernel - kornia expects kernel on same device as input if tapered_corners: @@ -1586,8 +1587,11 @@ class BlockifyMask: return {"required": { "masks": ("MASK",), "block_size": ("INT", {"default": 32, "min": 8, "max": 512, "step": 1, "tooltip": "Size of blocks in pixels (smaller = smaller blocks)"}), + }, + "optional": { + "device": (["cpu", "gpu"], {"default": "cpu", "tooltip": "Device to use for processing"}), } - } + } RETURN_TYPES = ("MASK", ) RETURN_NAMES = ("mask",) @@ -1595,49 +1599,70 @@ class BlockifyMask: CATEGORY = "KJNodes/masking" DESCRIPTION = "Creates a block mask by dividing the bounding box of each mask into blocks of the specified size and filling in blocks that contain any part of the original mask." - def process(self, masks, block_size): - batch_size = masks.shape[0] - result_masks = [] + def process(self, masks, block_size, device="cpu"): + processing_device = main_device if device == "gpu" else torch.device("cpu") + print("mask.dtype:", masks.dtype, "device:", processing_device) - for i in range(batch_size): + masks = masks.to(processing_device) + batch_size, height, width = masks.shape + + result_masks = torch.zeros_like(masks) + + for i in tqdm(range(batch_size), desc="BlockifyMask batch"): mask = masks[i] - # Find bounding box using tensor operations - nonzero_coords = torch.nonzero(mask, as_tuple=True) - if len(nonzero_coords[0]) == 0: # Empty mask - result_masks.append(mask) + # Find bounding box efficiently + mask_bool = mask > 0 + if not mask_bool.any(): continue - y_coords, x_coords = nonzero_coords - y_min, y_max = y_coords.min(), y_coords.max() - x_min, x_max = x_coords.min(), x_coords.max() + y_indices = torch.nonzero(mask_bool.any(dim=1), as_tuple=True)[0] + x_indices = torch.nonzero(mask_bool.any(dim=0), as_tuple=True)[0] + + if len(y_indices) == 0 or len(x_indices) == 0: + continue + + y_min, y_max = y_indices[0], y_indices[-1] + x_min, x_max = x_indices[0], x_indices[-1] bbox_width = x_max - x_min + 1 bbox_height = y_max - y_min + 1 - # Calculate number of blocks that fit + # Calculate block grid w_divisions = max(1, bbox_width // block_size) h_divisions = max(1, bbox_height // block_size) - # Calculate actual block sizes (might be slightly larger than block_size) w_slice = bbox_width // w_divisions h_slice = bbox_height // h_divisions - # Create output mask (copy of input) - output_mask = mask.clone() + # Create coordinate grids only for bbox region + y_coords = torch.arange(y_min, y_max + 1, device=processing_device).view(-1, 1) + x_coords = torch.arange(x_min, x_max + 1, device=processing_device).view(1, -1) - # Process grid cells - for w_start in range(x_min, x_max + 1, w_slice): - w_end = min(w_start + w_slice, x_max + 1) - for h_start in range(y_min, y_max + 1, h_slice): - h_end = min(h_start + h_slice, y_max + 1) - - # Check if this cell contains any mask content - cell_region = mask[h_start:h_end, w_start:w_end] - if cell_region.sum() > 0: - # Fill the entire cell - output_mask[h_start:h_end, w_start:w_end] = 1.0 + # Calculate block indices for bbox region + w_block_indices = (x_coords - x_min) // w_slice + h_block_indices = (y_coords - y_min) // h_slice - result_masks.append(output_mask) + # Clamp to valid range + w_block_indices = w_block_indices.clamp(0, w_divisions - 1) + h_block_indices = h_block_indices.clamp(0, h_divisions - 1) + + # Create unique block IDs by combining h and w indices + block_ids = h_block_indices * w_divisions + w_block_indices + + # Get mask region within bbox + mask_region = mask[y_min:y_max+1, x_min:x_max+1] + + # Find which blocks have content using scatter_add + max_blocks = h_divisions * w_divisions + block_content = torch.zeros(max_blocks, device=processing_device) + block_content.scatter_add_(0, block_ids.flatten(), mask_region.flatten()) + + # Create result for blocks that have content + has_content = block_content > 0 + block_mask = has_content[block_ids] + + # Fill the result + result_masks[i, y_min:y_max+1, x_min:x_max+1] = block_mask.float() - return torch.stack(result_masks, dim=0), \ No newline at end of file + return (result_masks.clamp(0, 1),) \ No newline at end of file