diff --git a/nodes.py b/nodes.py index d6e43e5..a62063e 100644 --- a/nodes.py +++ b/nodes.py @@ -330,6 +330,7 @@ class GrowMaskWithBlur: "incremental_expandrate": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1}), "tapered_corners": ("BOOLEAN", {"default": True}), "flip_input": ("BOOLEAN", {"default": False}), + "use_cuda": ("BOOLEAN", {"default": True}), "blur_radius": ("INT", { "default": 0, "min": 0, @@ -351,7 +352,7 @@ class GrowMaskWithBlur: RETURN_NAMES = ("mask", "mask_inverted",) FUNCTION = "expand_mask" - def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, sigma, incremental_expandrate): + def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, sigma, incremental_expandrate, use_cuda): if( flip_input ): mask = 1.0 - mask c = 0 if tapered_corners else 1 @@ -373,8 +374,12 @@ class GrowMaskWithBlur: expand += abs(incremental_expandrate) # Use abs(growrate) to ensure positive change output = torch.from_numpy(output) out.append(output) - + blurred = torch.stack(out, dim=0).reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) + if use_cuda: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + blurred = blurred.to(device) # Move blurred tensor to the GPU + batch_size, height, width, channels = blurred.shape if blur_radius != 0: blurkernel_size = blur_radius * 2 + 1 @@ -642,7 +647,7 @@ class VRAM_Debug: return { "required": { "model": ("MODEL",), - + "empty_cuda_cache": ("BOOLEAN", {"default": False}), }, "optional": { "clip_vision": ("CLIP_VISION", ), @@ -653,11 +658,12 @@ class VRAM_Debug: FUNCTION = "VRAMdebug" CATEGORY = "KJNodes" - def VRAMdebug(self, model, clip_vision=None): + def VRAMdebug(self, model, empty_cuda_cache, clip_vision=None): freemem_before = comfy.model_management.get_free_memory() print(freemem_before) - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + if empty_cuda_cache: + torch.cuda.empty_cache() + torch.cuda.ipc_collect() if clip_vision is not None: print("unloading clip_vision_clone") comfy.model_management.unload_model_clones(clip_vision.patcher)