diff --git a/__init__.py b/__init__.py index f338853..0c9c7f1 100644 --- a/__init__.py +++ b/__init__.py @@ -51,6 +51,7 @@ NODE_CONFIG = { "GetImagesFromBatchIndexed": {"class": GetImagesFromBatchIndexed, "name": "Get Images From Batch Indexed"}, "GetImageRangeFromBatch": {"class": GetImageRangeFromBatch, "name": "Get Image or Mask Range From Batch"}, "GetLatentRangeFromBatch": {"class": GetLatentRangeFromBatch, "name": "Get Latent Range From Batch"}, + "GetLatentSizeAndCount": {"class": GetLatentSizeAndCount, "name": "Get Latent Size & Count"}, "GetImageSizeAndCount": {"class": GetImageSizeAndCount, "name": "Get Image Size & Count"}, "FastPreview": {"class": FastPreview, "name": "Fast Preview"}, "ImageBatchFilter": {"class": ImageBatchFilter, "name": "Image Batch Filter"}, diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index fa70f33..72b128a 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -1,3 +1,5 @@ +from itertools import count +from turtle import width import numpy as np import time import torch @@ -795,6 +797,36 @@ and passes it through unchanged. "text": [f"{count}x{width}x{height}"]}, "result": (image, width, height, count) } + +class GetLatentSizeAndCount: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "latent": ("LATENT",), + }} + + RETURN_TYPES = ("LATENT","INT", "INT", "INT", "INT", "INT") + RETURN_NAMES = ("latent", "batch_size", "channels", "frames", "width", "height") + FUNCTION = "getsize" + CATEGORY = "KJNodes/image" + DESCRIPTION = """ +Returns latent tensor dimensions, +and passes the latent through unchanged. + +""" + def getsize(self, latent): + if len(latent["samples"].shape) == 5: + B, C, T, H, W = latent["samples"].shape + elif len(latent["samples"].shape) == 4: + B, C, H, W = latent["samples"].shape + T = 0 + else: + raise ValueError("Invalid latent shape") + + return {"ui": { + "text": [f"{B}x{C}x{T}x{H}x{W}"]}, + "result": (latent, B, C, T, H, W) + } class ImageBatchRepeatInterleaving: diff --git a/web/js/jsnodes.js b/web/js/jsnodes.js index ff820b8..2e67693 100644 --- a/web/js/jsnodes.js +++ b/web/js/jsnodes.js @@ -191,6 +191,33 @@ app.registerExtension({ } break; + case "GetLatentSizeAndCount": + const onGetLatentConnectInput = nodeType.prototype.onConnectInput; + nodeType.prototype.onConnectInput = function (targetSlot, type, output, originNode, originSlot) { + console.log(this) + const v = onGetLatentConnectInput? onGetLatentConnectInput.apply(this, arguments): undefined + //console.log(this) + this.outputs[1]["label"] = "width" + this.outputs[2]["label"] = "height" + this.outputs[3]["label"] = "count" + return v; + } + //const onGetImageSizeExecuted = nodeType.prototype.onExecuted; + const onGetLatentSizeExecuted = nodeType.prototype.onAfterExecuteNode; + nodeType.prototype.onExecuted = function(message) { + console.log(this) + const r = onGetLatentSizeExecuted? onGetLatentSizeExecuted.apply(this,arguments): undefined + let values = message["text"].toString().split('x').map(Number); + console.log(values) + this.outputs[1]["label"] = values[0] + " batch" + this.outputs[2]["label"] = values[1] + " channels" + this.outputs[3]["label"] = values[2] + " frames" + this.outputs[4]["label"] = values[3] + " height" + this.outputs[5]["label"] = values[4] + " width" + return r + } + break; + case "PreviewAnimation": const onPreviewAnimationConnectInput = nodeType.prototype.onConnectInput; nodeType.prototype.onConnectInput = function (targetSlot, type, output, originNode, originSlot) {