mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
support extra_model_paths
This commit is contained in:
parent
af2a568c8e
commit
502ebab83f
19
nodes.py
19
nodes.py
@ -69,6 +69,11 @@ log = logging.getLogger(__name__)
|
||||
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if not "CogVideo" in folder_paths.folder_names_and_paths:
|
||||
folder_paths.add_model_folder_path("CogVideo", os.path.join(folder_paths.models_dir, "CogVideo"))
|
||||
if not "cogvideox_loras" in folder_paths.folder_names_and_paths:
|
||||
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
|
||||
|
||||
class PABConfig:
|
||||
def __init__(
|
||||
self,
|
||||
@ -210,9 +215,6 @@ class CogVideoTransformerEdit:
|
||||
log.info(f"Blocks selected for removal: {blocks_to_remove}")
|
||||
return (blocks_to_remove,)
|
||||
|
||||
|
||||
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
|
||||
|
||||
class CogVideoLoraSelect:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -287,7 +289,8 @@ class DownloadAndLoadCogVideoModel:
|
||||
mm.soft_empty_cache()
|
||||
|
||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
download_path = os.path.join(folder_paths.models_dir, "CogVideo")
|
||||
download_path = folder_paths.get_folder_paths("CogVideo")[0]
|
||||
|
||||
if "Fun" in model:
|
||||
if not "1.1" in model:
|
||||
repo_id = "kijai/CogVideoX-Fun-pruned"
|
||||
@ -307,11 +310,11 @@ class DownloadAndLoadCogVideoModel:
|
||||
download_path = base_path
|
||||
|
||||
elif "2b" in model:
|
||||
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
|
||||
base_path = os.path.join(download_path, "CogVideo2B")
|
||||
download_path = base_path
|
||||
repo_id = model
|
||||
elif "5b" in model:
|
||||
base_path = os.path.join(folder_paths.models_dir, "CogVideo", (model.split("/")[-1]))
|
||||
base_path = os.path.join(download_path, (model.split("/")[-1]))
|
||||
download_path = base_path
|
||||
repo_id = model
|
||||
|
||||
@ -345,6 +348,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
|
||||
transformer = transformer.to(dtype).to(offload_device)
|
||||
|
||||
#LoRAs
|
||||
if lora is not None:
|
||||
from .cogvideox_fun.lora_utils import merge_lora, load_lora_into_transformer
|
||||
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
|
||||
@ -353,12 +357,11 @@ class DownloadAndLoadCogVideoModel:
|
||||
else:
|
||||
lora_sd = load_torch_file(lora["path"])
|
||||
transformer = load_lora_into_transformer(state_dict=lora_sd, transformer=transformer, adapter_name=lora["name"])
|
||||
#raise NotImplementedError("LoRA merging is currently only supported for Fun models")
|
||||
|
||||
|
||||
if block_edit is not None:
|
||||
transformer = remove_specific_blocks(transformer, block_edit)
|
||||
|
||||
#fp8
|
||||
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
|
||||
if "2b" in model:
|
||||
for name, param in transformer.named_parameters():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user