diff --git a/nodes.py b/nodes.py index 7bd8186..6c8d46b 100644 --- a/nodes.py +++ b/nodes.py @@ -16,7 +16,13 @@ class DownloadAndLoadCogVideoModel: def INPUT_TYPES(s): return { "required": { - + "model": ( + [ + "THUDM/CogVideoX-2b", + "THUDM/CogVideoX-5b", + ], + {"default": "THUDM/CogVideoX-2b"}, + ), }, "optional": { "precision": ( @@ -35,7 +41,7 @@ class DownloadAndLoadCogVideoModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, precision): + def loadmodel(self, model, precision): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() @@ -49,7 +55,7 @@ class DownloadAndLoadCogVideoModel: from huggingface_hub import snapshot_download snapshot_download( - repo_id="THUDM/CogVideoX-2b", + repo_id=model, ignore_patterns=["*text_encoder*"], local_dir=base_path, local_dir_use_symlinks=False,