diff --git a/README.md b/README.md index aa9e3a1c..c41ef77d 100644 --- a/README.md +++ b/README.md @@ -315,6 +315,8 @@ NODE_CLASS_MAPPINGS.update({ * When you create the `pip_overrides.json` file, it changes the installation of specific pip packages to installations defined by the user. * Please refer to the `pip_overrides.json.template` file. +* Use `aria2` as downloader + * [howto](docs/en/use_aria2.md) ## Scanner When you run the `scan.sh` script: diff --git a/docs/en/use_aria2.md b/docs/en/use_aria2.md new file mode 100644 index 00000000..10a7c6dd --- /dev/null +++ b/docs/en/use_aria2.md @@ -0,0 +1,40 @@ +# Use `aria2` as downloader + +Two environment variables are needed to use `aria2` as the downloader. + +```bash +export COMFYUI_MANAGER_ARIA2_SERVER=http://127.0.0.1:6800 +export COMFYUI_MANAGER_ARIA2_SECRET=__YOU_MUST_CHANGE_IT__ +``` + +An example `docker-compose.yml` + +```yaml +services: + + aria2: + container_name: aria2 + image: p3terx/aria2-pro + environment: + - PUID=1000 + - PGID=1000 + - UMASK_SET=022 + - RPC_SECRET=__YOU_MUST_CHANGE_IT__ + - RPC_PORT=5080 + - DISK_CACHE=64M + - IPV6_MODE=false + - UPDATE_TRACKERS=false + - CUSTOM_TRACKER_URL= + volumes: + - ./config:/config + - ./downloads:/downloads + - ~/ComfyUI/models:/models + - ~/ComfyUI/custom_nodes:/custom_nodes + ports: + - 6800:6800 + restart: unless-stopped + logging: + driver: json-file + options: + max-size: 1m +``` diff --git a/glob/manager_downloader.py b/glob/manager_downloader.py new file mode 100644 index 00000000..56bf8fa6 --- /dev/null +++ b/glob/manager_downloader.py @@ -0,0 +1,67 @@ +import os + +aria2 = os.getenv('COMFYUI_MANAGER_ARIA2_SERVER') +HF_ENDPOINT = os.getenv('HF_ENDPOINT') + +if aria2 is not None: + secret = os.getenv('COMFYUI_MANAGER_ARIA2_SECRET') + host, port = aria2.split(':') + import aria2p + + aria2 = aria2p.API(aria2p.Client(host=host, port=port, secret=secret)) + + +def download_url(model_url: str, model_dir: str, filename: str): + if aria2: + return aria2_download_url(model_url, model_dir, filename) + else: + from torchvision.datasets.utils import download_url as torchvision_download_url + + return torchvision_download_url(model_url, model_dir, filename) + + +def aria2_find_task(dir: str, filename: str): + target = os.path.join(dir, filename) + + downloads = aria2.get_downloads() + + for download in downloads: + for file in download.files: + if file.is_metadata: + continue + if str(file.path) == target: + return download + + +def aria2_download_url(model_url: str, model_dir: str, filename: str): + import manager_core as core + import tqdm + import time + + if model_dir.startswith(core.comfy_path): + model_dir = model_dir[len(core.comfy_path) :] + + if HF_ENDPOINT: + model_url = model_url.replace('https://huggingface.co', HF_ENDPOINT) + + download_dir = model_dir if model_dir.startswith('/') else os.path.join('/models', model_dir) + + download = aria2_find_task(download_dir, filename) + if download is None: + options = {'dir': download_dir, 'out': filename} + download = aria2.add(model_url, options)[0] + + if download.is_active: + with tqdm.tqdm( + total=download.total_length, + bar_format='{l_bar}{bar}{r_bar}', + desc=filename, + unit='B', + unit_scale=True, + ) as progress_bar: + while download.is_active: + if progress_bar.total == 0 and download.total_length != 0: + progress_bar.reset(download.total_length) + progress_bar.update(download.completed_length - progress_bar.n) + time.sleep(1) + download.update() diff --git a/glob/manager_server.py b/glob/manager_server.py index 0c9c67dd..edbe287c 100644 --- a/glob/manager_server.py +++ b/glob/manager_server.py @@ -106,7 +106,7 @@ core.manager_funcs = ManagerFuncsInComfyUI() sys.path.append('../..') -from torchvision.datasets.utils import download_url +from manager_downloader import download_url core.comfy_path = os.path.dirname(folder_paths.__file__) core.js_path = os.path.join(core.comfy_path, "web", "extensions")