From b99d27311fe83d6c917f6f7248a4e6e82d8e6ce5 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 25 Apr 2024 15:34:16 +0300 Subject: [PATCH] Add MaskBatchMulti -node --- nodes.py | 36 ++++++++++++++++++++++++++++++++++-- web/js/jsnodes.js | 22 ++++++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index bf67711..49d1aec 100644 --- a/nodes.py +++ b/nodes.py @@ -1080,7 +1080,7 @@ class ImageBatchMulti: RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("images",) FUNCTION = "combine" - CATEGORY = "KJNodes/masking/conditioning" + CATEGORY = "KJNodes/image" DESCRIPTION = """ Creates an image batch from multiple images. You can set how many inputs the node has, @@ -1094,7 +1094,37 @@ with the **inputcount** and clicking update. for c in range(1, inputcount): new_image = kwargs[f"image_{c + 1}"] image, = image_batch_node.batch(new_image, image) - return (image, inputcount,) + return (image,) + +class MaskBatchMulti: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}), + "mask_1": ("MASK", ), + "mask_2": ("MASK", ), + }, + } + + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("masks",) + FUNCTION = "combine" + CATEGORY = "KJNodes/masking" + DESCRIPTION = """ +Creates an image batch from multiple masks. +You can set how many inputs the node has, +with the **inputcount** and clicking update. +""" + + def combine(self, inputcount, **kwargs): + mask = kwargs["mask_1"] + for c in range(1, inputcount): + new_mask = kwargs[f"mask_{c + 1}"] + if mask.shape[1:] != new_mask.shape[1:]: + new_mask = F.interpolate(new_mask.unsqueeze(1), size=(mask.shape[1], mask.shape[2]), mode="bicubic").squeeze(1) + mask = torch.cat((mask, new_mask), dim=0) + return (mask,) class CondPassThrough: @classmethod @@ -5145,6 +5175,7 @@ NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, "ImageBatchMulti": ImageBatchMulti, + "MaskBatchMulti": MaskBatchMulti, "ConditioningMultiCombine": ConditioningMultiCombine, "ConditioningSetMaskAndCombine": ConditioningSetMaskAndCombine, "ConditioningSetMaskAndCombine3": ConditioningSetMaskAndCombine3, @@ -5233,6 +5264,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", "FloatConstant": "Float Constant", "ImageBatchMulti": "Image Batch Multi", + "MaskBatchMulti": "Mask Batch Multi", "ConditioningMultiCombine": "Conditioning Multi Combine", "ConditioningSetMaskAndCombine": "ConditioningSetMaskAndCombine", "ConditioningSetMaskAndCombine3": "ConditioningSetMaskAndCombine3", diff --git a/web/js/jsnodes.js b/web/js/jsnodes.js index 49e1ec2..91ab055 100644 --- a/web/js/jsnodes.js +++ b/web/js/jsnodes.js @@ -48,6 +48,28 @@ app.registerExtension({ }); } break; + case "MaskBatchMulti": + nodeType.prototype.onNodeCreated = function () { + this._type = "MASK" + this.inputs_offset = nodeData.name.includes("selective")?1:0 + this.addWidget("button", "Update inputs", null, () => { + if (!this.inputs) { + this.inputs = []; + } + const target_number_of_inputs = this.widgets.find(w => w.name === "inputcount")["value"]; + if(target_number_of_inputs===this.inputs.length)return; // already set, do nothing + + if(target_number_of_inputs < this.inputs.length){ + for(let i = this.inputs.length; i>=this.inputs_offset+target_number_of_inputs; i--) + this.removeInput(i) + } + else{ + for(let i = this.inputs.length+1-this.inputs_offset; i <= target_number_of_inputs; ++i) + this.addInput(`mask_${i}`, this._type) + } + }); + } + break; case "SoundReactive": nodeType.prototype.onNodeCreated = function () { let audioContext;