This commit is contained in:
kijai 2024-09-19 11:08:18 +03:00
parent 818e31d2d2
commit c3eae2edc1
2 changed files with 20 additions and 12 deletions

View File

@ -195,7 +195,7 @@
"Node name for S&R": "DownloadAndLoadCogVideoModel" "Node name for S&R": "DownloadAndLoadCogVideoModel"
}, },
"widgets_values": [ "widgets_values": [
"kijai/CogVideoX-Fun-pruned", "kijai/CogVideoX-Fun-5b",
"bf16", "bf16",
"disabled", "disabled",
"disabled", "disabled",

View File

@ -32,7 +32,8 @@ class DownloadAndLoadCogVideoModel:
"THUDM/CogVideoX-5b", "THUDM/CogVideoX-5b",
"THUDM/CogVideoX-5b-I2V", "THUDM/CogVideoX-5b-I2V",
"bertjiazheng/KoolCogVideoX-5b", "bertjiazheng/KoolCogVideoX-5b",
"kijai/CogVideoX-Fun-pruned" "kijai/CogVideoX-Fun-2b",
"kijai/CogVideoX-Fun-5b",
], ],
), ),
@ -60,24 +61,32 @@ class DownloadAndLoadCogVideoModel:
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
if "Fun" in model: if "Fun" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP") repo_id = "kijai/CogVideoX-Fun-pruned"
if not os.path.exists(base_path): download_path = os.path.join(folder_paths.models_dir, "CogVideo")
download_path = os.path.join(folder_paths.models_dir, "CogVideo") if "2b" in model:
base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP") base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-2b-InP") # location of the official model
if not os.path.exists(base_path):
base_path = os.path.join(download_path, "CogVideoX-Fun-2b-InP")
elif "5b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP") # location of the official model
if not os.path.exists(base_path):
base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP")
elif "2b" in model: elif "2b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B") base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
download_path = base_path download_path = base_path
repo_id = model
elif "5b" in model: elif "5b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideo", (model.split("/")[-1])) base_path = os.path.join(folder_paths.models_dir, "CogVideo", (model.split("/")[-1]))
download_path = base_path download_path = base_path
repo_id = model
if not os.path.exists(base_path): if not os.path.exists(base_path):
log.info(f"Downloading model to: {base_path}") log.info(f"Downloading model to: {base_path}")
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
snapshot_download( snapshot_download(
repo_id=model, repo_id=repo_id,
ignore_patterns=["*text_encoder*", "*tokenizer*"], ignore_patterns=["*text_encoder*", "*tokenizer*"],
local_dir=download_path, local_dir=download_path,
local_dir_use_symlinks=False, local_dir_use_symlinks=False,
@ -106,16 +115,15 @@ class DownloadAndLoadCogVideoModel:
from .fp8_optimization import convert_fp8_linear from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(transformer, dtype) convert_fp8_linear(transformer, dtype)
if "Fun" in model:
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
else:
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler") scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
if "Fun" in model: if "Fun" in model:
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler) pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
else: else:
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
pipe = CogVideoXPipeline(vae, transformer, scheduler) pipe = CogVideoXPipeline(vae, transformer, scheduler)
if enable_sequential_cpu_offload: if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()