From 181f15312055a869fffc593b1445f77590645c7d Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 1 Sep 2024 20:37:18 +0300 Subject: [PATCH] Update nodes.py --- nodes/nodes.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/nodes/nodes.py b/nodes/nodes.py index 669a26b..9fff99f 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1842,8 +1842,9 @@ class FluxBlockLoraLoader: } } - RETURN_TYPES = ("MODEL", ) - OUTPUT_TOOLTIPS = ("The modified diffusion model.",) + RETURN_TYPES = ("MODEL", "STRING", ) + RETURN_NAMES = ("model", "rank", ) + OUTPUT_TOOLTIPS = ("The modified diffusion model.", "possible rank of the LoRA.") FUNCTION = "load_lora" CATEGORY = "KJNodes/experimental" @@ -1867,6 +1868,15 @@ class FluxBlockLoraLoader: if lora is None: lora = load_torch_file(lora_path, safe_load=True) + # Find the first key that ends with "weight" + weight_key = next((key for key in lora.keys() if key.endswith('weight')), None) + # Print the shape of the value corresponding to the key + if weight_key: + print(f"Shape of the first 'weight' key ({weight_key}): {lora[weight_key].shape}") + rank = str(lora[weight_key].shape[0]) + else: + print("No key ending with 'weight' found.") + rank = "Couldn't find rank" self.loaded_lora = (lora_path, lora) key_map = {} @@ -1932,4 +1942,4 @@ class FluxBlockLoraLoader: if (x not in k): print("NOT LOADED {}".format(x)) - return (new_modelpatcher,) \ No newline at end of file + return (new_modelpatcher, rank) \ No newline at end of file