improve Flux lora block select

This commit is contained in:
kijai 2024-08-31 15:10:18 +03:00
parent bdb65e5635
commit 0defb731ac
2 changed files with 70 additions and 21 deletions

View File

@ -136,6 +136,7 @@ NODE_CONFIG = {
"WebcamCaptureCV2": {"class": WebcamCaptureCV2, "name": "Webcam Capture CV2"},
"DifferentialDiffusionAdvanced": {"class": DifferentialDiffusionAdvanced, "name": "Differential Diffusion Advanced"},
"FluxBlockLoraLoader": {"class": FluxBlockLoraLoader, "name": "Flux Block Lora Loader"},
"FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -1796,6 +1796,34 @@ class DifferentialDiffusionAdvanced():
return (denoise_mask >= threshold).to(denoise_mask.dtype)
class FluxBlockLoraSelect:
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
arg_dict = {}
argument = ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01})
for i in range(19):
arg_dict["double_blocks.{}.".format(i)] = argument
for i in range(38):
arg_dict["single_blocks.{}.".format(i)] = argument
return {"required": arg_dict}
RETURN_TYPES = ("SELECTEDBLOCKS", )
RETURN_NAMES = ("blocks", )
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
FUNCTION = "load_lora"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = "Select individual block alpha values, value of 0 removes the block altogether"
def load_lora(self, **kwargs):
return (kwargs,)
class FluxBlockLoraLoader:
def __init__(self):
self.loaded_lora = None
@ -1806,26 +1834,17 @@ class FluxBlockLoraLoader:
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
"blocks": ("SELECTEDBLOCKS",),
}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01})
for i in range(19):
arg_dict["double_blocks.{}.".format(i)] = argument
for i in range(38):
arg_dict["single_blocks.{}.".format(i)] = argument
return {"required": arg_dict}
RETURN_TYPES = ("MODEL", )
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
FUNCTION = "load_lora"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together."
def load_lora(self, model,lora_name, strength_model, **kwargs):
def load_lora(self, model, lora_name, strength_model, blocks):
from comfy.utils import load_torch_file
import comfy.lora
@ -1848,25 +1867,54 @@ class FluxBlockLoraLoader:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
loaded = comfy.lora.load_lora(lora, key_map)
#filtered_dict = {k: v for k, v in loaded.items() if 'double_blocks.0' in k}
#print(filtered_dict)
for arg in kwargs:
keys_to_delete = []
for block in blocks:
for key in list(loaded.keys()): # Convert keys to a list to avoid runtime error due to size change
if arg in key:
ratio = kwargs[arg]
match = False
if isinstance(key, str) and block in key:
match = True
elif isinstance(key, tuple):
for k in key:
if block in k:
match = True
break
if match:
ratio = blocks[block]
if ratio == 0:
del loaded[key] # Remove the key if ratio is 0
keys_to_delete.append(key) # Collect keys to delete
else:
value = loaded[key]
loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1]))
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
# Handle the tuple format
if len(value[1]) > 3:
loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1]))
else:
loaded[key] = (value[0], value[1][:-2] + (ratio, value[1][-1]))
else:
# Handle the simpler format directly
loaded[key] = (value[0], ratio)
# Now perform the deletion of keys
for key in keys_to_delete:
del loaded[key]
print("loading lora keys:")
for key, value in loaded.items():
if len(value) > 1 and len(value[1]) > 2:
alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
# Handle the tuple format
if len(value[1]) > 2:
alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple
else:
alpha = value[1][-2] # Adjust according to the second format's structure
else:
alpha = None
# Handle the simpler format directly
alpha = value[1] if len(value) > 1 else None
print(f"Key: {key}, Alpha: {alpha}")
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)