FluxBlockLoraSelect update

This commit is contained in:
kijai 2024-08-31 16:01:44 +03:00
parent 0defb731ac
commit d85a3c92a8
2 changed files with 67 additions and 48 deletions

View File

@ -1830,21 +1830,23 @@ class FluxBlockLoraLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
arg_dict = { return {"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
"blocks": ("SELECTEDBLOCKS",),
} },
"optional": {
return {"required": arg_dict} "blocks": ("SELECTEDBLOCKS",),
}
}
RETURN_TYPES = ("MODEL", ) RETURN_TYPES = ("MODEL", )
OUTPUT_TOOLTIPS = ("The modified diffusion model.",) OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
FUNCTION = "load_lora" FUNCTION = "load_lora"
CATEGORY = "KJNodes/experimental" CATEGORY = "KJNodes/experimental"
def load_lora(self, model, lora_name, strength_model, blocks): def load_lora(self, model, lora_name, strength_model, blocks=None):
from comfy.utils import load_torch_file from comfy.utils import load_torch_file
import comfy.lora import comfy.lora
@ -1868,51 +1870,52 @@ class FluxBlockLoraLoader:
loaded = comfy.lora.load_lora(lora, key_map) loaded = comfy.lora.load_lora(lora, key_map)
keys_to_delete = [] if blocks is not None:
keys_to_delete = []
for block in blocks: for block in blocks:
for key in list(loaded.keys()): # Convert keys to a list to avoid runtime error due to size change for key in list(loaded.keys()): # Convert keys to a list to avoid runtime error due to size change
match = False match = False
if isinstance(key, str) and block in key: if isinstance(key, str) and block in key:
match = True match = True
elif isinstance(key, tuple): elif isinstance(key, tuple):
for k in key: for k in key:
if block in k: if block in k:
match = True match = True
break break
if match: if match:
ratio = blocks[block] ratio = blocks[block]
if ratio == 0: if ratio == 0:
keys_to_delete.append(key) # Collect keys to delete keys_to_delete.append(key) # Collect keys to delete
else:
value = loaded[key]
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
# Handle the tuple format
if len(value[1]) > 3:
loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1]))
else:
loaded[key] = (value[0], value[1][:-2] + (ratio, value[1][-1]))
else: else:
# Handle the simpler format directly value = loaded[key]
loaded[key] = (value[0], ratio) if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
# Handle the tuple format
if len(value[1]) > 3:
loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1]))
else:
loaded[key] = (value[0], value[1][:-2] + (ratio, value[1][-1]))
else:
# Handle the simpler format directly
loaded[key] = (value[0], ratio)
# Now perform the deletion of keys # Now perform the deletion of keys
for key in keys_to_delete: for key in keys_to_delete:
del loaded[key] del loaded[key]
print("loading lora keys:") print("loading lora keys:")
for key, value in loaded.items(): for key, value in loaded.items():
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
# Handle the tuple format # Handle the tuple format
if len(value[1]) > 2: if len(value[1]) > 2:
alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple
else:
alpha = value[1][-2] # Adjust according to the second format's structure
else: else:
alpha = value[1][-2] # Adjust according to the second format's structure # Handle the simpler format directly
else: alpha = value[1] if len(value) > 1 else None
# Handle the simpler format directly print(f"Key: {key}, Alpha: {alpha}")
alpha = value[1] if len(value) > 1 else None
print(f"Key: {key}, Alpha: {alpha}")
if model is not None: if model is not None:

View File

@ -75,6 +75,22 @@ app.registerExtension({
}); });
} }
break; break;
case "FluxBlockLoraSelect":
nodeType.prototype.onNodeCreated = function () {
this.addWidget("button", "Set all", null, () => {
const userValue = parseFloat(prompt("Enter the value to set for all widgets:", "1.0"));
if (!isNaN(userValue)) {
const widgets = this.widgets;
for (const w of widgets) {
w.value = userValue;
}
} else {
alert("Invalid input. Please enter a numeric value.");
}
});
};
break;
case "GetMaskSizeAndCount": case "GetMaskSizeAndCount":
const onGetMaskSizeConnectInput = nodeType.prototype.onConnectInput; const onGetMaskSizeConnectInput = nodeType.prototype.onConnectInput;