diff --git a/nodes/nodes.py b/nodes/nodes.py index 30d66af..38a7851 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1830,21 +1830,23 @@ class FluxBlockLoraLoader: @classmethod def INPUT_TYPES(s): - arg_dict = { - "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), - "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), - "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), - "blocks": ("SELECTEDBLOCKS",), - } - - return {"required": arg_dict} + return {"required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + + }, + "optional": { + "blocks": ("SELECTEDBLOCKS",), + } + } RETURN_TYPES = ("MODEL", ) OUTPUT_TOOLTIPS = ("The modified diffusion model.",) FUNCTION = "load_lora" CATEGORY = "KJNodes/experimental" - def load_lora(self, model, lora_name, strength_model, blocks): + def load_lora(self, model, lora_name, strength_model, blocks=None): from comfy.utils import load_torch_file import comfy.lora @@ -1868,51 +1870,52 @@ class FluxBlockLoraLoader: loaded = comfy.lora.load_lora(lora, key_map) - keys_to_delete = [] + if blocks is not None: + keys_to_delete = [] - for block in blocks: - for key in list(loaded.keys()): # Convert keys to a list to avoid runtime error due to size change - match = False - if isinstance(key, str) and block in key: - match = True - elif isinstance(key, tuple): - for k in key: - if block in k: - match = True - break + for block in blocks: + for key in list(loaded.keys()): # Convert keys to a list to avoid runtime error due to size change + match = False + if isinstance(key, str) and block in key: + match = True + elif isinstance(key, tuple): + for k in key: + if block in k: + match = True + break - if match: - ratio = blocks[block] - if ratio == 0: - keys_to_delete.append(key) # Collect keys to delete - else: - value = loaded[key] - if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): - # Handle the tuple format - if len(value[1]) > 3: - loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1])) - else: - loaded[key] = (value[0], value[1][:-2] + (ratio, value[1][-1])) + if match: + ratio = blocks[block] + if ratio == 0: + keys_to_delete.append(key) # Collect keys to delete else: - # Handle the simpler format directly - loaded[key] = (value[0], ratio) + value = loaded[key] + if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): + # Handle the tuple format + if len(value[1]) > 3: + loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1])) + else: + loaded[key] = (value[0], value[1][:-2] + (ratio, value[1][-1])) + else: + # Handle the simpler format directly + loaded[key] = (value[0], ratio) - # Now perform the deletion of keys - for key in keys_to_delete: - del loaded[key] + # Now perform the deletion of keys + for key in keys_to_delete: + del loaded[key] - print("loading lora keys:") - for key, value in loaded.items(): - if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): - # Handle the tuple format - if len(value[1]) > 2: - alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple + print("loading lora keys:") + for key, value in loaded.items(): + if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): + # Handle the tuple format + if len(value[1]) > 2: + alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple + else: + alpha = value[1][-2] # Adjust according to the second format's structure else: - alpha = value[1][-2] # Adjust according to the second format's structure - else: - # Handle the simpler format directly - alpha = value[1] if len(value) > 1 else None - print(f"Key: {key}, Alpha: {alpha}") + # Handle the simpler format directly + alpha = value[1] if len(value) > 1 else None + print(f"Key: {key}, Alpha: {alpha}") if model is not None: diff --git a/web/js/jsnodes.js b/web/js/jsnodes.js index 68df804..f0ea6f7 100644 --- a/web/js/jsnodes.js +++ b/web/js/jsnodes.js @@ -75,6 +75,22 @@ app.registerExtension({ }); } break; + + case "FluxBlockLoraSelect": + nodeType.prototype.onNodeCreated = function () { + this.addWidget("button", "Set all", null, () => { + const userValue = parseFloat(prompt("Enter the value to set for all widgets:", "1.0")); + if (!isNaN(userValue)) { + const widgets = this.widgets; + for (const w of widgets) { + w.value = userValue; + } + } else { + alert("Invalid input. Please enter a numeric value."); + } + }); + }; + break; case "GetMaskSizeAndCount": const onGetMaskSizeConnectInput = nodeType.prototype.onConnectInput;