lora loading

This commit is contained in:
kijai 2024-10-06 02:48:39 +03:00
parent 5fde34468e
commit d8a9f31de9
2 changed files with 9 additions and 5 deletions

View File

@ -512,7 +512,7 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None):
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
transformer = inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
@ -521,4 +521,5 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None):
print(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
)
return transformer

View File

@ -58,7 +58,7 @@ from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as Autoenco
from .cogvideox_fun.utils import get_image_to_video_latent, get_video_to_video_latent, ASPECT_RATIO_512, get_closest_ratio, to_pil
from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control
from .cogvideox_fun.lora_utils import merge_lora
from .cogvideox_fun.lora_utils import merge_lora, load_lora_into_transformer
from PIL import Image
import numpy as np
import json
@ -233,7 +233,8 @@ class CogVideoLoraSelect:
cog_lora = {
"path": folder_paths.get_full_path("cogvideox_loras", lora),
"strength": strength
"strength": strength,
"name": lora.split(".")[0],
}
return (cog_lora,)
@ -349,7 +350,9 @@ class DownloadAndLoadCogVideoModel:
if "fun" in model.lower():
transformer = merge_lora(transformer, lora["path"], lora["strength"])
else:
raise NotImplementedError("LoRA merging is currently only supported for Fun models")
lora_sd = load_torch_file(lora["path"])
transformer = load_lora_into_transformer(state_dict=lora_sd, transformer=transformer, adapter_name=lora["name"])
#raise NotImplementedError("LoRA merging is currently only supported for Fun models")
if block_edit is not None: