From d8a9f31de9511063092e426285a668aaac1f8158 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 6 Oct 2024 02:48:39 +0300 Subject: [PATCH] lora loading --- cogvideox_fun/lora_utils.py | 5 +++-- nodes.py | 9 ++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cogvideox_fun/lora_utils.py b/cogvideox_fun/lora_utils.py index 3191c0c..ccb3f65 100644 --- a/cogvideox_fun/lora_utils.py +++ b/cogvideox_fun/lora_utils.py @@ -512,7 +512,7 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None): if adapter_name is None: adapter_name = get_adapter_name(transformer) - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + transformer = inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) if incompatible_keys is not None: # check only for unexpected keys @@ -521,4 +521,5 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None): print( f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " - ) \ No newline at end of file + ) + return transformer \ No newline at end of file diff --git a/nodes.py b/nodes.py index ac08b3f..fb4b949 100644 --- a/nodes.py +++ b/nodes.py @@ -58,7 +58,7 @@ from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as Autoenco from .cogvideox_fun.utils import get_image_to_video_latent, get_video_to_video_latent, ASPECT_RATIO_512, get_closest_ratio, to_pil from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control -from .cogvideox_fun.lora_utils import merge_lora +from .cogvideox_fun.lora_utils import merge_lora, load_lora_into_transformer from PIL import Image import numpy as np import json @@ -233,7 +233,8 @@ class CogVideoLoraSelect: cog_lora = { "path": folder_paths.get_full_path("cogvideox_loras", lora), - "strength": strength + "strength": strength, + "name": lora.split(".")[0], } return (cog_lora,) @@ -349,7 +350,9 @@ class DownloadAndLoadCogVideoModel: if "fun" in model.lower(): transformer = merge_lora(transformer, lora["path"], lora["strength"]) else: - raise NotImplementedError("LoRA merging is currently only supported for Fun models") + lora_sd = load_torch_file(lora["path"]) + transformer = load_lora_into_transformer(state_dict=lora_sd, transformer=transformer, adapter_name=lora["name"]) + #raise NotImplementedError("LoRA merging is currently only supported for Fun models") if block_edit is not None: