This commit is contained in:
kijai 2024-09-19 11:08:18 +03:00
parent 818e31d2d2
commit c3eae2edc1
2 changed files with 20 additions and 12 deletions

View File

@ -195,7 +195,7 @@
"Node name for S&R": "DownloadAndLoadCogVideoModel"
},
"widgets_values": [
"kijai/CogVideoX-Fun-pruned",
"kijai/CogVideoX-Fun-5b",
"bf16",
"disabled",
"disabled",

View File

@ -32,7 +32,8 @@ class DownloadAndLoadCogVideoModel:
"THUDM/CogVideoX-5b",
"THUDM/CogVideoX-5b-I2V",
"bertjiazheng/KoolCogVideoX-5b",
"kijai/CogVideoX-Fun-pruned"
"kijai/CogVideoX-Fun-2b",
"kijai/CogVideoX-Fun-5b",
],
),
@ -60,24 +61,32 @@ class DownloadAndLoadCogVideoModel:
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
if "Fun" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP")
if not os.path.exists(base_path):
download_path = os.path.join(folder_paths.models_dir, "CogVideo")
base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP")
repo_id = "kijai/CogVideoX-Fun-pruned"
download_path = os.path.join(folder_paths.models_dir, "CogVideo")
if "2b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-2b-InP") # location of the official model
if not os.path.exists(base_path):
base_path = os.path.join(download_path, "CogVideoX-Fun-2b-InP")
elif "5b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP") # location of the official model
if not os.path.exists(base_path):
base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP")
elif "2b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
download_path = base_path
repo_id = model
elif "5b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideo", (model.split("/")[-1]))
download_path = base_path
repo_id = model
if not os.path.exists(base_path):
log.info(f"Downloading model to: {base_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=model,
repo_id=repo_id,
ignore_patterns=["*text_encoder*", "*tokenizer*"],
local_dir=download_path,
local_dir_use_symlinks=False,
@ -105,17 +114,16 @@ class DownloadAndLoadCogVideoModel:
if fp8_transformer == "fastmode":
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(transformer, dtype)
if "Fun" in model:
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
else:
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
if "Fun" in model:
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
else:
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
pipe = CogVideoXPipeline(vae, transformer, scheduler)
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()