mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-24 04:14:25 +08:00
Support multiple LoRAs
This commit is contained in:
parent
dc4bb73d3f
commit
4c3fcd7b01
@ -130,7 +130,7 @@
|
|||||||
"widgets_values": [
|
"widgets_values": [
|
||||||
"kijai/CogVideoX-5b-Tora",
|
"kijai/CogVideoX-5b-Tora",
|
||||||
"bf16",
|
"bf16",
|
||||||
"fastmode",
|
"disabled",
|
||||||
"disabled",
|
"disabled",
|
||||||
false
|
false
|
||||||
]
|
]
|
||||||
|
|||||||
@ -476,26 +476,35 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
|
|||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
def load_lora_into_transformer(state_dict, transformer, adapter_name=None, strength=1.0):
|
def load_lora_into_transformer(lora, transformer):
|
||||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
from peft import LoraConfig, set_peft_model_state_dict
|
||||||
|
from peft.mapping import PEFT_TYPE_TO_TUNER_MAPPING
|
||||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||||
from diffusers.utils.peft_utils import get_peft_kwargs, get_adapter_name
|
from diffusers.utils.peft_utils import get_peft_kwargs
|
||||||
from diffusers.utils.import_utils import is_peft_version
|
from diffusers.utils.import_utils import is_peft_version
|
||||||
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
||||||
keys = list(state_dict.keys())
|
|
||||||
transformer_keys = [k for k in keys if k.startswith("transformer")]
|
state_dict_list = []
|
||||||
state_dict = {
|
adapter_name_list = []
|
||||||
k.replace(f"transformer.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
strength_list = []
|
||||||
}
|
lora_config_list = []
|
||||||
if len(state_dict.keys()) > 0:
|
|
||||||
|
for l in lora:
|
||||||
|
state_dict = load_file(l["path"])
|
||||||
|
adapter_name_list.append(l["name"])
|
||||||
|
strength_list.append(l["strength"])
|
||||||
|
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
transformer_keys = [k for k in keys if k.startswith("transformer")]
|
||||||
|
state_dict = {
|
||||||
|
k.replace(f"transformer.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
||||||
|
}
|
||||||
|
|
||||||
# check with first key if is not in peft format
|
# check with first key if is not in peft format
|
||||||
first_key = next(iter(state_dict.keys()))
|
first_key = next(iter(state_dict.keys()))
|
||||||
if "lora_A" not in first_key:
|
if "lora_A" not in first_key:
|
||||||
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
|
||||||
raise ValueError(
|
|
||||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
|
||||||
)
|
|
||||||
rank = {}
|
rank = {}
|
||||||
for key, val in state_dict.items():
|
for key, val in state_dict.items():
|
||||||
if "lora_B" in key:
|
if "lora_B" in key:
|
||||||
@ -508,13 +517,18 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None, stren
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
lora_config_kwargs.pop("use_dora")
|
lora_config_kwargs.pop("use_dora")
|
||||||
lora_config = LoraConfig(**lora_config_kwargs)
|
|
||||||
# adapter_name
|
|
||||||
if adapter_name is None:
|
|
||||||
adapter_name = get_adapter_name(transformer)
|
|
||||||
|
|
||||||
transformer = inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
lora_config_list.append(LoraConfig(**lora_config_kwargs))
|
||||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
state_dict_list.append(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
peft_models = []
|
||||||
|
|
||||||
|
for i in range(len(lora_config_list)):
|
||||||
|
tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[lora_config_list[i].peft_type]
|
||||||
|
peft_model = tuner_cls(transformer, lora_config_list[i], adapter_name=adapter_name_list[i])
|
||||||
|
incompatible_keys = set_peft_model_state_dict(peft_model.model, state_dict_list[i], adapter_name_list[i])
|
||||||
|
|
||||||
if incompatible_keys is not None:
|
if incompatible_keys is not None:
|
||||||
# check only for unexpected keys
|
# check only for unexpected keys
|
||||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||||
@ -524,9 +538,20 @@ def load_lora_into_transformer(state_dict, transformer, adapter_name=None, stren
|
|||||||
f" {unexpected_keys}. "
|
f" {unexpected_keys}. "
|
||||||
)
|
)
|
||||||
|
|
||||||
if strength != 1.0:
|
peft_models.append(peft_model)
|
||||||
|
|
||||||
|
if len(peft_models) > 1:
|
||||||
|
peft_models[0].add_weighted_adapter(
|
||||||
|
adapters=adapter_name_list,
|
||||||
|
weights=strength_list,
|
||||||
|
combination_type="linear",
|
||||||
|
adapter_name="combined_adapter"
|
||||||
|
)
|
||||||
|
peft_models[0].set_adapter("combined_adapter")
|
||||||
|
else:
|
||||||
|
if strength_list[0] != 1.0:
|
||||||
for module in transformer.modules():
|
for module in transformer.modules():
|
||||||
if isinstance(module, BaseTunerLayer):
|
if isinstance(module, BaseTunerLayer):
|
||||||
#print(f"Setting strength for {module}")
|
#print(f"Setting strength for {module}")
|
||||||
module.scale_layer(strength)
|
module.scale_layer(strength_list[0])
|
||||||
return transformer
|
return peft_model.model
|
||||||
28
nodes.py
28
nodes.py
@ -224,6 +224,9 @@ class CogVideoLoraSelect:
|
|||||||
{"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}),
|
{"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}),
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
|
||||||
},
|
},
|
||||||
|
"optional": {
|
||||||
|
"prev_lora":("COGLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("COGLORA",)
|
RETURN_TYPES = ("COGLORA",)
|
||||||
@ -231,15 +234,20 @@ class CogVideoLoraSelect:
|
|||||||
FUNCTION = "getlorapath"
|
FUNCTION = "getlorapath"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def getlorapath(self, lora, strength):
|
def getlorapath(self, lora, strength, prev_lora=None):
|
||||||
|
cog_loras_list = []
|
||||||
|
|
||||||
cog_lora = {
|
cog_lora = {
|
||||||
"path": folder_paths.get_full_path("cogvideox_loras", lora),
|
"path": folder_paths.get_full_path("cogvideox_loras", lora),
|
||||||
"strength": strength,
|
"strength": strength,
|
||||||
"name": lora.split(".")[0],
|
"name": lora.split(".")[0],
|
||||||
}
|
}
|
||||||
|
if prev_lora is not None:
|
||||||
|
cog_loras_list.extend(prev_lora)
|
||||||
|
|
||||||
return (cog_lora,)
|
cog_loras_list.append(cog_lora)
|
||||||
|
print(cog_loras_list)
|
||||||
|
return (cog_loras_list,)
|
||||||
|
|
||||||
class DownloadAndLoadCogVideoModel:
|
class DownloadAndLoadCogVideoModel:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -268,7 +276,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"precision": (["fp16", "fp32", "bf16"],
|
"precision": (["fp16", "fp32", "bf16"],
|
||||||
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
||||||
),
|
),
|
||||||
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs"}),
|
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
|
||||||
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||||
"pab_config": ("PAB_CONFIG", {"default": None}),
|
"pab_config": ("PAB_CONFIG", {"default": None}),
|
||||||
@ -281,6 +289,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
RETURN_NAMES = ("cogvideo_pipe", )
|
RETURN_NAMES = ("cogvideo_pipe", )
|
||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
|
||||||
|
|
||||||
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None):
|
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None):
|
||||||
|
|
||||||
@ -353,13 +362,14 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
|
|
||||||
#LoRAs
|
#LoRAs
|
||||||
if lora is not None:
|
if lora is not None:
|
||||||
from .cogvideox_fun.lora_utils import merge_lora, load_lora_into_transformer
|
from .lora_utils import merge_lora, load_lora_into_transformer
|
||||||
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
|
|
||||||
if "fun" in model.lower():
|
if "fun" in model.lower():
|
||||||
transformer = merge_lora(transformer, lora["path"], lora["strength"])
|
for l in lora:
|
||||||
|
logging.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
|
||||||
|
transformer = merge_lora(transformer, l["path"], l["strength"])
|
||||||
else:
|
else:
|
||||||
lora_sd = load_torch_file(lora["path"])
|
transformer = load_lora_into_transformer(lora, transformer)
|
||||||
transformer = load_lora_into_transformer(state_dict=lora_sd, transformer=transformer, adapter_name=lora["name"], strength=lora["strength"])
|
|
||||||
|
|
||||||
if block_edit is not None:
|
if block_edit is not None:
|
||||||
transformer = remove_specific_blocks(transformer, block_edit)
|
transformer = remove_specific_blocks(transformer, block_edit)
|
||||||
@ -472,7 +482,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
|
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
|
||||||
"fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs"}),
|
"fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs, also requires torch 2.4.0 with cu124 minimum"}),
|
||||||
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
||||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||||
},
|
},
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user