Support multiple LoRAs

This commit is contained in:
kijai 2024-10-21 13:24:21 +03:00
parent dc4bb73d3f
commit 4c3fcd7b01
3 changed files with 70 additions and 35 deletions

View File

@ -130,7 +130,7 @@
"widgets_values": [
"kijai/CogVideoX-5b-Tora",
"bf16",
"fastmode",
"disabled",
"disabled",
false
]

View File

@ -476,26 +476,35 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
return pipeline
def load_lora_into_transformer(state_dict, transformer, adapter_name=None, strength=1.0):
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
def load_lora_into_transformer(lora, transformer):
from peft import LoraConfig, set_peft_model_state_dict
from peft.mapping import PEFT_TYPE_TO_TUNER_MAPPING
from peft.tuners.tuners_utils import BaseTunerLayer
from diffusers.utils.peft_utils import get_peft_kwargs, get_adapter_name
from diffusers.utils.peft_utils import get_peft_kwargs
from diffusers.utils.import_utils import is_peft_version
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith("transformer")]
state_dict = {
k.replace(f"transformer.", ""): v for k, v in state_dict.items() if k in transformer_keys
}
if len(state_dict.keys()) > 0:
state_dict_list = []
adapter_name_list = []
strength_list = []
lora_config_list = []
for l in lora:
state_dict = load_file(l["path"])
adapter_name_list.append(l["name"])
strength_list.append(l["strength"])
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith("transformer")]
state_dict = {
k.replace(f"transformer.", ""): v for k, v in state_dict.items() if k in transformer_keys
}
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
@ -508,13 +517,18 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None, stren
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
transformer = inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
lora_config_list.append(LoraConfig(**lora_config_kwargs))
state_dict_list.append(state_dict)
peft_models = []
for i in range(len(lora_config_list)):
tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[lora_config_list[i].peft_type]
peft_model = tuner_cls(transformer, lora_config_list[i], adapter_name=adapter_name_list[i])
incompatible_keys = set_peft_model_state_dict(peft_model.model, state_dict_list[i], adapter_name_list[i])
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
@ -523,10 +537,21 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None, stren
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
peft_models.append(peft_model)
if strength != 1.0:
if len(peft_models) > 1:
peft_models[0].add_weighted_adapter(
adapters=adapter_name_list,
weights=strength_list,
combination_type="linear",
adapter_name="combined_adapter"
)
peft_models[0].set_adapter("combined_adapter")
else:
if strength_list[0] != 1.0:
for module in transformer.modules():
if isinstance(module, BaseTunerLayer):
#print(f"Setting strength for {module}")
module.scale_layer(strength)
return transformer
module.scale_layer(strength_list[0])
return peft_model.model

View File

@ -224,6 +224,9 @@ class CogVideoLoraSelect:
{"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
},
"optional": {
"prev_lora":("COGLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}),
}
}
RETURN_TYPES = ("COGLORA",)
@ -231,15 +234,20 @@ class CogVideoLoraSelect:
FUNCTION = "getlorapath"
CATEGORY = "CogVideoWrapper"
def getlorapath(self, lora, strength):
def getlorapath(self, lora, strength, prev_lora=None):
cog_loras_list = []
cog_lora = {
"path": folder_paths.get_full_path("cogvideox_loras", lora),
"strength": strength,
"name": lora.split(".")[0],
}
return (cog_lora,)
if prev_lora is not None:
cog_loras_list.extend(prev_lora)
cog_loras_list.append(cog_lora)
print(cog_loras_list)
return (cog_loras_list,)
class DownloadAndLoadCogVideoModel:
@classmethod
@ -268,7 +276,7 @@ class DownloadAndLoadCogVideoModel:
"precision": (["fp16", "fp32", "bf16"],
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
),
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs"}),
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
"pab_config": ("PAB_CONFIG", {"default": None}),
@ -281,6 +289,7 @@ class DownloadAndLoadCogVideoModel:
RETURN_NAMES = ("cogvideo_pipe", )
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None):
@ -353,14 +362,15 @@ class DownloadAndLoadCogVideoModel:
#LoRAs
if lora is not None:
from .cogvideox_fun.lora_utils import merge_lora, load_lora_into_transformer
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
from .lora_utils import merge_lora, load_lora_into_transformer
if "fun" in model.lower():
transformer = merge_lora(transformer, lora["path"], lora["strength"])
for l in lora:
logging.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
transformer = merge_lora(transformer, l["path"], l["strength"])
else:
lora_sd = load_torch_file(lora["path"])
transformer = load_lora_into_transformer(state_dict=lora_sd, transformer=transformer, adapter_name=lora["name"], strength=lora["strength"])
transformer = load_lora_into_transformer(lora, transformer)
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)
@ -472,7 +482,7 @@ class DownloadAndLoadCogVideoGGUFModel:
],
),
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
"fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs"}),
"fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs, also requires torch 2.4.0 with cu124 minimum"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
},