Support multiple LoRAs

This commit is contained in:
kijai 2024-10-21 13:24:21 +03:00
parent dc4bb73d3f
commit 4c3fcd7b01
3 changed files with 70 additions and 35 deletions

View File

@ -130,7 +130,7 @@
"widgets_values": [ "widgets_values": [
"kijai/CogVideoX-5b-Tora", "kijai/CogVideoX-5b-Tora",
"bf16", "bf16",
"fastmode", "disabled",
"disabled", "disabled",
false false
] ]

View File

@ -476,26 +476,35 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
return pipeline return pipeline
def load_lora_into_transformer(state_dict, transformer, adapter_name=None, strength=1.0): def load_lora_into_transformer(lora, transformer):
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft import LoraConfig, set_peft_model_state_dict
from peft.mapping import PEFT_TYPE_TO_TUNER_MAPPING
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
from diffusers.utils.peft_utils import get_peft_kwargs, get_adapter_name from diffusers.utils.peft_utils import get_peft_kwargs
from diffusers.utils.import_utils import is_peft_version from diffusers.utils.import_utils import is_peft_version
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith("transformer")] state_dict_list = []
state_dict = { adapter_name_list = []
k.replace(f"transformer.", ""): v for k, v in state_dict.items() if k in transformer_keys strength_list = []
} lora_config_list = []
if len(state_dict.keys()) > 0:
for l in lora:
state_dict = load_file(l["path"])
adapter_name_list.append(l["name"])
strength_list.append(l["strength"])
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith("transformer")]
state_dict = {
k.replace(f"transformer.", ""): v for k, v in state_dict.items() if k in transformer_keys
}
# check with first key if is not in peft format # check with first key if is not in peft format
first_key = next(iter(state_dict.keys())) first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key: if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict) state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {} rank = {}
for key, val in state_dict.items(): for key, val in state_dict.items():
if "lora_B" in key: if "lora_B" in key:
@ -508,13 +517,18 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None, stren
) )
else: else:
lora_config_kwargs.pop("use_dora") lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
transformer = inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) lora_config_list.append(LoraConfig(**lora_config_kwargs))
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) state_dict_list.append(state_dict)
peft_models = []
for i in range(len(lora_config_list)):
tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[lora_config_list[i].peft_type]
peft_model = tuner_cls(transformer, lora_config_list[i], adapter_name=adapter_name_list[i])
incompatible_keys = set_peft_model_state_dict(peft_model.model, state_dict_list[i], adapter_name_list[i])
if incompatible_keys is not None: if incompatible_keys is not None:
# check only for unexpected keys # check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
@ -524,9 +538,20 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None, stren
f" {unexpected_keys}. " f" {unexpected_keys}. "
) )
if strength != 1.0: peft_models.append(peft_model)
if len(peft_models) > 1:
peft_models[0].add_weighted_adapter(
adapters=adapter_name_list,
weights=strength_list,
combination_type="linear",
adapter_name="combined_adapter"
)
peft_models[0].set_adapter("combined_adapter")
else:
if strength_list[0] != 1.0:
for module in transformer.modules(): for module in transformer.modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
#print(f"Setting strength for {module}") #print(f"Setting strength for {module}")
module.scale_layer(strength) module.scale_layer(strength_list[0])
return transformer return peft_model.model

View File

@ -224,6 +224,9 @@ class CogVideoLoraSelect:
{"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}), {"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
}, },
"optional": {
"prev_lora":("COGLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}),
}
} }
RETURN_TYPES = ("COGLORA",) RETURN_TYPES = ("COGLORA",)
@ -231,15 +234,20 @@ class CogVideoLoraSelect:
FUNCTION = "getlorapath" FUNCTION = "getlorapath"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def getlorapath(self, lora, strength): def getlorapath(self, lora, strength, prev_lora=None):
cog_loras_list = []
cog_lora = { cog_lora = {
"path": folder_paths.get_full_path("cogvideox_loras", lora), "path": folder_paths.get_full_path("cogvideox_loras", lora),
"strength": strength, "strength": strength,
"name": lora.split(".")[0], "name": lora.split(".")[0],
} }
if prev_lora is not None:
cog_loras_list.extend(prev_lora)
return (cog_lora,) cog_loras_list.append(cog_lora)
print(cog_loras_list)
return (cog_loras_list,)
class DownloadAndLoadCogVideoModel: class DownloadAndLoadCogVideoModel:
@classmethod @classmethod
@ -268,7 +276,7 @@ class DownloadAndLoadCogVideoModel:
"precision": (["fp16", "fp32", "bf16"], "precision": (["fp16", "fp32", "bf16"],
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
), ),
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs"}), "fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}), "compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
"pab_config": ("PAB_CONFIG", {"default": None}), "pab_config": ("PAB_CONFIG", {"default": None}),
@ -281,6 +289,7 @@ class DownloadAndLoadCogVideoModel:
RETURN_NAMES = ("cogvideo_pipe", ) RETURN_NAMES = ("cogvideo_pipe", )
FUNCTION = "loadmodel" FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None): def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None):
@ -353,13 +362,14 @@ class DownloadAndLoadCogVideoModel:
#LoRAs #LoRAs
if lora is not None: if lora is not None:
from .cogvideox_fun.lora_utils import merge_lora, load_lora_into_transformer from .lora_utils import merge_lora, load_lora_into_transformer
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
if "fun" in model.lower(): if "fun" in model.lower():
transformer = merge_lora(transformer, lora["path"], lora["strength"]) for l in lora:
logging.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
transformer = merge_lora(transformer, l["path"], l["strength"])
else: else:
lora_sd = load_torch_file(lora["path"]) transformer = load_lora_into_transformer(lora, transformer)
transformer = load_lora_into_transformer(state_dict=lora_sd, transformer=transformer, adapter_name=lora["name"], strength=lora["strength"])
if block_edit is not None: if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit) transformer = remove_specific_blocks(transformer, block_edit)
@ -472,7 +482,7 @@ class DownloadAndLoadCogVideoGGUFModel:
], ],
), ),
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}), "vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
"fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs"}), "fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs, also requires torch 2.4.0 with cu124 minimum"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
}, },