possibly fix LoRA scaling

This commit is contained in:
kijai 2024-10-12 22:23:54 +03:00
parent 276b3b86d9
commit a995247624
2 changed files with 9 additions and 2 deletions

View File

@ -476,8 +476,9 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
return pipeline
def load_lora_into_transformer(state_dict, transformer, adapter_name=None):
def load_lora_into_transformer(state_dict, transformer, adapter_name=None, strength=1.0):
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
from diffusers.utils.peft_utils import get_peft_kwargs, get_adapter_name
from diffusers.utils.import_utils import is_peft_version
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
@ -522,4 +523,10 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None):
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
if strength != 1.0:
for module in transformer.modules():
if isinstance(module, BaseTunerLayer):
#print(f"Setting strength for {module}")
module.scale_layer(strength)
return transformer

View File

@ -356,7 +356,7 @@ class DownloadAndLoadCogVideoModel:
transformer = merge_lora(transformer, lora["path"], lora["strength"])
else:
lora_sd = load_torch_file(lora["path"])
transformer = load_lora_into_transformer(state_dict=lora_sd, transformer=transformer, adapter_name=lora["name"])
transformer = load_lora_into_transformer(state_dict=lora_sd, transformer=transformer, adapter_name=lora["name"], strength=lora["strength"])
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)