mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 04:44:30 +08:00
improve Flux lora block select
This commit is contained in:
parent
bdb65e5635
commit
0defb731ac
@ -136,6 +136,7 @@ NODE_CONFIG = {
|
||||
"WebcamCaptureCV2": {"class": WebcamCaptureCV2, "name": "Webcam Capture CV2"},
|
||||
"DifferentialDiffusionAdvanced": {"class": DifferentialDiffusionAdvanced, "name": "Differential Diffusion Advanced"},
|
||||
"FluxBlockLoraLoader": {"class": FluxBlockLoraLoader, "name": "Flux Block Lora Loader"},
|
||||
"FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -1796,6 +1796,34 @@ class DifferentialDiffusionAdvanced():
|
||||
|
||||
return (denoise_mask >= threshold).to(denoise_mask.dtype)
|
||||
|
||||
class FluxBlockLoraSelect:
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = {}
|
||||
argument = ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01})
|
||||
|
||||
for i in range(19):
|
||||
arg_dict["double_blocks.{}.".format(i)] = argument
|
||||
|
||||
for i in range(38):
|
||||
arg_dict["single_blocks.{}.".format(i)] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
RETURN_TYPES = ("SELECTEDBLOCKS", )
|
||||
RETURN_NAMES = ("blocks", )
|
||||
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
|
||||
FUNCTION = "load_lora"
|
||||
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
DESCRIPTION = "Select individual block alpha values, value of 0 removes the block altogether"
|
||||
|
||||
def load_lora(self, **kwargs):
|
||||
return (kwargs,)
|
||||
|
||||
class FluxBlockLoraLoader:
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
@ -1806,26 +1834,17 @@ class FluxBlockLoraLoader:
|
||||
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
|
||||
"blocks": ("SELECTEDBLOCKS",),
|
||||
}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01})
|
||||
|
||||
for i in range(19):
|
||||
arg_dict["double_blocks.{}.".format(i)] = argument
|
||||
|
||||
for i in range(38):
|
||||
arg_dict["single_blocks.{}.".format(i)] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
RETURN_TYPES = ("MODEL", )
|
||||
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
|
||||
FUNCTION = "load_lora"
|
||||
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together."
|
||||
|
||||
def load_lora(self, model,lora_name, strength_model, **kwargs):
|
||||
def load_lora(self, model, lora_name, strength_model, blocks):
|
||||
from comfy.utils import load_torch_file
|
||||
import comfy.lora
|
||||
|
||||
@ -1848,25 +1867,54 @@ class FluxBlockLoraLoader:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
|
||||
loaded = comfy.lora.load_lora(lora, key_map)
|
||||
#filtered_dict = {k: v for k, v in loaded.items() if 'double_blocks.0' in k}
|
||||
|
||||
#print(filtered_dict)
|
||||
for arg in kwargs:
|
||||
keys_to_delete = []
|
||||
|
||||
for block in blocks:
|
||||
for key in list(loaded.keys()): # Convert keys to a list to avoid runtime error due to size change
|
||||
if arg in key:
|
||||
ratio = kwargs[arg]
|
||||
match = False
|
||||
if isinstance(key, str) and block in key:
|
||||
match = True
|
||||
elif isinstance(key, tuple):
|
||||
for k in key:
|
||||
if block in k:
|
||||
match = True
|
||||
break
|
||||
|
||||
if match:
|
||||
ratio = blocks[block]
|
||||
if ratio == 0:
|
||||
del loaded[key] # Remove the key if ratio is 0
|
||||
keys_to_delete.append(key) # Collect keys to delete
|
||||
else:
|
||||
value = loaded[key]
|
||||
loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1]))
|
||||
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
|
||||
# Handle the tuple format
|
||||
if len(value[1]) > 3:
|
||||
loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1]))
|
||||
else:
|
||||
loaded[key] = (value[0], value[1][:-2] + (ratio, value[1][-1]))
|
||||
else:
|
||||
# Handle the simpler format directly
|
||||
loaded[key] = (value[0], ratio)
|
||||
|
||||
# Now perform the deletion of keys
|
||||
for key in keys_to_delete:
|
||||
del loaded[key]
|
||||
|
||||
print("loading lora keys:")
|
||||
for key, value in loaded.items():
|
||||
if len(value) > 1 and len(value[1]) > 2:
|
||||
alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple
|
||||
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
|
||||
# Handle the tuple format
|
||||
if len(value[1]) > 2:
|
||||
alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple
|
||||
else:
|
||||
alpha = value[1][-2] # Adjust according to the second format's structure
|
||||
else:
|
||||
alpha = None
|
||||
# Handle the simpler format directly
|
||||
alpha = value[1] if len(value) > 1 else None
|
||||
print(f"Key: {key}, Alpha: {alpha}")
|
||||
|
||||
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user