Add Tora GGUF

This commit is contained in:
kijai 2024-10-21 03:49:29 +03:00
parent 8ed031417f
commit 01bce7dfff

View File

@ -326,7 +326,7 @@ class DownloadAndLoadCogVideoModel:
else:
scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json')
if not os.path.exists(base_path):
if not os.path.exists(base_path) or not os.path.exists(os.path.join(base_path, "transformer")):
log.info(f"Downloading model to: {base_path}")
from huggingface_hub import snapshot_download
@ -467,6 +467,8 @@ class DownloadAndLoadCogVideoGGUFModel:
"CogVideoX_5b_fun_1_1_GGUF_Q4_0.safetensors",
"CogVideoX_5b_fun_1_1_Pose_GGUF_Q4_0.safetensors",
"CogVideoX_5b_Interpolation_GGUF_Q4_0.safetensors",
"CogVideoX_5b_Tora_GGUF_Q4_0.safetensors",
],
),
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
@ -499,7 +501,7 @@ class DownloadAndLoadCogVideoGGUFModel:
if not os.path.exists(gguf_path):
gguf_path = os.path.join(download_path, model)
if not os.path.exists(gguf_path):
if "I2V" in model or "1_1" in model or "Interpolation" in model:
if "I2V" in model or "1_1" in model or "Interpolation" in model or "Tora" in model:
repo_id = "Kijai/CogVideoX_GGUF"
else:
repo_id = "MinusZoneAI/ComfyUI-CogVideoX-MZ"
@ -610,6 +612,63 @@ class DownloadAndLoadCogVideoGGUFModel:
vae.load_state_dict(vae_sd)
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
if "Tora" in model:
import torch.nn as nn
from .tora.traj_module import MGF
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors")
if not os.path.exists(fuser_path):
log.info(f"Downloading Fuser model to: {fuser_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="kijai/CogVideoX-5b-Tora",
allow_patterns=["*fuser.safetensors*"],
local_dir=download_path,
local_dir_use_symlinks=False,
)
hidden_size = 3072
num_layers = transformer.num_layers
pipe.transformer.fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)])
fuser_sd = load_torch_file(fuser_path)
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
for module in transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(torch.float16)
del fuser_sd
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
if not os.path.exists(traj_extractor_path):
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="kijai/CogVideoX-5b-Tora",
allow_patterns=["*traj_extractor.safetensors*"],
local_dir=download_path,
local_dir_use_symlinks=False,
)
from .tora.traj_module import TrajExtractor
traj_extractor = TrajExtractor(
vae_downsize=(4, 8, 8),
patch_size=2,
nums_rb=2,
cin=vae.config.latent_channels,
channels=[128] * transformer.num_layers,
sk=True,
use_conv=False,
)
traj_sd = load_torch_file(traj_extractor_path)
traj_extractor.load_state_dict(traj_sd)
traj_extractor.to(torch.float32).to(device)
pipe.traj_extractor = traj_extractor
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
@ -617,7 +676,7 @@ class DownloadAndLoadCogVideoGGUFModel:
"pipe": pipe,
"dtype": vae_dtype,
"base_path": model,
"onediff": True if compile == "onediff" else False,
"onediff": False,
"cpu_offloading": enable_sequential_cpu_offload,
"scheduler_config": scheduler_config,
"model_name": model