diff --git a/__init__.py b/__init__.py index fb1923f..1df5bca 100644 --- a/__init__.py +++ b/__init__.py @@ -136,6 +136,7 @@ NODE_CONFIG = { "WebcamCaptureCV2": {"class": WebcamCaptureCV2, "name": "Webcam Capture CV2"}, "DifferentialDiffusionAdvanced": {"class": DifferentialDiffusionAdvanced, "name": "Differential Diffusion Advanced"}, "FluxBlockLoraLoader": {"class": FluxBlockLoraLoader, "name": "Flux Block Lora Loader"}, + "FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/nodes.py b/nodes/nodes.py index 6a10605..30d66af 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1796,6 +1796,34 @@ class DifferentialDiffusionAdvanced(): return (denoise_mask >= threshold).to(denoise_mask.dtype) +class FluxBlockLoraSelect: + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + arg_dict = {} + argument = ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01}) + + for i in range(19): + arg_dict["double_blocks.{}.".format(i)] = argument + + for i in range(38): + arg_dict["single_blocks.{}.".format(i)] = argument + + return {"required": arg_dict} + + RETURN_TYPES = ("SELECTEDBLOCKS", ) + RETURN_NAMES = ("blocks", ) + OUTPUT_TOOLTIPS = ("The modified diffusion model.",) + FUNCTION = "load_lora" + + CATEGORY = "KJNodes/experimental" + DESCRIPTION = "Select individual block alpha values, value of 0 removes the block altogether" + + def load_lora(self, **kwargs): + return (kwargs,) + class FluxBlockLoraLoader: def __init__(self): self.loaded_lora = None @@ -1806,26 +1834,17 @@ class FluxBlockLoraLoader: "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + "blocks": ("SELECTEDBLOCKS",), } - argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}) - - for i in range(19): - arg_dict["double_blocks.{}.".format(i)] = argument - - for i in range(38): - arg_dict["single_blocks.{}.".format(i)] = argument - return {"required": arg_dict} RETURN_TYPES = ("MODEL", ) OUTPUT_TOOLTIPS = ("The modified diffusion model.",) FUNCTION = "load_lora" - CATEGORY = "KJNodes/experimental" - DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together." - def load_lora(self, model,lora_name, strength_model, **kwargs): + def load_lora(self, model, lora_name, strength_model, blocks): from comfy.utils import load_torch_file import comfy.lora @@ -1848,25 +1867,54 @@ class FluxBlockLoraLoader: key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) loaded = comfy.lora.load_lora(lora, key_map) - #filtered_dict = {k: v for k, v in loaded.items() if 'double_blocks.0' in k} - #print(filtered_dict) - for arg in kwargs: + keys_to_delete = [] + + for block in blocks: for key in list(loaded.keys()): # Convert keys to a list to avoid runtime error due to size change - if arg in key: - ratio = kwargs[arg] + match = False + if isinstance(key, str) and block in key: + match = True + elif isinstance(key, tuple): + for k in key: + if block in k: + match = True + break + + if match: + ratio = blocks[block] if ratio == 0: - del loaded[key] # Remove the key if ratio is 0 + keys_to_delete.append(key) # Collect keys to delete else: value = loaded[key] - loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1])) + if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): + # Handle the tuple format + if len(value[1]) > 3: + loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1])) + else: + loaded[key] = (value[0], value[1][:-2] + (ratio, value[1][-1])) + else: + # Handle the simpler format directly + loaded[key] = (value[0], ratio) + + # Now perform the deletion of keys + for key in keys_to_delete: + del loaded[key] + print("loading lora keys:") for key, value in loaded.items(): - if len(value) > 1 and len(value[1]) > 2: - alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple + if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): + # Handle the tuple format + if len(value[1]) > 2: + alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple + else: + alpha = value[1][-2] # Adjust according to the second format's structure else: - alpha = None + # Handle the simpler format directly + alpha = value[1] if len(value) > 1 else None print(f"Key: {key}, Alpha: {alpha}") + + if model is not None: new_modelpatcher = model.clone() k = new_modelpatcher.add_patches(loaded, strength_model)