Add MaskBatchMulti -node

This commit is contained in:
kijai 2024-04-25 15:34:16 +03:00
parent 9b07de5f55
commit b99d27311f
2 changed files with 56 additions and 2 deletions

View File

@ -1080,7 +1080,7 @@ class ImageBatchMulti:
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",) RETURN_NAMES = ("images",)
FUNCTION = "combine" FUNCTION = "combine"
CATEGORY = "KJNodes/masking/conditioning" CATEGORY = "KJNodes/image"
DESCRIPTION = """ DESCRIPTION = """
Creates an image batch from multiple images. Creates an image batch from multiple images.
You can set how many inputs the node has, You can set how many inputs the node has,
@ -1094,7 +1094,37 @@ with the **inputcount** and clicking update.
for c in range(1, inputcount): for c in range(1, inputcount):
new_image = kwargs[f"image_{c + 1}"] new_image = kwargs[f"image_{c + 1}"]
image, = image_batch_node.batch(new_image, image) image, = image_batch_node.batch(new_image, image)
return (image, inputcount,) return (image,)
class MaskBatchMulti:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
"mask_1": ("MASK", ),
"mask_2": ("MASK", ),
},
}
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("masks",)
FUNCTION = "combine"
CATEGORY = "KJNodes/masking"
DESCRIPTION = """
Creates an image batch from multiple masks.
You can set how many inputs the node has,
with the **inputcount** and clicking update.
"""
def combine(self, inputcount, **kwargs):
mask = kwargs["mask_1"]
for c in range(1, inputcount):
new_mask = kwargs[f"mask_{c + 1}"]
if mask.shape[1:] != new_mask.shape[1:]:
new_mask = F.interpolate(new_mask.unsqueeze(1), size=(mask.shape[1], mask.shape[2]), mode="bicubic").squeeze(1)
mask = torch.cat((mask, new_mask), dim=0)
return (mask,)
class CondPassThrough: class CondPassThrough:
@classmethod @classmethod
@ -5145,6 +5175,7 @@ NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant, "INTConstant": INTConstant,
"FloatConstant": FloatConstant, "FloatConstant": FloatConstant,
"ImageBatchMulti": ImageBatchMulti, "ImageBatchMulti": ImageBatchMulti,
"MaskBatchMulti": MaskBatchMulti,
"ConditioningMultiCombine": ConditioningMultiCombine, "ConditioningMultiCombine": ConditioningMultiCombine,
"ConditioningSetMaskAndCombine": ConditioningSetMaskAndCombine, "ConditioningSetMaskAndCombine": ConditioningSetMaskAndCombine,
"ConditioningSetMaskAndCombine3": ConditioningSetMaskAndCombine3, "ConditioningSetMaskAndCombine3": ConditioningSetMaskAndCombine3,
@ -5233,6 +5264,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant", "INTConstant": "INT Constant",
"FloatConstant": "Float Constant", "FloatConstant": "Float Constant",
"ImageBatchMulti": "Image Batch Multi", "ImageBatchMulti": "Image Batch Multi",
"MaskBatchMulti": "Mask Batch Multi",
"ConditioningMultiCombine": "Conditioning Multi Combine", "ConditioningMultiCombine": "Conditioning Multi Combine",
"ConditioningSetMaskAndCombine": "ConditioningSetMaskAndCombine", "ConditioningSetMaskAndCombine": "ConditioningSetMaskAndCombine",
"ConditioningSetMaskAndCombine3": "ConditioningSetMaskAndCombine3", "ConditioningSetMaskAndCombine3": "ConditioningSetMaskAndCombine3",

View File

@ -48,6 +48,28 @@ app.registerExtension({
}); });
} }
break; break;
case "MaskBatchMulti":
nodeType.prototype.onNodeCreated = function () {
this._type = "MASK"
this.inputs_offset = nodeData.name.includes("selective")?1:0
this.addWidget("button", "Update inputs", null, () => {
if (!this.inputs) {
this.inputs = [];
}
const target_number_of_inputs = this.widgets.find(w => w.name === "inputcount")["value"];
if(target_number_of_inputs===this.inputs.length)return; // already set, do nothing
if(target_number_of_inputs < this.inputs.length){
for(let i = this.inputs.length; i>=this.inputs_offset+target_number_of_inputs; i--)
this.removeInput(i)
}
else{
for(let i = this.inputs.length+1-this.inputs_offset; i <= target_number_of_inputs; ++i)
this.addInput(`mask_${i}`, this._type)
}
});
}
break;
case "SoundReactive": case "SoundReactive":
nodeType.prototype.onNodeCreated = function () { nodeType.prototype.onNodeCreated = function () {
let audioContext; let audioContext;