mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-01-23 20:24:25 +08:00
Add Tora GGUF
This commit is contained in:
parent
8ed031417f
commit
01bce7dfff
65
nodes.py
65
nodes.py
@ -326,7 +326,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
else:
|
||||
scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json')
|
||||
|
||||
if not os.path.exists(base_path):
|
||||
if not os.path.exists(base_path) or not os.path.exists(os.path.join(base_path, "transformer")):
|
||||
log.info(f"Downloading model to: {base_path}")
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
@ -467,6 +467,8 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
"CogVideoX_5b_fun_1_1_GGUF_Q4_0.safetensors",
|
||||
"CogVideoX_5b_fun_1_1_Pose_GGUF_Q4_0.safetensors",
|
||||
"CogVideoX_5b_Interpolation_GGUF_Q4_0.safetensors",
|
||||
"CogVideoX_5b_Tora_GGUF_Q4_0.safetensors",
|
||||
|
||||
],
|
||||
),
|
||||
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
|
||||
@ -499,7 +501,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
if not os.path.exists(gguf_path):
|
||||
gguf_path = os.path.join(download_path, model)
|
||||
if not os.path.exists(gguf_path):
|
||||
if "I2V" in model or "1_1" in model or "Interpolation" in model:
|
||||
if "I2V" in model or "1_1" in model or "Interpolation" in model or "Tora" in model:
|
||||
repo_id = "Kijai/CogVideoX_GGUF"
|
||||
else:
|
||||
repo_id = "MinusZoneAI/ComfyUI-CogVideoX-MZ"
|
||||
@ -610,6 +612,63 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
vae.load_state_dict(vae_sd)
|
||||
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
||||
|
||||
if "Tora" in model:
|
||||
import torch.nn as nn
|
||||
from .tora.traj_module import MGF
|
||||
|
||||
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
|
||||
fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors")
|
||||
if not os.path.exists(fuser_path):
|
||||
log.info(f"Downloading Fuser model to: {fuser_path}")
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="kijai/CogVideoX-5b-Tora",
|
||||
allow_patterns=["*fuser.safetensors*"],
|
||||
local_dir=download_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
hidden_size = 3072
|
||||
num_layers = transformer.num_layers
|
||||
pipe.transformer.fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)])
|
||||
|
||||
fuser_sd = load_torch_file(fuser_path)
|
||||
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
|
||||
for module in transformer.fuser_list:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.to(torch.float16)
|
||||
del fuser_sd
|
||||
|
||||
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
|
||||
if not os.path.exists(traj_extractor_path):
|
||||
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="kijai/CogVideoX-5b-Tora",
|
||||
allow_patterns=["*traj_extractor.safetensors*"],
|
||||
local_dir=download_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
from .tora.traj_module import TrajExtractor
|
||||
traj_extractor = TrajExtractor(
|
||||
vae_downsize=(4, 8, 8),
|
||||
patch_size=2,
|
||||
nums_rb=2,
|
||||
cin=vae.config.latent_channels,
|
||||
channels=[128] * transformer.num_layers,
|
||||
sk=True,
|
||||
use_conv=False,
|
||||
)
|
||||
|
||||
traj_sd = load_torch_file(traj_extractor_path)
|
||||
traj_extractor.load_state_dict(traj_sd)
|
||||
traj_extractor.to(torch.float32).to(device)
|
||||
|
||||
pipe.traj_extractor = traj_extractor
|
||||
|
||||
if enable_sequential_cpu_offload:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
@ -617,7 +676,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
"pipe": pipe,
|
||||
"dtype": vae_dtype,
|
||||
"base_path": model,
|
||||
"onediff": True if compile == "onediff" else False,
|
||||
"onediff": False,
|
||||
"cpu_offloading": enable_sequential_cpu_offload,
|
||||
"scheduler_config": scheduler_config,
|
||||
"model_name": model
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user