Add MaskBatchMulti -node

This commit is contained in:
kijai 2024-04-25 15:34:16 +03:00
parent 9b07de5f55
commit b99d27311f
2 changed files with 56 additions and 2 deletions

View File

@ -1080,7 +1080,7 @@ class ImageBatchMulti:
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "combine"
CATEGORY = "KJNodes/masking/conditioning"
CATEGORY = "KJNodes/image"
DESCRIPTION = """
Creates an image batch from multiple images.
You can set how many inputs the node has,
@ -1094,7 +1094,37 @@ with the **inputcount** and clicking update.
for c in range(1, inputcount):
new_image = kwargs[f"image_{c + 1}"]
image, = image_batch_node.batch(new_image, image)
return (image, inputcount,)
return (image,)
class MaskBatchMulti:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
"mask_1": ("MASK", ),
"mask_2": ("MASK", ),
},
}
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("masks",)
FUNCTION = "combine"
CATEGORY = "KJNodes/masking"
DESCRIPTION = """
Creates an image batch from multiple masks.
You can set how many inputs the node has,
with the **inputcount** and clicking update.
"""
def combine(self, inputcount, **kwargs):
mask = kwargs["mask_1"]
for c in range(1, inputcount):
new_mask = kwargs[f"mask_{c + 1}"]
if mask.shape[1:] != new_mask.shape[1:]:
new_mask = F.interpolate(new_mask.unsqueeze(1), size=(mask.shape[1], mask.shape[2]), mode="bicubic").squeeze(1)
mask = torch.cat((mask, new_mask), dim=0)
return (mask,)
class CondPassThrough:
@classmethod
@ -5145,6 +5175,7 @@ NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,
"FloatConstant": FloatConstant,
"ImageBatchMulti": ImageBatchMulti,
"MaskBatchMulti": MaskBatchMulti,
"ConditioningMultiCombine": ConditioningMultiCombine,
"ConditioningSetMaskAndCombine": ConditioningSetMaskAndCombine,
"ConditioningSetMaskAndCombine3": ConditioningSetMaskAndCombine3,
@ -5233,6 +5264,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
"FloatConstant": "Float Constant",
"ImageBatchMulti": "Image Batch Multi",
"MaskBatchMulti": "Mask Batch Multi",
"ConditioningMultiCombine": "Conditioning Multi Combine",
"ConditioningSetMaskAndCombine": "ConditioningSetMaskAndCombine",
"ConditioningSetMaskAndCombine3": "ConditioningSetMaskAndCombine3",

View File

@ -48,6 +48,28 @@ app.registerExtension({
});
}
break;
case "MaskBatchMulti":
nodeType.prototype.onNodeCreated = function () {
this._type = "MASK"
this.inputs_offset = nodeData.name.includes("selective")?1:0
this.addWidget("button", "Update inputs", null, () => {
if (!this.inputs) {
this.inputs = [];
}
const target_number_of_inputs = this.widgets.find(w => w.name === "inputcount")["value"];
if(target_number_of_inputs===this.inputs.length)return; // already set, do nothing
if(target_number_of_inputs < this.inputs.length){
for(let i = this.inputs.length; i>=this.inputs_offset+target_number_of_inputs; i--)
this.removeInput(i)
}
else{
for(let i = this.inputs.length+1-this.inputs_offset; i <= target_number_of_inputs; ++i)
this.addInput(`mask_${i}`, this._type)
}
});
}
break;
case "SoundReactive":
nodeType.prototype.onNodeCreated = function () {
let audioContext;