From b95b79ee822cbe21bd30de6c89cb91bd11e393fd Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 28 May 2025 14:31:22 +0300 Subject: [PATCH] Fix lora block loading --- nodes/nodes.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/nodes/nodes.py b/nodes/nodes.py index 065813d..ddd90bb 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2043,13 +2043,13 @@ class DiTBlockLoraLoader: if self.loaded_lora[0] == lora_path: lora = self.loaded_lora[1] else: - temp = self.loaded_lora self.loaded_lora = None - del temp if lora is None: lora = load_torch_file(lora_path, safe_load=True) - # Find the first key that ends with "weight" + self.loaded_lora = (lora_path, lora) + + # Find the first key that ends with "weight" rank = "unknown" weight_key = next((key for key in lora.keys() if key.endswith('weight')), None) # Print the shape of the value corresponding to the key @@ -2086,26 +2086,17 @@ class DiTBlockLoraLoader: if ratio == 0: keys_to_delete.append(key) else: - value = loaded[key] - if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): - inner_tuple = value[1] - if len(inner_tuple) >= 3: - inner_tuple = (inner_tuple[0], inner_tuple[1], ratio, *inner_tuple[3:]) - loaded[key] = (value[0], inner_tuple) - else: - loaded[key] = (value[0], ratio) + value = loaded[key].weights + weights_list = list(loaded[key].weights) + weights_list[2] = ratio + loaded[key].weights = tuple(weights_list) for key in keys_to_delete: del loaded[key] print("loading lora keys:") for key, value in loaded.items(): - if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): - inner_tuple = value[1] - alpha = inner_tuple[2] if len(inner_tuple) >= 3 else None - else: - alpha = value[1] if len(value) > 1 else None - print(f"Key: {key}, Alpha: {alpha}") + print(f"Key: {key}, Alpha: {value.weights[2]}") if model is not None: