Add GetLatentSizeAndCount

This commit is contained in:
kijai 2025-08-20 12:36:50 +03:00
parent 876a6dd292
commit e435e999e4
3 changed files with 60 additions and 0 deletions

View File

@ -51,6 +51,7 @@ NODE_CONFIG = {
"GetImagesFromBatchIndexed": {"class": GetImagesFromBatchIndexed, "name": "Get Images From Batch Indexed"},
"GetImageRangeFromBatch": {"class": GetImageRangeFromBatch, "name": "Get Image or Mask Range From Batch"},
"GetLatentRangeFromBatch": {"class": GetLatentRangeFromBatch, "name": "Get Latent Range From Batch"},
"GetLatentSizeAndCount": {"class": GetLatentSizeAndCount, "name": "Get Latent Size & Count"},
"GetImageSizeAndCount": {"class": GetImageSizeAndCount, "name": "Get Image Size & Count"},
"FastPreview": {"class": FastPreview, "name": "Fast Preview"},
"ImageBatchFilter": {"class": ImageBatchFilter, "name": "Image Batch Filter"},

View File

@ -1,3 +1,5 @@
from itertools import count
from turtle import width
import numpy as np
import time
import torch
@ -795,6 +797,36 @@ and passes it through unchanged.
"text": [f"{count}x{width}x{height}"]},
"result": (image, width, height, count)
}
class GetLatentSizeAndCount:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"latent": ("LATENT",),
}}
RETURN_TYPES = ("LATENT","INT", "INT", "INT", "INT", "INT")
RETURN_NAMES = ("latent", "batch_size", "channels", "frames", "width", "height")
FUNCTION = "getsize"
CATEGORY = "KJNodes/image"
DESCRIPTION = """
Returns latent tensor dimensions,
and passes the latent through unchanged.
"""
def getsize(self, latent):
if len(latent["samples"].shape) == 5:
B, C, T, H, W = latent["samples"].shape
elif len(latent["samples"].shape) == 4:
B, C, H, W = latent["samples"].shape
T = 0
else:
raise ValueError("Invalid latent shape")
return {"ui": {
"text": [f"{B}x{C}x{T}x{H}x{W}"]},
"result": (latent, B, C, T, H, W)
}
class ImageBatchRepeatInterleaving:

View File

@ -191,6 +191,33 @@ app.registerExtension({
}
break;
case "GetLatentSizeAndCount":
const onGetLatentConnectInput = nodeType.prototype.onConnectInput;
nodeType.prototype.onConnectInput = function (targetSlot, type, output, originNode, originSlot) {
console.log(this)
const v = onGetLatentConnectInput? onGetLatentConnectInput.apply(this, arguments): undefined
//console.log(this)
this.outputs[1]["label"] = "width"
this.outputs[2]["label"] = "height"
this.outputs[3]["label"] = "count"
return v;
}
//const onGetImageSizeExecuted = nodeType.prototype.onExecuted;
const onGetLatentSizeExecuted = nodeType.prototype.onAfterExecuteNode;
nodeType.prototype.onExecuted = function(message) {
console.log(this)
const r = onGetLatentSizeExecuted? onGetLatentSizeExecuted.apply(this,arguments): undefined
let values = message["text"].toString().split('x').map(Number);
console.log(values)
this.outputs[1]["label"] = values[0] + " batch"
this.outputs[2]["label"] = values[1] + " channels"
this.outputs[3]["label"] = values[2] + " frames"
this.outputs[4]["label"] = values[3] + " height"
this.outputs[5]["label"] = values[4] + " width"
return r
}
break;
case "PreviewAnimation":
const onPreviewAnimationConnectInput = nodeType.prototype.onConnectInput;
nodeType.prototype.onConnectInput = function (targetSlot, type, output, originNode, originSlot) {