diff --git a/nodes.py b/nodes.py index 9119a83..b7f2a6c 100644 --- a/nodes.py +++ b/nodes.py @@ -326,7 +326,7 @@ class DownloadAndLoadCogVideoModel: else: scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json') - if not os.path.exists(base_path): + if not os.path.exists(base_path) or not os.path.exists(os.path.join(base_path, "transformer")): log.info(f"Downloading model to: {base_path}") from huggingface_hub import snapshot_download @@ -467,6 +467,8 @@ class DownloadAndLoadCogVideoGGUFModel: "CogVideoX_5b_fun_1_1_GGUF_Q4_0.safetensors", "CogVideoX_5b_fun_1_1_Pose_GGUF_Q4_0.safetensors", "CogVideoX_5b_Interpolation_GGUF_Q4_0.safetensors", + "CogVideoX_5b_Tora_GGUF_Q4_0.safetensors", + ], ), "vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}), @@ -499,7 +501,7 @@ class DownloadAndLoadCogVideoGGUFModel: if not os.path.exists(gguf_path): gguf_path = os.path.join(download_path, model) if not os.path.exists(gguf_path): - if "I2V" in model or "1_1" in model or "Interpolation" in model: + if "I2V" in model or "1_1" in model or "Interpolation" in model or "Tora" in model: repo_id = "Kijai/CogVideoX_GGUF" else: repo_id = "MinusZoneAI/ComfyUI-CogVideoX-MZ" @@ -610,6 +612,63 @@ class DownloadAndLoadCogVideoGGUFModel: vae.load_state_dict(vae_sd) pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config) + if "Tora" in model: + import torch.nn as nn + from .tora.traj_module import MGF + + download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora") + fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors") + if not os.path.exists(fuser_path): + log.info(f"Downloading Fuser model to: {fuser_path}") + from huggingface_hub import snapshot_download + + snapshot_download( + repo_id="kijai/CogVideoX-5b-Tora", + allow_patterns=["*fuser.safetensors*"], + local_dir=download_path, + local_dir_use_symlinks=False, + ) + + hidden_size = 3072 + num_layers = transformer.num_layers + pipe.transformer.fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)]) + + fuser_sd = load_torch_file(fuser_path) + pipe.transformer.fuser_list.load_state_dict(fuser_sd) + for module in transformer.fuser_list: + for param in module.parameters(): + param.data = param.data.to(torch.float16) + del fuser_sd + + traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors") + if not os.path.exists(traj_extractor_path): + log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}") + from huggingface_hub import snapshot_download + + snapshot_download( + repo_id="kijai/CogVideoX-5b-Tora", + allow_patterns=["*traj_extractor.safetensors*"], + local_dir=download_path, + local_dir_use_symlinks=False, + ) + + from .tora.traj_module import TrajExtractor + traj_extractor = TrajExtractor( + vae_downsize=(4, 8, 8), + patch_size=2, + nums_rb=2, + cin=vae.config.latent_channels, + channels=[128] * transformer.num_layers, + sk=True, + use_conv=False, + ) + + traj_sd = load_torch_file(traj_extractor_path) + traj_extractor.load_state_dict(traj_sd) + traj_extractor.to(torch.float32).to(device) + + pipe.traj_extractor = traj_extractor + if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() @@ -617,7 +676,7 @@ class DownloadAndLoadCogVideoGGUFModel: "pipe": pipe, "dtype": vae_dtype, "base_path": model, - "onediff": True if compile == "onediff" else False, + "onediff": False, "cpu_offloading": enable_sequential_cpu_offload, "scheduler_config": scheduler_config, "model_name": model