From a99524762432bda75e7ef3a95cf8923917f9078b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 12 Oct 2024 22:23:54 +0300 Subject: [PATCH] possibly fix LoRA scaling --- cogvideox_fun/lora_utils.py | 9 ++++++++- nodes.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/cogvideox_fun/lora_utils.py b/cogvideox_fun/lora_utils.py index ccb3f65..fecc923 100644 --- a/cogvideox_fun/lora_utils.py +++ b/cogvideox_fun/lora_utils.py @@ -476,8 +476,9 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl return pipeline -def load_lora_into_transformer(state_dict, transformer, adapter_name=None): +def load_lora_into_transformer(state_dict, transformer, adapter_name=None, strength=1.0): from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + from peft.tuners.tuners_utils import BaseTunerLayer from diffusers.utils.peft_utils import get_peft_kwargs, get_adapter_name from diffusers.utils.import_utils import is_peft_version from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft @@ -522,4 +523,10 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None): f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) + + if strength != 1.0: + for module in transformer.modules(): + if isinstance(module, BaseTunerLayer): + #print(f"Setting strength for {module}") + module.scale_layer(strength) return transformer \ No newline at end of file diff --git a/nodes.py b/nodes.py index e1caadd..dcac384 100644 --- a/nodes.py +++ b/nodes.py @@ -356,7 +356,7 @@ class DownloadAndLoadCogVideoModel: transformer = merge_lora(transformer, lora["path"], lora["strength"]) else: lora_sd = load_torch_file(lora["path"]) - transformer = load_lora_into_transformer(state_dict=lora_sd, transformer=transformer, adapter_name=lora["name"]) + transformer = load_lora_into_transformer(state_dict=lora_sd, transformer=transformer, adapter_name=lora["name"], strength=lora["strength"]) if block_edit is not None: transformer = remove_specific_blocks(transformer, block_edit)