Add mask rounding node

This commit is contained in:
kijai 2023-11-10 01:40:12 +02:00
parent 74b54f41a5
commit af90b7b842

View File

@ -577,7 +577,7 @@ class GrowMaskWithBlur:
blurred = F.conv2d(padded_image, blurkernel, padding=blurkernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = F.conv2d(padded_image, blurkernel, padding=blurkernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
blurred = blurred.permute(0, 2, 3, 1) blurred = blurred.permute(0, 2, 3, 1)
blurred = blurred[:, :, :, 0] blurred = blurred[:, :, :, 0]
return (blurred, 1.0 - blurred,) return (blurred.cpu(), 1.0 - blurred.cpu(),)
return (torch.stack(out, dim=0), 1.0 -torch.stack(out, dim=0),) return (torch.stack(out, dim=0), 1.0 -torch.stack(out, dim=0),)
@ -1566,6 +1566,21 @@ class BatchCLIPSeg:
return results, return results,
class RoundMask:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"mask": ("MASK",),
}}
RETURN_TYPES = ("MASK",)
FUNCTION = "round"
CATEGORY = "KJNodes"
def round(self, mask):
mask = mask.round()
return (mask,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant, "INTConstant": INTConstant,
"FloatConstant": FloatConstant, "FloatConstant": FloatConstant,
@ -1597,6 +1612,7 @@ NODE_CLASS_MAPPINGS = {
"BatchCropFromMask": BatchCropFromMask, "BatchCropFromMask": BatchCropFromMask,
"BatchUncrop": BatchUncrop, "BatchUncrop": BatchUncrop,
"BatchCLIPSeg": BatchCLIPSeg, "BatchCLIPSeg": BatchCLIPSeg,
"RoundMask": RoundMask,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant", "INTConstant": "INT Constant",
@ -1628,4 +1644,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"BatchCropFromMask": "BatchCropFromMask", "BatchCropFromMask": "BatchCropFromMask",
"BatchUncrop": "BatchUncrop", "BatchUncrop": "BatchUncrop",
"BatchCLIPSeg": "BatchCLIPSeg", "BatchCLIPSeg": "BatchCLIPSeg",
"RoundMask": "RoundMask",
} }