support extra_model_paths

This commit is contained in:
kijai 2024-10-06 12:39:57 +03:00
parent af2a568c8e
commit 502ebab83f

View File

@ -69,6 +69,11 @@ log = logging.getLogger(__name__)
script_directory = os.path.dirname(os.path.abspath(__file__))
if not "CogVideo" in folder_paths.folder_names_and_paths:
folder_paths.add_model_folder_path("CogVideo", os.path.join(folder_paths.models_dir, "CogVideo"))
if not "cogvideox_loras" in folder_paths.folder_names_and_paths:
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
class PABConfig:
def __init__(
self,
@ -210,9 +215,6 @@ class CogVideoTransformerEdit:
log.info(f"Blocks selected for removal: {blocks_to_remove}")
return (blocks_to_remove,)
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
class CogVideoLoraSelect:
@classmethod
def INPUT_TYPES(s):
@ -287,7 +289,8 @@ class DownloadAndLoadCogVideoModel:
mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
download_path = os.path.join(folder_paths.models_dir, "CogVideo")
download_path = folder_paths.get_folder_paths("CogVideo")[0]
if "Fun" in model:
if not "1.1" in model:
repo_id = "kijai/CogVideoX-Fun-pruned"
@ -307,11 +310,11 @@ class DownloadAndLoadCogVideoModel:
download_path = base_path
elif "2b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
base_path = os.path.join(download_path, "CogVideo2B")
download_path = base_path
repo_id = model
elif "5b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideo", (model.split("/")[-1]))
base_path = os.path.join(download_path, (model.split("/")[-1]))
download_path = base_path
repo_id = model
@ -345,6 +348,7 @@ class DownloadAndLoadCogVideoModel:
transformer = transformer.to(dtype).to(offload_device)
#LoRAs
if lora is not None:
from .cogvideox_fun.lora_utils import merge_lora, load_lora_into_transformer
logging.info(f"Merging LoRA weights from {lora['path']} with strength {lora['strength']}")
@ -353,12 +357,11 @@ class DownloadAndLoadCogVideoModel:
else:
lora_sd = load_torch_file(lora["path"])
transformer = load_lora_into_transformer(state_dict=lora_sd, transformer=transformer, adapter_name=lora["name"])
#raise NotImplementedError("LoRA merging is currently only supported for Fun models")
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)
#fp8
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
if "2b" in model:
for name, param in transformer.named_parameters():