Update model_loading.py

This commit is contained in:
Jukka Seppänen 2024-11-09 04:16:46 +02:00
parent b563994afc
commit 9a64e1ae5e

View File

@ -144,6 +144,11 @@ class DownloadAndLoadCogVideoModel:
base_path = os.path.join(download_path, "CogVideo2B")
download_path = base_path
repo_id = model
elif "1.5-T2V" in model:
base_path = os.path.join(download_path, "CogVideoX-5b-1.5")
download_path = base_path
transformer_path = os.path.join(base_path, "transformer_T2V")
repo_id = "kijai/CogVideoX-5b-1.5"
else:
base_path = os.path.join(download_path, (model.split("/")[-1]))
download_path = base_path
@ -172,6 +177,8 @@ class DownloadAndLoadCogVideoModel:
transformer = CogVideoXTransformer3DModelFunPAB.from_pretrained(base_path, subfolder="transformer")
else:
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder="transformer")
elif "1.5-T2V" in model:
transformer = CogVideoXTransformer3DModel.from_pretrained(transformer_path)
else:
if pab_config is not None:
transformer = CogVideoXTransformer3DModelPAB.from_pretrained(base_path, subfolder="transformer")