From af90b7b842a136fcb135ffe21205add60fe3f611 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 10 Nov 2023 01:40:12 +0200 Subject: [PATCH] Add mask rounding node --- nodes.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index db6af0b..a5e8e35 100644 --- a/nodes.py +++ b/nodes.py @@ -577,7 +577,7 @@ class GrowMaskWithBlur: blurred = F.conv2d(padded_image, blurkernel, padding=blurkernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) blurred = blurred[:, :, :, 0] - return (blurred, 1.0 - blurred,) + return (blurred.cpu(), 1.0 - blurred.cpu(),) return (torch.stack(out, dim=0), 1.0 -torch.stack(out, dim=0),) @@ -1566,6 +1566,21 @@ class BatchCLIPSeg: return results, +class RoundMask: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "mask": ("MASK",), + }} + + RETURN_TYPES = ("MASK",) + FUNCTION = "round" + CATEGORY = "KJNodes" + + def round(self, mask): + mask = mask.round() + return (mask,) + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -1597,6 +1612,7 @@ NODE_CLASS_MAPPINGS = { "BatchCropFromMask": BatchCropFromMask, "BatchUncrop": BatchUncrop, "BatchCLIPSeg": BatchCLIPSeg, + "RoundMask": RoundMask, } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -1628,4 +1644,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "BatchCropFromMask": "BatchCropFromMask", "BatchUncrop": "BatchUncrop", "BatchCLIPSeg": "BatchCLIPSeg", + "RoundMask": "RoundMask", } \ No newline at end of file