mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
support other tora model
This commit is contained in:
parent
eaaa0f6e1a
commit
8c5e4f812d
@ -971,6 +971,7 @@ class DownloadAndLoadToraModel:
|
|||||||
"model": (
|
"model": (
|
||||||
[
|
[
|
||||||
"kijai/CogVideoX-5b-Tora",
|
"kijai/CogVideoX-5b-Tora",
|
||||||
|
"kijai/CogVideoX-5b-Tora-I2V",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
@ -1000,14 +1001,17 @@ class DownloadAndLoadToraModel:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
|
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
|
||||||
fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors")
|
|
||||||
|
|
||||||
|
fuser_model = "fuser.safetensors" if not "I2V" in model else "fuser_I2V.safetensors"
|
||||||
|
fuser_path = os.path.join(download_path, "fuser", fuser_model)
|
||||||
if not os.path.exists(fuser_path):
|
if not os.path.exists(fuser_path):
|
||||||
log.info(f"Downloading Fuser model to: {fuser_path}")
|
log.info(f"Downloading Fuser model to: {fuser_path}")
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=model,
|
repo_id=model,
|
||||||
allow_patterns=["*fuser.safetensors*"],
|
allow_patterns=[fuser_model],
|
||||||
local_dir=download_path,
|
local_dir=download_path,
|
||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
)
|
)
|
||||||
@ -1029,14 +1033,15 @@ class DownloadAndLoadToraModel:
|
|||||||
param.data = param.data.to(torch.bfloat16).to(device)
|
param.data = param.data.to(torch.bfloat16).to(device)
|
||||||
del fuser_sd
|
del fuser_sd
|
||||||
|
|
||||||
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
|
traj_extractor_model = "traj_extractor.safetensors" if not "I2V" in model else "traj_extractor_I2V.safetensors"
|
||||||
|
traj_extractor_path = os.path.join(download_path, "traj_extractor", traj_extractor_model)
|
||||||
if not os.path.exists(traj_extractor_path):
|
if not os.path.exists(traj_extractor_path):
|
||||||
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
|
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id="kijai/CogVideoX-5b-Tora",
|
repo_id="kijai/CogVideoX-5b-Tora",
|
||||||
allow_patterns=["*traj_extractor.safetensors*"],
|
allow_patterns=[traj_extractor_model],
|
||||||
local_dir=download_path,
|
local_dir=download_path,
|
||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -814,7 +814,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(i, latents.detach()[-1], None, num_inference_steps)
|
callback(i, (latents - noise_pred * (t / 1000)).detach()[0], None, num_inference_steps)
|
||||||
else:
|
else:
|
||||||
comfy_pbar.update(1)
|
comfy_pbar.update(1)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "comfyui-cogvideoxwrapper"
|
name = "comfyui-cogvideoxwrapper"
|
||||||
description = "Diffusers wrapper for CogVideoX -models: [a/https://github.com/THUDM/CogVideo](https://github.com/THUDM/CogVideo)"
|
description = "Diffusers wrapper for CogVideoX -models: [a/https://github.com/THUDM/CogVideo](https://github.com/THUDM/CogVideo)"
|
||||||
version = "1.5.0"
|
version = "1.5.1"
|
||||||
license = {file = "LICENSE"}
|
license = {file = "LICENSE"}
|
||||||
dependencies = ["huggingface_hub", "diffusers>=0.31.0", "accelerate>=0.33.0"]
|
dependencies = ["huggingface_hub", "diffusers>=0.31.0", "accelerate>=0.33.0"]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user