mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-07 08:47:57 +08:00
Fix lora block loading
This commit is contained in:
parent
37eb7bddcb
commit
b95b79ee82
@ -2043,13 +2043,13 @@ class DiTBlockLoraLoader:
|
|||||||
if self.loaded_lora[0] == lora_path:
|
if self.loaded_lora[0] == lora_path:
|
||||||
lora = self.loaded_lora[1]
|
lora = self.loaded_lora[1]
|
||||||
else:
|
else:
|
||||||
temp = self.loaded_lora
|
|
||||||
self.loaded_lora = None
|
self.loaded_lora = None
|
||||||
del temp
|
|
||||||
|
|
||||||
if lora is None:
|
if lora is None:
|
||||||
lora = load_torch_file(lora_path, safe_load=True)
|
lora = load_torch_file(lora_path, safe_load=True)
|
||||||
# Find the first key that ends with "weight"
|
self.loaded_lora = (lora_path, lora)
|
||||||
|
|
||||||
|
# Find the first key that ends with "weight"
|
||||||
rank = "unknown"
|
rank = "unknown"
|
||||||
weight_key = next((key for key in lora.keys() if key.endswith('weight')), None)
|
weight_key = next((key for key in lora.keys() if key.endswith('weight')), None)
|
||||||
# Print the shape of the value corresponding to the key
|
# Print the shape of the value corresponding to the key
|
||||||
@ -2086,26 +2086,17 @@ class DiTBlockLoraLoader:
|
|||||||
if ratio == 0:
|
if ratio == 0:
|
||||||
keys_to_delete.append(key)
|
keys_to_delete.append(key)
|
||||||
else:
|
else:
|
||||||
value = loaded[key]
|
value = loaded[key].weights
|
||||||
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
|
weights_list = list(loaded[key].weights)
|
||||||
inner_tuple = value[1]
|
weights_list[2] = ratio
|
||||||
if len(inner_tuple) >= 3:
|
loaded[key].weights = tuple(weights_list)
|
||||||
inner_tuple = (inner_tuple[0], inner_tuple[1], ratio, *inner_tuple[3:])
|
|
||||||
loaded[key] = (value[0], inner_tuple)
|
|
||||||
else:
|
|
||||||
loaded[key] = (value[0], ratio)
|
|
||||||
|
|
||||||
for key in keys_to_delete:
|
for key in keys_to_delete:
|
||||||
del loaded[key]
|
del loaded[key]
|
||||||
|
|
||||||
print("loading lora keys:")
|
print("loading lora keys:")
|
||||||
for key, value in loaded.items():
|
for key, value in loaded.items():
|
||||||
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
|
print(f"Key: {key}, Alpha: {value.weights[2]}")
|
||||||
inner_tuple = value[1]
|
|
||||||
alpha = inner_tuple[2] if len(inner_tuple) >= 3 else None
|
|
||||||
else:
|
|
||||||
alpha = value[1] if len(value) > 1 else None
|
|
||||||
print(f"Key: {key}, Alpha: {alpha}")
|
|
||||||
|
|
||||||
|
|
||||||
if model is not None:
|
if model is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user