mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
lora loading
This commit is contained in:
parent
5fde34468e
commit
d8a9f31de9
@ -512,7 +512,7 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None):
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(transformer)
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
transformer = inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
@ -521,4 +521,5 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None):
|
||||
print(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
)
|
||||
return transformer
|
||||
9
nodes.py
9
nodes.py
@ -58,7 +58,7 @@ from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as Autoenco
|
||||
from .cogvideox_fun.utils import get_image_to_video_latent, get_video_to_video_latent, ASPECT_RATIO_512, get_closest_ratio, to_pil
|
||||
from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
|
||||
from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control
|
||||
from .cogvideox_fun.lora_utils import merge_lora
|
||||
from .cogvideox_fun.lora_utils import merge_lora, load_lora_into_transformer
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import json
|
||||
@ -233,7 +233,8 @@ class CogVideoLoraSelect:
|
||||
|
||||
cog_lora = {
|
||||
"path": folder_paths.get_full_path("cogvideox_loras", lora),
|
||||
"strength": strength
|
||||
"strength": strength,
|
||||
"name": lora.split(".")[0],
|
||||
}
|
||||
|
||||
return (cog_lora,)
|
||||
@ -349,7 +350,9 @@ class DownloadAndLoadCogVideoModel:
|
||||
if "fun" in model.lower():
|
||||
transformer = merge_lora(transformer, lora["path"], lora["strength"])
|
||||
else:
|
||||
raise NotImplementedError("LoRA merging is currently only supported for Fun models")
|
||||
lora_sd = load_torch_file(lora["path"])
|
||||
transformer = load_lora_into_transformer(state_dict=lora_sd, transformer=transformer, adapter_name=lora["name"])
|
||||
#raise NotImplementedError("LoRA merging is currently only supported for Fun models")
|
||||
|
||||
|
||||
if block_edit is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user