support other tora model

This commit is contained in:
kijai 2025-01-14 14:39:44 +02:00
parent eaaa0f6e1a
commit 8c5e4f812d
3 changed files with 11 additions and 6 deletions

View File

@ -971,6 +971,7 @@ class DownloadAndLoadToraModel:
"model": (
[
"kijai/CogVideoX-5b-Tora",
"kijai/CogVideoX-5b-Tora-I2V",
],
),
},
@ -1000,14 +1001,17 @@ class DownloadAndLoadToraModel:
pass
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors")
fuser_model = "fuser.safetensors" if not "I2V" in model else "fuser_I2V.safetensors"
fuser_path = os.path.join(download_path, "fuser", fuser_model)
if not os.path.exists(fuser_path):
log.info(f"Downloading Fuser model to: {fuser_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=model,
allow_patterns=["*fuser.safetensors*"],
allow_patterns=[fuser_model],
local_dir=download_path,
local_dir_use_symlinks=False,
)
@ -1029,14 +1033,15 @@ class DownloadAndLoadToraModel:
param.data = param.data.to(torch.bfloat16).to(device)
del fuser_sd
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
traj_extractor_model = "traj_extractor.safetensors" if not "I2V" in model else "traj_extractor_I2V.safetensors"
traj_extractor_path = os.path.join(download_path, "traj_extractor", traj_extractor_model)
if not os.path.exists(traj_extractor_path):
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="kijai/CogVideoX-5b-Tora",
allow_patterns=["*traj_extractor.safetensors*"],
allow_patterns=[traj_extractor_model],
local_dir=download_path,
local_dir_use_symlinks=False,
)

View File

@ -814,7 +814,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None:
callback(i, latents.detach()[-1], None, num_inference_steps)
callback(i, (latents - noise_pred * (t / 1000)).detach()[0], None, num_inference_steps)
else:
comfy_pbar.update(1)

View File

@ -1,7 +1,7 @@
[project]
name = "comfyui-cogvideoxwrapper"
description = "Diffusers wrapper for CogVideoX -models: [a/https://github.com/THUDM/CogVideo](https://github.com/THUDM/CogVideo)"
version = "1.5.0"
version = "1.5.1"
license = {file = "LICENSE"}
dependencies = ["huggingface_hub", "diffusers>=0.31.0", "accelerate>=0.33.0"]