diff --git a/nodes/nodes.py b/nodes/nodes.py index 78e7c13..3afb26c 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1332,6 +1332,12 @@ https://huggingface.co/roborovski/superprompt-v1 from transformers import T5Tokenizer, T5ForConditionalGeneration checkpoint_path = os.path.join(script_directory, "models","superprompt-v1") + if not os.path.exists(checkpoint_path): + print(f"Downloading model to: {checkpoint_path}") + from huggingface_hub import snapshot_download + snapshot_download(repo_id="roborovski/superprompt-v1", + local_dir=checkpoint_path, + local_dir_use_symlinks=False) tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small", legacy=False) model = T5ForConditionalGeneration.from_pretrained(checkpoint_path, device_map=device)