mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2026-06-22 07:36:58 +08:00
Compare commits
No commits in common. "master" and "latest" have entirely different histories.
3
.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat
Executable file
3
.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
||||||
|
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
|
||||||
|
pause
|
||||||
@ -1,2 +1,2 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --use-pytorch-cross-attention
|
||||||
pause
|
pause
|
||||||
@ -1,9 +1,6 @@
|
|||||||
import pygit2
|
import pygit2
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import sys
|
import sys
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import filecmp
|
|
||||||
|
|
||||||
def pull(repo, remote_name='origin', branch='master'):
|
def pull(repo, remote_name='origin', branch='master'):
|
||||||
for remote in repo.remotes:
|
for remote in repo.remotes:
|
||||||
@ -28,137 +25,41 @@ def pull(repo, remote_name='origin', branch='master'):
|
|||||||
|
|
||||||
if repo.index.conflicts is not None:
|
if repo.index.conflicts is not None:
|
||||||
for conflict in repo.index.conflicts:
|
for conflict in repo.index.conflicts:
|
||||||
print('Conflicts found in:', conflict[0].path) # noqa: T201
|
print('Conflicts found in:', conflict[0].path)
|
||||||
raise AssertionError('Conflicts, ahhhhh!!')
|
raise AssertionError('Conflicts, ahhhhh!!')
|
||||||
|
|
||||||
user = repo.default_signature
|
user = repo.default_signature
|
||||||
tree = repo.index.write_tree()
|
tree = repo.index.write_tree()
|
||||||
repo.create_commit('HEAD',
|
commit = repo.create_commit('HEAD',
|
||||||
user,
|
user,
|
||||||
user,
|
user,
|
||||||
'Merge!',
|
'Merge!',
|
||||||
tree,
|
tree,
|
||||||
[repo.head.target, remote_master_id])
|
[repo.head.target, remote_master_id])
|
||||||
# We need to do this or git CLI will think we are still merging.
|
# We need to do this or git CLI will think we are still merging.
|
||||||
repo.state_cleanup()
|
repo.state_cleanup()
|
||||||
else:
|
else:
|
||||||
raise AssertionError('Unknown merge analysis result')
|
raise AssertionError('Unknown merge analysis result')
|
||||||
|
|
||||||
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
|
|
||||||
repo_path = str(sys.argv[1])
|
repo = pygit2.Repository(str(sys.argv[1]))
|
||||||
repo = pygit2.Repository(repo_path)
|
|
||||||
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
||||||
try:
|
try:
|
||||||
print("stashing current changes") # noqa: T201
|
print("stashing current changes")
|
||||||
repo.stash(ident)
|
repo.stash(ident)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("nothing to stash") # noqa: T201
|
print("nothing to stash")
|
||||||
except:
|
|
||||||
print("Could not stash, cleaning index and trying again.") # noqa: T201
|
|
||||||
repo.state_cleanup()
|
|
||||||
repo.index.read_tree(repo.head.peel().tree)
|
|
||||||
repo.index.write()
|
|
||||||
try:
|
|
||||||
repo.stash(ident)
|
|
||||||
except KeyError:
|
|
||||||
print("nothing to stash.") # noqa: T201
|
|
||||||
|
|
||||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
print("creating backup branch: {}".format(backup_branch_name))
|
||||||
try:
|
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
||||||
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
print("checking out master branch") # noqa: T201
|
print("checking out master branch")
|
||||||
branch = repo.lookup_branch('master')
|
branch = repo.lookup_branch('master')
|
||||||
if branch is None:
|
ref = repo.lookup_reference(branch.name)
|
||||||
try:
|
repo.checkout(ref)
|
||||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
|
||||||
except:
|
|
||||||
print("fetching.") # noqa: T201
|
|
||||||
for remote in repo.remotes:
|
|
||||||
if remote.name == "origin":
|
|
||||||
remote.fetch()
|
|
||||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
|
||||||
repo.checkout(ref)
|
|
||||||
branch = repo.lookup_branch('master')
|
|
||||||
if branch is None:
|
|
||||||
repo.create_branch('master', repo.get(ref.target))
|
|
||||||
else:
|
|
||||||
ref = repo.lookup_reference(branch.name)
|
|
||||||
repo.checkout(ref)
|
|
||||||
|
|
||||||
print("pulling latest changes") # noqa: T201
|
print("pulling latest changes")
|
||||||
pull(repo)
|
pull(repo)
|
||||||
|
|
||||||
if "--stable" in sys.argv:
|
print("Done!")
|
||||||
def latest_tag(repo):
|
|
||||||
versions = []
|
|
||||||
for k in repo.references:
|
|
||||||
try:
|
|
||||||
prefix = "refs/tags/v"
|
|
||||||
if k.startswith(prefix):
|
|
||||||
version = list(map(int, k[len(prefix):].split(".")))
|
|
||||||
versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
versions.sort()
|
|
||||||
if len(versions) > 0:
|
|
||||||
return versions[-1][1]
|
|
||||||
return None
|
|
||||||
latest_tag = latest_tag(repo)
|
|
||||||
if latest_tag is not None:
|
|
||||||
repo.checkout(latest_tag)
|
|
||||||
|
|
||||||
print("Done!") # noqa: T201
|
|
||||||
|
|
||||||
self_update = True
|
|
||||||
if len(sys.argv) > 2:
|
|
||||||
self_update = '--skip_self_update' not in sys.argv
|
|
||||||
|
|
||||||
update_py_path = os.path.realpath(__file__)
|
|
||||||
repo_update_py_path = os.path.join(repo_path, ".ci/update_windows/update.py")
|
|
||||||
|
|
||||||
cur_path = os.path.dirname(update_py_path)
|
|
||||||
|
|
||||||
|
|
||||||
req_path = os.path.join(cur_path, "current_requirements.txt")
|
|
||||||
repo_req_path = os.path.join(repo_path, "requirements.txt")
|
|
||||||
|
|
||||||
|
|
||||||
def files_equal(file1, file2):
|
|
||||||
try:
|
|
||||||
return filecmp.cmp(file1, file2, shallow=False)
|
|
||||||
except:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def file_size(f):
|
|
||||||
try:
|
|
||||||
return os.path.getsize(f)
|
|
||||||
except:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
if self_update and not files_equal(update_py_path, repo_update_py_path) and file_size(repo_update_py_path) > 10:
|
|
||||||
shutil.copy(repo_update_py_path, os.path.join(cur_path, "update_new.py"))
|
|
||||||
exit()
|
|
||||||
|
|
||||||
if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
|
|
||||||
import subprocess
|
|
||||||
try:
|
|
||||||
subprocess.check_call([sys.executable, '-s', '-m', 'pip', 'install', '-r', repo_req_path])
|
|
||||||
shutil.copy(repo_req_path, req_path)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
|
|
||||||
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not file_size(stable_update_script_to) > 10:
|
|
||||||
shutil.copy(stable_update_script, stable_update_script_to)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,2 @@
|
|||||||
@echo off
|
|
||||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
||||||
if exist update_new.py (
|
pause
|
||||||
move /y update_new.py update.py
|
|
||||||
echo Running updater again since it got updated.
|
|
||||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update
|
|
||||||
)
|
|
||||||
if "%~1"=="" pause
|
|
||||||
|
|||||||
3
.ci/update_windows/update_comfyui_and_python_dependencies.bat
Executable file
3
.ci/update_windows/update_comfyui_and_python_dependencies.bat
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
||||||
|
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 xformers -r ../ComfyUI/requirements.txt pygit2
|
||||||
|
pause
|
||||||
@ -1,8 +0,0 @@
|
|||||||
@echo off
|
|
||||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
|
|
||||||
if exist update_new.py (
|
|
||||||
move /y update_new.py update.py
|
|
||||||
echo Running updater again since it got updated.
|
|
||||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
|
|
||||||
)
|
|
||||||
if "%~1"=="" pause
|
|
||||||
11
.ci/update_windows_cu118/update_comfyui_and_python_dependencies.bat
Executable file
11
.ci/update_windows_cu118/update_comfyui_and_python_dependencies.bat
Executable file
@ -0,0 +1,11 @@
|
|||||||
|
@echo off
|
||||||
|
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
||||||
|
echo
|
||||||
|
echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff
|
||||||
|
echo You should not be running this anyways unless you really have to
|
||||||
|
echo
|
||||||
|
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
||||||
|
echo
|
||||||
|
pause
|
||||||
|
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 xformers -r ../ComfyUI/requirements.txt pygit2
|
||||||
|
pause
|
||||||
@ -1,28 +0,0 @@
|
|||||||
As of the time of writing this you need this driver for best results:
|
|
||||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
|
|
||||||
|
|
||||||
HOW TO RUN:
|
|
||||||
|
|
||||||
If you have a AMD gpu:
|
|
||||||
|
|
||||||
run_amd_gpu.bat
|
|
||||||
|
|
||||||
If you have memory issues you can try disabling the smart memory management by running comfyui with:
|
|
||||||
|
|
||||||
run_amd_gpu_disable_smart_memory.bat
|
|
||||||
|
|
||||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
|
||||||
|
|
||||||
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
|
|
||||||
|
|
||||||
|
|
||||||
RECOMMENDED WAY TO UPDATE:
|
|
||||||
To update the ComfyUI code: update\update_comfyui.bat
|
|
||||||
|
|
||||||
|
|
||||||
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
|
|
||||||
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
|
|
||||||
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -4,9 +4,6 @@ if you have a NVIDIA gpu:
|
|||||||
|
|
||||||
run_nvidia_gpu.bat
|
run_nvidia_gpu.bat
|
||||||
|
|
||||||
if you want to enable the fast fp16 accumulation (faster for fp16 models with slightly less quality):
|
|
||||||
|
|
||||||
run_nvidia_gpu_fast_fp16_accumulation.bat
|
|
||||||
|
|
||||||
|
|
||||||
To run it in slow CPU mode:
|
To run it in slow CPU mode:
|
||||||
@ -17,7 +14,7 @@ run_cpu.bat
|
|||||||
|
|
||||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||||
|
|
||||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
|
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
|
||||||
|
|
||||||
|
|
||||||
RECOMMENDED WAY TO UPDATE:
|
RECOMMENDED WAY TO UPDATE:
|
||||||
@ -1,2 +0,0 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
|
|
||||||
pause
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
|
||||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
|
||||||
pause
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
|
||||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
|
||||||
pause
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
|
||||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
|
||||||
pause
|
|
||||||
3
.gitattributes
vendored
3
.gitattributes
vendored
@ -1,3 +0,0 @@
|
|||||||
/web/assets/** linguist-generated
|
|
||||||
/web/** linguist-vendored
|
|
||||||
comfy_api_nodes/apis/__init__.py linguist-generated
|
|
||||||
58
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
58
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -1,58 +0,0 @@
|
|||||||
name: Bug Report
|
|
||||||
description: "Something is broken inside of ComfyUI. (Do not use this if you're just having issues and need help, or if the issue relates to a custom node)"
|
|
||||||
labels: ["Potential Bug"]
|
|
||||||
body:
|
|
||||||
- type: markdown
|
|
||||||
attributes:
|
|
||||||
value: |
|
|
||||||
Before submitting a **Bug Report**, please ensure the following:
|
|
||||||
|
|
||||||
- **1:** You are running the latest version of ComfyUI.
|
|
||||||
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
|
|
||||||
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
|
|
||||||
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
|
|
||||||
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
|
|
||||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
|
||||||
|
|
||||||
## Very Important
|
|
||||||
|
|
||||||
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
|
||||||
- type: checkboxes
|
|
||||||
id: custom-nodes-test
|
|
||||||
attributes:
|
|
||||||
label: Custom Node Testing
|
|
||||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
|
||||||
options:
|
|
||||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
|
||||||
required: false
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: Expected Behavior
|
|
||||||
description: "What you expected to happen."
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: Actual Behavior
|
|
||||||
description: "What actually happened. Please include a screenshot of the issue if possible."
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: Steps to Reproduce
|
|
||||||
description: "Describe how to reproduce the issue. Please be sure to attach a workflow JSON or PNG, ideally one that doesn't require custom nodes to test. If the bug open happens when certain custom nodes are used, most likely that custom node is what has the bug rather than ComfyUI, in which case it should be reported to the node's author."
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: Debug Logs
|
|
||||||
description: "Please copy the output from your terminal logs here."
|
|
||||||
render: powershell
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: Other
|
|
||||||
description: "Any other additional information you think might be helpful."
|
|
||||||
validations:
|
|
||||||
required: false
|
|
||||||
11
.github/ISSUE_TEMPLATE/config.yml
vendored
11
.github/ISSUE_TEMPLATE/config.yml
vendored
@ -1,11 +0,0 @@
|
|||||||
blank_issues_enabled: true
|
|
||||||
contact_links:
|
|
||||||
- name: ComfyUI Frontend Issues
|
|
||||||
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
|
|
||||||
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
|
|
||||||
- name: ComfyUI Matrix Space
|
|
||||||
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
|
||||||
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).
|
|
||||||
- name: Comfy Org Discord
|
|
||||||
url: https://discord.gg/comfyorg
|
|
||||||
about: The Comfy Org Discord is available for support and general discussion related to ComfyUI.
|
|
||||||
32
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
32
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
@ -1,32 +0,0 @@
|
|||||||
name: Feature Request
|
|
||||||
description: "You have an idea for something new you would like to see added to ComfyUI's core."
|
|
||||||
labels: [ "Feature" ]
|
|
||||||
body:
|
|
||||||
- type: markdown
|
|
||||||
attributes:
|
|
||||||
value: |
|
|
||||||
Before submitting a **Feature Request**, please ensure the following:
|
|
||||||
|
|
||||||
**1:** You are running the latest version of ComfyUI.
|
|
||||||
**2:** You have looked to make sure there is not already a feature that does what you need, and there is not already a Feature Request listed for the same idea.
|
|
||||||
**3:** This is something that makes sense to add to ComfyUI Core, and wouldn't make more sense as a custom node.
|
|
||||||
|
|
||||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: Feature Idea
|
|
||||||
description: "Describe the feature you want to see."
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: Existing Solutions
|
|
||||||
description: "Please search through available custom nodes / extensions to see if there are existing custom solutions for this. If so, please link the options you found here as a reference."
|
|
||||||
validations:
|
|
||||||
required: false
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: Other
|
|
||||||
description: "Any other additional information you think might be helpful."
|
|
||||||
validations:
|
|
||||||
required: false
|
|
||||||
40
.github/ISSUE_TEMPLATE/user-support.yml
vendored
40
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@ -1,40 +0,0 @@
|
|||||||
name: User Support
|
|
||||||
description: "Use this if you need help with something, or you're experiencing an issue."
|
|
||||||
labels: [ "User Support" ]
|
|
||||||
body:
|
|
||||||
- type: markdown
|
|
||||||
attributes:
|
|
||||||
value: |
|
|
||||||
Before submitting a **User Report** issue, please ensure the following:
|
|
||||||
|
|
||||||
**1:** You are running the latest version of ComfyUI.
|
|
||||||
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
|
|
||||||
|
|
||||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
|
||||||
- type: checkboxes
|
|
||||||
id: custom-nodes-test
|
|
||||||
attributes:
|
|
||||||
label: Custom Node Testing
|
|
||||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
|
||||||
options:
|
|
||||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
|
||||||
required: false
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: Your question
|
|
||||||
description: "Post your question here. Please be as detailed as possible."
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: Logs
|
|
||||||
description: "If your question relates to an issue you're experiencing, please go to `Server` -> `Logs` -> potentially set `View Type` to `Debug` as well, then copypaste all the text into here."
|
|
||||||
render: powershell
|
|
||||||
validations:
|
|
||||||
required: false
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: Other
|
|
||||||
description: "Any other additional information you think might be helpful."
|
|
||||||
validations:
|
|
||||||
required: false
|
|
||||||
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
@ -1,21 +0,0 @@
|
|||||||
<!-- API_NODE_PR_CHECKLIST: do not remove -->
|
|
||||||
|
|
||||||
## API Node PR Checklist
|
|
||||||
|
|
||||||
### Scope
|
|
||||||
- [ ] **Is API Node Change**
|
|
||||||
|
|
||||||
### Pricing & Billing
|
|
||||||
- [ ] **Need pricing update**
|
|
||||||
- [ ] **No pricing update**
|
|
||||||
|
|
||||||
If **Need pricing update**:
|
|
||||||
- [ ] Metronome rate cards updated
|
|
||||||
- [ ] Auto‑billing tests updated and passing
|
|
||||||
|
|
||||||
### QA
|
|
||||||
- [ ] **QA done**
|
|
||||||
- [ ] **QA not required**
|
|
||||||
|
|
||||||
### Comms
|
|
||||||
- [ ] Informed **Kosinkadink**
|
|
||||||
58
.github/workflows/api-node-template.yml
vendored
58
.github/workflows/api-node-template.yml
vendored
@ -1,58 +0,0 @@
|
|||||||
name: Append API Node PR template
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request_target:
|
|
||||||
types: [opened, reopened, synchronize, ready_for_review]
|
|
||||||
paths:
|
|
||||||
- 'comfy_api_nodes/**' # only run if these files changed
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
pull-requests: write
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
inject:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Ensure template exists and append to PR body
|
|
||||||
uses: actions/github-script@v7
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const { owner, repo } = context.repo;
|
|
||||||
const number = context.payload.pull_request.number;
|
|
||||||
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
|
|
||||||
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
|
|
||||||
|
|
||||||
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
|
|
||||||
|
|
||||||
let templateText;
|
|
||||||
try {
|
|
||||||
const res = await github.rest.repos.getContent({
|
|
||||||
owner,
|
|
||||||
repo,
|
|
||||||
path: templatePath,
|
|
||||||
ref: pr.base.ref
|
|
||||||
});
|
|
||||||
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
|
|
||||||
templateText = buf.toString('utf8');
|
|
||||||
} catch (e) {
|
|
||||||
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enforce the presence of the marker inside the template (for idempotence)
|
|
||||||
if (!templateText.includes(marker)) {
|
|
||||||
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the PR already contains the marker, do not append again.
|
|
||||||
const body = pr.body || '';
|
|
||||||
if (body.includes(marker)) {
|
|
||||||
core.info('Template already present in PR body; nothing to inject.');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
|
|
||||||
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
|
|
||||||
core.notice('API Node template appended to PR description.');
|
|
||||||
40
.github/workflows/check-line-endings.yml
vendored
40
.github/workflows/check-line-endings.yml
vendored
@ -1,40 +0,0 @@
|
|||||||
name: Check for Windows Line Endings
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
branches: ['*'] # Trigger on all pull requests to any branch
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check-line-endings:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0 # Fetch all history to compare changes
|
|
||||||
|
|
||||||
- name: Check for Windows line endings (CRLF)
|
|
||||||
run: |
|
|
||||||
# Get the list of changed files in the PR
|
|
||||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
|
|
||||||
|
|
||||||
# Flag to track if CRLF is found
|
|
||||||
CRLF_FOUND=false
|
|
||||||
|
|
||||||
# Loop through each changed file
|
|
||||||
for FILE in $CHANGED_FILES; do
|
|
||||||
# Check if the file exists and is a text file
|
|
||||||
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
|
|
||||||
# Check for CRLF line endings
|
|
||||||
if grep -UP '\r$' "$FILE"; then
|
|
||||||
echo "Error: Windows line endings (CRLF) detected in $FILE"
|
|
||||||
CRLF_FOUND=true
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# Exit with error if CRLF was found
|
|
||||||
if [ "$CRLF_FOUND" = true ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
53
.github/workflows/pullrequest-ci-run.yml
vendored
53
.github/workflows/pullrequest-ci-run.yml
vendored
@ -1,53 +0,0 @@
|
|||||||
# This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added
|
|
||||||
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
|
|
||||||
name: Pull Request CI Workflow Runs
|
|
||||||
on:
|
|
||||||
pull_request_target:
|
|
||||||
types: [labeled]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
pr-test-stable:
|
|
||||||
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [macos, linux, windows]
|
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
|
||||||
cuda_version: ["12.1"]
|
|
||||||
torch_version: ["stable"]
|
|
||||||
include:
|
|
||||||
- os: macos
|
|
||||||
runner_label: [self-hosted, macOS]
|
|
||||||
flags: "--use-pytorch-cross-attention"
|
|
||||||
- os: linux
|
|
||||||
runner_label: [self-hosted, Linux]
|
|
||||||
flags: ""
|
|
||||||
- os: windows
|
|
||||||
runner_label: [self-hosted, Windows]
|
|
||||||
flags: ""
|
|
||||||
runs-on: ${{ matrix.runner_label }}
|
|
||||||
steps:
|
|
||||||
- name: Test Workflows
|
|
||||||
uses: comfy-org/comfy-action@main
|
|
||||||
with:
|
|
||||||
os: ${{ matrix.os }}
|
|
||||||
python_version: ${{ matrix.python_version }}
|
|
||||||
torch_version: ${{ matrix.torch_version }}
|
|
||||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
|
||||||
comfyui_flags: ${{ matrix.flags }}
|
|
||||||
use_prior_commit: 'true'
|
|
||||||
comment:
|
|
||||||
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
pull-requests: write
|
|
||||||
steps:
|
|
||||||
- uses: actions/github-script@v6
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
github.rest.issues.createComment({
|
|
||||||
issue_number: context.issue.number,
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
|
|
||||||
})
|
|
||||||
78
.github/workflows/release-stable-all.yml
vendored
78
.github/workflows/release-stable-all.yml
vendored
@ -1,78 +0,0 @@
|
|||||||
name: "Release Stable All Portable Versions"
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
git_tag:
|
|
||||||
description: 'Git tag'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
release_nvidia_default:
|
|
||||||
permissions:
|
|
||||||
contents: "write"
|
|
||||||
packages: "write"
|
|
||||||
pull-requests: "read"
|
|
||||||
name: "Release NVIDIA Default (cu130)"
|
|
||||||
uses: ./.github/workflows/stable-release.yml
|
|
||||||
with:
|
|
||||||
git_tag: ${{ inputs.git_tag }}
|
|
||||||
cache_tag: "cu130"
|
|
||||||
python_minor: "13"
|
|
||||||
python_patch: "9"
|
|
||||||
rel_name: "nvidia"
|
|
||||||
rel_extra_name: ""
|
|
||||||
test_release: true
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
release_nvidia_cu128:
|
|
||||||
permissions:
|
|
||||||
contents: "write"
|
|
||||||
packages: "write"
|
|
||||||
pull-requests: "read"
|
|
||||||
name: "Release NVIDIA cu128"
|
|
||||||
uses: ./.github/workflows/stable-release.yml
|
|
||||||
with:
|
|
||||||
git_tag: ${{ inputs.git_tag }}
|
|
||||||
cache_tag: "cu128"
|
|
||||||
python_minor: "12"
|
|
||||||
python_patch: "10"
|
|
||||||
rel_name: "nvidia"
|
|
||||||
rel_extra_name: "_cu128"
|
|
||||||
test_release: true
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
release_nvidia_cu126:
|
|
||||||
permissions:
|
|
||||||
contents: "write"
|
|
||||||
packages: "write"
|
|
||||||
pull-requests: "read"
|
|
||||||
name: "Release NVIDIA cu126"
|
|
||||||
uses: ./.github/workflows/stable-release.yml
|
|
||||||
with:
|
|
||||||
git_tag: ${{ inputs.git_tag }}
|
|
||||||
cache_tag: "cu126"
|
|
||||||
python_minor: "12"
|
|
||||||
python_patch: "10"
|
|
||||||
rel_name: "nvidia"
|
|
||||||
rel_extra_name: "_cu126"
|
|
||||||
test_release: true
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
release_amd_rocm:
|
|
||||||
permissions:
|
|
||||||
contents: "write"
|
|
||||||
packages: "write"
|
|
||||||
pull-requests: "read"
|
|
||||||
name: "Release AMD ROCm 7.1.1"
|
|
||||||
uses: ./.github/workflows/stable-release.yml
|
|
||||||
with:
|
|
||||||
git_tag: ${{ inputs.git_tag }}
|
|
||||||
cache_tag: "rocm711"
|
|
||||||
python_minor: "12"
|
|
||||||
python_patch: "10"
|
|
||||||
rel_name: "amd"
|
|
||||||
rel_extra_name: ""
|
|
||||||
test_release: false
|
|
||||||
secrets: inherit
|
|
||||||
108
.github/workflows/release-webhook.yml
vendored
108
.github/workflows/release-webhook.yml
vendored
@ -1,108 +0,0 @@
|
|||||||
name: Release Webhook
|
|
||||||
|
|
||||||
on:
|
|
||||||
release:
|
|
||||||
types: [published]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
send-webhook:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Send release webhook
|
|
||||||
env:
|
|
||||||
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
|
|
||||||
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
|
|
||||||
run: |
|
|
||||||
# Generate UUID for delivery ID
|
|
||||||
DELIVERY_ID=$(uuidgen)
|
|
||||||
HOOK_ID="release-webhook-$(date +%s)"
|
|
||||||
|
|
||||||
# Create webhook payload matching GitHub release webhook format
|
|
||||||
PAYLOAD=$(cat <<EOF
|
|
||||||
{
|
|
||||||
"action": "published",
|
|
||||||
"release": {
|
|
||||||
"id": ${{ github.event.release.id }},
|
|
||||||
"node_id": "${{ github.event.release.node_id }}",
|
|
||||||
"url": "${{ github.event.release.url }}",
|
|
||||||
"html_url": "${{ github.event.release.html_url }}",
|
|
||||||
"assets_url": "${{ github.event.release.assets_url }}",
|
|
||||||
"upload_url": "${{ github.event.release.upload_url }}",
|
|
||||||
"tag_name": "${{ github.event.release.tag_name }}",
|
|
||||||
"target_commitish": "${{ github.event.release.target_commitish }}",
|
|
||||||
"name": ${{ toJSON(github.event.release.name) }},
|
|
||||||
"body": ${{ toJSON(github.event.release.body) }},
|
|
||||||
"draft": ${{ github.event.release.draft }},
|
|
||||||
"prerelease": ${{ github.event.release.prerelease }},
|
|
||||||
"created_at": "${{ github.event.release.created_at }}",
|
|
||||||
"published_at": "${{ github.event.release.published_at }}",
|
|
||||||
"author": {
|
|
||||||
"login": "${{ github.event.release.author.login }}",
|
|
||||||
"id": ${{ github.event.release.author.id }},
|
|
||||||
"node_id": "${{ github.event.release.author.node_id }}",
|
|
||||||
"avatar_url": "${{ github.event.release.author.avatar_url }}",
|
|
||||||
"url": "${{ github.event.release.author.url }}",
|
|
||||||
"html_url": "${{ github.event.release.author.html_url }}",
|
|
||||||
"type": "${{ github.event.release.author.type }}",
|
|
||||||
"site_admin": ${{ github.event.release.author.site_admin }}
|
|
||||||
},
|
|
||||||
"tarball_url": "${{ github.event.release.tarball_url }}",
|
|
||||||
"zipball_url": "${{ github.event.release.zipball_url }}",
|
|
||||||
"assets": ${{ toJSON(github.event.release.assets) }}
|
|
||||||
},
|
|
||||||
"repository": {
|
|
||||||
"id": ${{ github.event.repository.id }},
|
|
||||||
"node_id": "${{ github.event.repository.node_id }}",
|
|
||||||
"name": "${{ github.event.repository.name }}",
|
|
||||||
"full_name": "${{ github.event.repository.full_name }}",
|
|
||||||
"private": ${{ github.event.repository.private }},
|
|
||||||
"owner": {
|
|
||||||
"login": "${{ github.event.repository.owner.login }}",
|
|
||||||
"id": ${{ github.event.repository.owner.id }},
|
|
||||||
"node_id": "${{ github.event.repository.owner.node_id }}",
|
|
||||||
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
|
|
||||||
"url": "${{ github.event.repository.owner.url }}",
|
|
||||||
"html_url": "${{ github.event.repository.owner.html_url }}",
|
|
||||||
"type": "${{ github.event.repository.owner.type }}",
|
|
||||||
"site_admin": ${{ github.event.repository.owner.site_admin }}
|
|
||||||
},
|
|
||||||
"html_url": "${{ github.event.repository.html_url }}",
|
|
||||||
"clone_url": "${{ github.event.repository.clone_url }}",
|
|
||||||
"git_url": "${{ github.event.repository.git_url }}",
|
|
||||||
"ssh_url": "${{ github.event.repository.ssh_url }}",
|
|
||||||
"url": "${{ github.event.repository.url }}",
|
|
||||||
"created_at": "${{ github.event.repository.created_at }}",
|
|
||||||
"updated_at": "${{ github.event.repository.updated_at }}",
|
|
||||||
"pushed_at": "${{ github.event.repository.pushed_at }}",
|
|
||||||
"default_branch": "${{ github.event.repository.default_branch }}",
|
|
||||||
"fork": ${{ github.event.repository.fork }}
|
|
||||||
},
|
|
||||||
"sender": {
|
|
||||||
"login": "${{ github.event.sender.login }}",
|
|
||||||
"id": ${{ github.event.sender.id }},
|
|
||||||
"node_id": "${{ github.event.sender.node_id }}",
|
|
||||||
"avatar_url": "${{ github.event.sender.avatar_url }}",
|
|
||||||
"url": "${{ github.event.sender.url }}",
|
|
||||||
"html_url": "${{ github.event.sender.html_url }}",
|
|
||||||
"type": "${{ github.event.sender.type }}",
|
|
||||||
"site_admin": ${{ github.event.sender.site_admin }}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
EOF
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate HMAC-SHA256 signature
|
|
||||||
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
|
|
||||||
|
|
||||||
# Send webhook with required headers
|
|
||||||
curl -X POST "$WEBHOOK_URL" \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "X-GitHub-Event: release" \
|
|
||||||
-H "X-GitHub-Delivery: $DELIVERY_ID" \
|
|
||||||
-H "X-GitHub-Hook-ID: $HOOK_ID" \
|
|
||||||
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
|
|
||||||
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
|
|
||||||
-d "$PAYLOAD" \
|
|
||||||
--fail --silent --show-error
|
|
||||||
|
|
||||||
echo "✅ Release webhook sent successfully"
|
|
||||||
48
.github/workflows/ruff.yml
vendored
48
.github/workflows/ruff.yml
vendored
@ -1,48 +0,0 @@
|
|||||||
name: Python Linting
|
|
||||||
|
|
||||||
on: [push, pull_request]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
ruff:
|
|
||||||
name: Run Ruff
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v2
|
|
||||||
with:
|
|
||||||
python-version: 3.x
|
|
||||||
|
|
||||||
- name: Install Ruff
|
|
||||||
run: pip install ruff
|
|
||||||
|
|
||||||
- name: Run Ruff
|
|
||||||
run: ruff check .
|
|
||||||
|
|
||||||
pylint:
|
|
||||||
name: Run Pylint
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: '3.12'
|
|
||||||
|
|
||||||
- name: Install requirements
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
|
||||||
pip install -r requirements.txt
|
|
||||||
|
|
||||||
- name: Install Pylint
|
|
||||||
run: pip install pylint
|
|
||||||
|
|
||||||
- name: Run Pylint
|
|
||||||
run: pylint comfy_api_nodes
|
|
||||||
170
.github/workflows/stable-release.yml
vendored
170
.github/workflows/stable-release.yml
vendored
@ -1,170 +0,0 @@
|
|||||||
|
|
||||||
name: "Release Stable Version"
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_call:
|
|
||||||
inputs:
|
|
||||||
git_tag:
|
|
||||||
description: 'Git tag'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
cache_tag:
|
|
||||||
description: 'Cached dependencies tag'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "cu129"
|
|
||||||
python_minor:
|
|
||||||
description: 'Python minor version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "13"
|
|
||||||
python_patch:
|
|
||||||
description: 'Python patch version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "6"
|
|
||||||
rel_name:
|
|
||||||
description: 'Release name'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "nvidia"
|
|
||||||
rel_extra_name:
|
|
||||||
description: 'Release extra name'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
test_release:
|
|
||||||
description: 'Test Release'
|
|
||||||
required: true
|
|
||||||
type: boolean
|
|
||||||
default: true
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
git_tag:
|
|
||||||
description: 'Git tag'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
cache_tag:
|
|
||||||
description: 'Cached dependencies tag'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "cu129"
|
|
||||||
python_minor:
|
|
||||||
description: 'Python minor version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "13"
|
|
||||||
python_patch:
|
|
||||||
description: 'Python patch version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "6"
|
|
||||||
rel_name:
|
|
||||||
description: 'Release name'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "nvidia"
|
|
||||||
rel_extra_name:
|
|
||||||
description: 'Release extra name'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
test_release:
|
|
||||||
description: 'Test Release'
|
|
||||||
required: true
|
|
||||||
type: boolean
|
|
||||||
default: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
package_comfy_windows:
|
|
||||||
permissions:
|
|
||||||
contents: "write"
|
|
||||||
packages: "write"
|
|
||||||
pull-requests: "read"
|
|
||||||
runs-on: windows-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
ref: ${{ inputs.git_tag }}
|
|
||||||
fetch-depth: 150
|
|
||||||
persist-credentials: false
|
|
||||||
- uses: actions/cache/restore@v4
|
|
||||||
id: cache
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
${{ inputs.cache_tag }}_python_deps.tar
|
|
||||||
update_comfyui_and_python_dependencies.bat
|
|
||||||
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
|
|
||||||
- shell: bash
|
|
||||||
run: |
|
|
||||||
mv ${{ inputs.cache_tag }}_python_deps.tar ../
|
|
||||||
mv update_comfyui_and_python_dependencies.bat ../
|
|
||||||
cd ..
|
|
||||||
tar xf ${{ inputs.cache_tag }}_python_deps.tar
|
|
||||||
pwd
|
|
||||||
ls
|
|
||||||
|
|
||||||
- shell: bash
|
|
||||||
run: |
|
|
||||||
cd ..
|
|
||||||
cp -r ComfyUI ComfyUI_copy
|
|
||||||
curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
|
|
||||||
unzip python_embeded.zip -d python_embeded
|
|
||||||
cd python_embeded
|
|
||||||
echo ${{ env.MINOR_VERSION }}
|
|
||||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
|
||||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
|
||||||
./python.exe get-pip.py
|
|
||||||
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
|
|
||||||
|
|
||||||
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
|
||||||
./python.exe -s -m pip install -r requirements_comfyui.txt
|
|
||||||
rm requirements_comfyui.txt
|
|
||||||
|
|
||||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
|
||||||
|
|
||||||
if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
|
|
||||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
|
||||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
|
||||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
|
||||||
fi
|
|
||||||
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
|
||||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable
|
|
||||||
mv python_embeded ComfyUI_windows_portable
|
|
||||||
mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI
|
|
||||||
|
|
||||||
cd ComfyUI_windows_portable
|
|
||||||
|
|
||||||
mkdir update
|
|
||||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
|
||||||
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
|
|
||||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
|
||||||
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
|
||||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
|
|
||||||
|
|
||||||
- shell: bash
|
|
||||||
if: ${{ inputs.test_release }}
|
|
||||||
run: |
|
|
||||||
cd ..
|
|
||||||
cd ComfyUI_windows_portable
|
|
||||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
|
||||||
|
|
||||||
python_embeded/python.exe -s ./update/update.py ComfyUI/
|
|
||||||
|
|
||||||
ls
|
|
||||||
|
|
||||||
- name: Upload binaries to release
|
|
||||||
uses: softprops/action-gh-release@v2
|
|
||||||
with:
|
|
||||||
files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
|
|
||||||
tag_name: ${{ inputs.git_tag }}
|
|
||||||
draft: true
|
|
||||||
overwrite_files: true
|
|
||||||
21
.github/workflows/stale-issues.yml
vendored
21
.github/workflows/stale-issues.yml
vendored
@ -1,21 +0,0 @@
|
|||||||
name: 'Close stale issues'
|
|
||||||
on:
|
|
||||||
schedule:
|
|
||||||
# Run daily at 430 am PT
|
|
||||||
- cron: '30 11 * * *'
|
|
||||||
permissions:
|
|
||||||
issues: write
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
stale:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/stale@v9
|
|
||||||
with:
|
|
||||||
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
|
|
||||||
days-before-stale: 30
|
|
||||||
days-before-close: 7
|
|
||||||
stale-issue-label: 'Stale'
|
|
||||||
only-labels: 'User Support'
|
|
||||||
exempt-all-assignees: true
|
|
||||||
exempt-all-milestones: true
|
|
||||||
31
.github/workflows/test-build.yml
vendored
31
.github/workflows/test-build.yml
vendored
@ -1,31 +0,0 @@
|
|||||||
name: Build package
|
|
||||||
|
|
||||||
#
|
|
||||||
# This workflow is a test of the python package build.
|
|
||||||
# Install Python dependencies across different Python versions.
|
|
||||||
#
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
paths:
|
|
||||||
- "requirements.txt"
|
|
||||||
- ".github/workflows/test-build.yml"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
name: Build Test
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install -r requirements.txt
|
|
||||||
99
.github/workflows/test-ci.yml
vendored
99
.github/workflows/test-ci.yml
vendored
@ -1,99 +0,0 @@
|
|||||||
# This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI
|
|
||||||
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
|
|
||||||
name: Full Comfy CI Workflow Runs
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
- release/**
|
|
||||||
paths-ignore:
|
|
||||||
- 'app/**'
|
|
||||||
- 'input/**'
|
|
||||||
- 'output/**'
|
|
||||||
- 'notebooks/**'
|
|
||||||
- 'script_examples/**'
|
|
||||||
- '.github/**'
|
|
||||||
- 'web/**'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test-stable:
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
# os: [macos, linux, windows]
|
|
||||||
# os: [macos, linux]
|
|
||||||
os: [linux]
|
|
||||||
python_version: ["3.10", "3.11", "3.12"]
|
|
||||||
cuda_version: ["12.1"]
|
|
||||||
torch_version: ["stable"]
|
|
||||||
include:
|
|
||||||
# - os: macos
|
|
||||||
# runner_label: [self-hosted, macOS]
|
|
||||||
# flags: "--use-pytorch-cross-attention"
|
|
||||||
- os: linux
|
|
||||||
runner_label: [self-hosted, Linux]
|
|
||||||
flags: ""
|
|
||||||
# - os: windows
|
|
||||||
# runner_label: [self-hosted, Windows]
|
|
||||||
# flags: ""
|
|
||||||
runs-on: ${{ matrix.runner_label }}
|
|
||||||
steps:
|
|
||||||
- name: Test Workflows
|
|
||||||
uses: comfy-org/comfy-action@main
|
|
||||||
with:
|
|
||||||
os: ${{ matrix.os }}
|
|
||||||
python_version: ${{ matrix.python_version }}
|
|
||||||
torch_version: ${{ matrix.torch_version }}
|
|
||||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
|
||||||
comfyui_flags: ${{ matrix.flags }}
|
|
||||||
|
|
||||||
# test-win-nightly:
|
|
||||||
# strategy:
|
|
||||||
# fail-fast: true
|
|
||||||
# matrix:
|
|
||||||
# os: [windows]
|
|
||||||
# python_version: ["3.9", "3.10", "3.11", "3.12"]
|
|
||||||
# cuda_version: ["12.1"]
|
|
||||||
# torch_version: ["nightly"]
|
|
||||||
# include:
|
|
||||||
# - os: windows
|
|
||||||
# runner_label: [self-hosted, Windows]
|
|
||||||
# flags: ""
|
|
||||||
# runs-on: ${{ matrix.runner_label }}
|
|
||||||
# steps:
|
|
||||||
# - name: Test Workflows
|
|
||||||
# uses: comfy-org/comfy-action@main
|
|
||||||
# with:
|
|
||||||
# os: ${{ matrix.os }}
|
|
||||||
# python_version: ${{ matrix.python_version }}
|
|
||||||
# torch_version: ${{ matrix.torch_version }}
|
|
||||||
# google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
|
||||||
# comfyui_flags: ${{ matrix.flags }}
|
|
||||||
|
|
||||||
test-unix-nightly:
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
# os: [macos, linux]
|
|
||||||
os: [linux]
|
|
||||||
python_version: ["3.11"]
|
|
||||||
cuda_version: ["12.1"]
|
|
||||||
torch_version: ["nightly"]
|
|
||||||
include:
|
|
||||||
# - os: macos
|
|
||||||
# runner_label: [self-hosted, macOS]
|
|
||||||
# flags: "--use-pytorch-cross-attention"
|
|
||||||
- os: linux
|
|
||||||
runner_label: [self-hosted, Linux]
|
|
||||||
flags: ""
|
|
||||||
runs-on: ${{ matrix.runner_label }}
|
|
||||||
steps:
|
|
||||||
- name: Test Workflows
|
|
||||||
uses: comfy-org/comfy-action@main
|
|
||||||
with:
|
|
||||||
os: ${{ matrix.os }}
|
|
||||||
python_version: ${{ matrix.python_version }}
|
|
||||||
torch_version: ${{ matrix.torch_version }}
|
|
||||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
|
||||||
comfyui_flags: ${{ matrix.flags }}
|
|
||||||
30
.github/workflows/test-execution.yml
vendored
30
.github/workflows/test-execution.yml
vendored
@ -1,30 +0,0 @@
|
|||||||
name: Execution Tests
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ main, master, release/** ]
|
|
||||||
pull_request:
|
|
||||||
branches: [ main, master, release/** ]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test:
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
continue-on-error: true
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: '3.12'
|
|
||||||
- name: Install requirements
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
|
||||||
pip install -r requirements.txt
|
|
||||||
pip install -r tests-unit/requirements.txt
|
|
||||||
- name: Run Execution Tests
|
|
||||||
run: |
|
|
||||||
python -m pytest tests/execution -v --skip-timing-checks
|
|
||||||
45
.github/workflows/test-launch.yml
vendored
45
.github/workflows/test-launch.yml
vendored
@ -1,45 +0,0 @@
|
|||||||
name: Test server launches without errors
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ main, master, release/** ]
|
|
||||||
pull_request:
|
|
||||||
branches: [ main, master, release/** ]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout ComfyUI
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
repository: "comfyanonymous/ComfyUI"
|
|
||||||
path: "ComfyUI"
|
|
||||||
- uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
- name: Install requirements
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
|
||||||
pip install -r requirements.txt
|
|
||||||
pip install wait-for-it
|
|
||||||
working-directory: ComfyUI
|
|
||||||
- name: Start ComfyUI server
|
|
||||||
run: |
|
|
||||||
python main.py --cpu 2>&1 | tee console_output.log &
|
|
||||||
wait-for-it --service 127.0.0.1:8188 -t 30
|
|
||||||
working-directory: ComfyUI
|
|
||||||
- name: Check for unhandled exceptions in server log
|
|
||||||
run: |
|
|
||||||
if grep -qE "Exception|Error" console_output.log; then
|
|
||||||
echo "Unhandled exception/error found in server log."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
working-directory: ComfyUI
|
|
||||||
- uses: actions/upload-artifact@v4
|
|
||||||
if: always()
|
|
||||||
with:
|
|
||||||
name: console-output
|
|
||||||
path: ComfyUI/console_output.log
|
|
||||||
retention-days: 30
|
|
||||||
30
.github/workflows/test-unit.yml
vendored
30
.github/workflows/test-unit.yml
vendored
@ -1,30 +0,0 @@
|
|||||||
name: Unit Tests
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ main, master, release/** ]
|
|
||||||
pull_request:
|
|
||||||
branches: [ main, master, release/** ]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test:
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-2022, macos-latest]
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
continue-on-error: true
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: '3.12'
|
|
||||||
- name: Install requirements
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
|
||||||
pip install -r requirements.txt
|
|
||||||
- name: Run Unit Tests
|
|
||||||
run: |
|
|
||||||
pip install -r tests-unit/requirements.txt
|
|
||||||
python -m pytest tests-unit
|
|
||||||
56
.github/workflows/update-api-stubs.yml
vendored
56
.github/workflows/update-api-stubs.yml
vendored
@ -1,56 +0,0 @@
|
|||||||
name: Generate Pydantic Stubs from api.comfy.org
|
|
||||||
|
|
||||||
on:
|
|
||||||
schedule:
|
|
||||||
- cron: '0 0 * * 1'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
generate-models:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install 'datamodel-code-generator[http]'
|
|
||||||
npm install @redocly/cli
|
|
||||||
|
|
||||||
- name: Download OpenAPI spec
|
|
||||||
run: |
|
|
||||||
curl -o openapi.yaml https://api.comfy.org/openapi
|
|
||||||
|
|
||||||
- name: Filter OpenAPI spec with Redocly
|
|
||||||
run: |
|
|
||||||
npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
|
||||||
|
|
||||||
- name: Generate API models
|
|
||||||
run: |
|
|
||||||
datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
|
|
||||||
|
|
||||||
- name: Check for changes
|
|
||||||
id: git-check
|
|
||||||
run: |
|
|
||||||
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
- name: Create Pull Request
|
|
||||||
if: steps.git-check.outputs.changes == 'true'
|
|
||||||
uses: peter-evans/create-pull-request@v5
|
|
||||||
with:
|
|
||||||
commit-message: 'chore: update API models from OpenAPI spec'
|
|
||||||
title: 'Update API models from api.comfy.org'
|
|
||||||
body: |
|
|
||||||
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
|
|
||||||
|
|
||||||
Generated automatically by the a Github workflow.
|
|
||||||
branch: update-api-stubs
|
|
||||||
delete-branch: true
|
|
||||||
base: master
|
|
||||||
59
.github/workflows/update-version.yml
vendored
59
.github/workflows/update-version.yml
vendored
@ -1,59 +0,0 @@
|
|||||||
name: Update Version File
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- "pyproject.toml"
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
- release/**
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
update-version:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
# Don't run on fork PRs
|
|
||||||
if: github.event.pull_request.head.repo.full_name == github.repository
|
|
||||||
permissions:
|
|
||||||
pull-requests: write
|
|
||||||
contents: write
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: "3.11"
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
|
|
||||||
- name: Update comfyui_version.py
|
|
||||||
run: |
|
|
||||||
# Read version from pyproject.toml and update comfyui_version.py
|
|
||||||
python -c '
|
|
||||||
import tomllib
|
|
||||||
|
|
||||||
# Read version from pyproject.toml
|
|
||||||
with open("pyproject.toml", "rb") as f:
|
|
||||||
config = tomllib.load(f)
|
|
||||||
version = config["project"]["version"]
|
|
||||||
|
|
||||||
# Write version to comfyui_version.py
|
|
||||||
with open("comfyui_version.py", "w") as f:
|
|
||||||
f.write("# This file is automatically generated by the build process when version is\n")
|
|
||||||
f.write("# updated in pyproject.toml.\n")
|
|
||||||
f.write(f"__version__ = \"{version}\"\n")
|
|
||||||
'
|
|
||||||
|
|
||||||
- name: Commit changes
|
|
||||||
run: |
|
|
||||||
git config --local user.name "github-actions"
|
|
||||||
git config --local user.email "github-actions@github.com"
|
|
||||||
git fetch origin ${{ github.head_ref }}
|
|
||||||
git checkout -B ${{ github.head_ref }} origin/${{ github.head_ref }}
|
|
||||||
git add comfyui_version.py
|
|
||||||
git diff --quiet && git diff --staged --quiet || git commit -m "chore: Update comfyui_version.py to match pyproject.toml"
|
|
||||||
git push origin HEAD:${{ github.head_ref }}
|
|
||||||
71
.github/workflows/windows_release_cu118_dependencies.yml
vendored
Normal file
71
.github/workflows/windows_release_cu118_dependencies.yml
vendored
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
name: "Windows Release cu118 dependencies"
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
# push:
|
||||||
|
# branches:
|
||||||
|
# - master
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build_dependencies:
|
||||||
|
env:
|
||||||
|
# you need at least cuda 5.0 for some of the stuff compiled here.
|
||||||
|
TORCH_CUDA_ARCH_LIST: "5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6 8.9"
|
||||||
|
FORCE_CUDA: 1
|
||||||
|
MAX_JOBS: 1 # will crash otherwise
|
||||||
|
DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc
|
||||||
|
XFORMERS_BUILD_TYPE: "Release"
|
||||||
|
runs-on: windows-latest
|
||||||
|
steps:
|
||||||
|
- name: Cache Built Dependencies
|
||||||
|
uses: actions/cache@v3
|
||||||
|
id: cache-cu118_python_stuff
|
||||||
|
with:
|
||||||
|
path: cu118_python_deps.tar
|
||||||
|
key: ${{ runner.os }}-build-cu118
|
||||||
|
|
||||||
|
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10.9'
|
||||||
|
|
||||||
|
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
|
||||||
|
uses: comfyanonymous/cuda-toolkit@test
|
||||||
|
id: cuda-toolkit
|
||||||
|
with:
|
||||||
|
cuda: '11.8.0'
|
||||||
|
# copied from xformers github
|
||||||
|
- name: Setup MSVC
|
||||||
|
uses: ilammy/msvc-dev-cmd@v1
|
||||||
|
- name: Configure Pagefile
|
||||||
|
# windows runners will OOM with many CUDA architectures
|
||||||
|
# we cheat here with a page file
|
||||||
|
uses: al-cheb/configure-pagefile-action@v1.3
|
||||||
|
with:
|
||||||
|
minimum-size: 2GB
|
||||||
|
# really unfortunate: https://github.com/ilammy/msvc-dev-cmd#name-conflicts-with-shell-bash
|
||||||
|
- name: Remove link.exe
|
||||||
|
shell: bash
|
||||||
|
run: rm /usr/bin/link
|
||||||
|
|
||||||
|
- if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||||
|
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||||
|
echo installed basic
|
||||||
|
git clone --recurse-submodules https://github.com/facebookresearch/xformers.git
|
||||||
|
cd xformers
|
||||||
|
python -m pip install --no-cache-dir wheel setuptools twine
|
||||||
|
echo building xformers
|
||||||
|
python setup.py bdist_wheel -d ../temp_wheel_dir/
|
||||||
|
cd ..
|
||||||
|
rm -rf xformers
|
||||||
|
ls -lah temp_wheel_dir
|
||||||
|
mv temp_wheel_dir cu118_python_deps
|
||||||
|
tar cf cu118_python_deps.tar cu118_python_deps
|
||||||
|
|
||||||
|
|
||||||
30
.github/workflows/windows_release_cu118_dependencies_2.yml
vendored
Normal file
30
.github/workflows/windows_release_cu118_dependencies_2.yml
vendored
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
name: "Windows Release cu118 dependencies 2"
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
# push:
|
||||||
|
# branches:
|
||||||
|
# - master
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build_dependencies:
|
||||||
|
runs-on: windows-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10.9'
|
||||||
|
|
||||||
|
- shell: bash
|
||||||
|
run: |
|
||||||
|
python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||||
|
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||||
|
echo installed basic
|
||||||
|
ls -lah temp_wheel_dir
|
||||||
|
mv temp_wheel_dir cu118_python_deps
|
||||||
|
tar cf cu118_python_deps.tar cu118_python_deps
|
||||||
|
|
||||||
|
- uses: actions/cache/save@v3
|
||||||
|
with:
|
||||||
|
path: cu118_python_deps.tar
|
||||||
|
key: ${{ runner.os }}-build-cu118
|
||||||
76
.github/workflows/windows_release_cu118_package.yml
vendored
Normal file
76
.github/workflows/windows_release_cu118_package.yml
vendored
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
name: "Windows Release cu118 packaging"
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
# push:
|
||||||
|
# branches:
|
||||||
|
# - master
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
package_comfyui:
|
||||||
|
permissions:
|
||||||
|
contents: "write"
|
||||||
|
packages: "write"
|
||||||
|
pull-requests: "read"
|
||||||
|
runs-on: windows-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/cache/restore@v3
|
||||||
|
id: cache
|
||||||
|
with:
|
||||||
|
path: cu118_python_deps.tar
|
||||||
|
key: ${{ runner.os }}-build-cu118
|
||||||
|
- shell: bash
|
||||||
|
run: |
|
||||||
|
mv cu118_python_deps.tar ../
|
||||||
|
cd ..
|
||||||
|
tar xf cu118_python_deps.tar
|
||||||
|
pwd
|
||||||
|
ls
|
||||||
|
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- shell: bash
|
||||||
|
run: |
|
||||||
|
cd ..
|
||||||
|
cp -r ComfyUI ComfyUI_copy
|
||||||
|
curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip
|
||||||
|
unzip python_embeded.zip -d python_embeded
|
||||||
|
cd python_embeded
|
||||||
|
echo 'import site' >> ./python310._pth
|
||||||
|
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||||
|
./python.exe get-pip.py
|
||||||
|
./python.exe -s -m pip install ../cu118_python_deps/*
|
||||||
|
sed -i '1i../ComfyUI' ./python310._pth
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
|
||||||
|
mkdir ComfyUI_windows_portable
|
||||||
|
mv python_embeded ComfyUI_windows_portable
|
||||||
|
mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI
|
||||||
|
|
||||||
|
cd ComfyUI_windows_portable
|
||||||
|
|
||||||
|
mkdir update
|
||||||
|
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||||
|
cp -r ComfyUI/.ci/update_windows_cu118/* ./update/
|
||||||
|
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||||
|
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||||
|
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z
|
||||||
|
|
||||||
|
cd ComfyUI_windows_portable
|
||||||
|
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||||
|
|
||||||
|
ls
|
||||||
|
|
||||||
|
- name: Upload binaries to release
|
||||||
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
with:
|
||||||
|
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
file: new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z
|
||||||
|
tag: "latest"
|
||||||
|
overwrite: true
|
||||||
|
|
||||||
@ -1,72 +0,0 @@
|
|||||||
name: "Windows Release dependencies"
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
xformers:
|
|
||||||
description: 'xformers version'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
extra_dependencies:
|
|
||||||
description: 'extra dependencies'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
cu:
|
|
||||||
description: 'cuda version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "130"
|
|
||||||
|
|
||||||
python_minor:
|
|
||||||
description: 'python minor version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "13"
|
|
||||||
|
|
||||||
python_patch:
|
|
||||||
description: 'python patch version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "9"
|
|
||||||
# push:
|
|
||||||
# branches:
|
|
||||||
# - master
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build_dependencies:
|
|
||||||
runs-on: windows-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
|
|
||||||
|
|
||||||
- shell: bash
|
|
||||||
run: |
|
|
||||||
echo "@echo off
|
|
||||||
call update_comfyui.bat nopause
|
|
||||||
echo -
|
|
||||||
echo This will try to update pytorch and all python dependencies.
|
|
||||||
echo -
|
|
||||||
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
|
||||||
echo -
|
|
||||||
pause
|
|
||||||
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
|
||||||
pause" > update_comfyui_and_python_dependencies.bat
|
|
||||||
|
|
||||||
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
|
|
||||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
|
|
||||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
|
||||||
echo installed basic
|
|
||||||
ls -lah temp_wheel_dir
|
|
||||||
mv temp_wheel_dir cu${{ inputs.cu }}_python_deps
|
|
||||||
tar cf cu${{ inputs.cu }}_python_deps.tar cu${{ inputs.cu }}_python_deps
|
|
||||||
|
|
||||||
- uses: actions/cache/save@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
cu${{ inputs.cu }}_python_deps.tar
|
|
||||||
update_comfyui_and_python_dependencies.bat
|
|
||||||
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
|
|
||||||
@ -1,64 +0,0 @@
|
|||||||
name: "Windows Release dependencies Manual"
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
torch_dependencies:
|
|
||||||
description: 'torch dependencies'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
|
|
||||||
cache_tag:
|
|
||||||
description: 'Cached dependencies tag'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "cu128"
|
|
||||||
|
|
||||||
python_minor:
|
|
||||||
description: 'python minor version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "12"
|
|
||||||
|
|
||||||
python_patch:
|
|
||||||
description: 'python patch version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "10"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build_dependencies:
|
|
||||||
runs-on: windows-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
|
|
||||||
|
|
||||||
- shell: bash
|
|
||||||
run: |
|
|
||||||
echo "@echo off
|
|
||||||
call update_comfyui.bat nopause
|
|
||||||
echo -
|
|
||||||
echo This will try to update pytorch and all python dependencies.
|
|
||||||
echo -
|
|
||||||
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
|
||||||
echo -
|
|
||||||
pause
|
|
||||||
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
|
|
||||||
pause" > update_comfyui_and_python_dependencies.bat
|
|
||||||
|
|
||||||
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
|
|
||||||
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
|
|
||||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
|
||||||
echo installed basic
|
|
||||||
ls -lah temp_wheel_dir
|
|
||||||
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
|
|
||||||
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
|
|
||||||
|
|
||||||
- uses: actions/cache/save@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
${{ inputs.cache_tag }}_python_deps.tar
|
|
||||||
update_comfyui_and_python_dependencies.bat
|
|
||||||
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
|
|
||||||
@ -2,24 +2,6 @@ name: "Windows Release Nightly pytorch"
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
|
||||||
cu:
|
|
||||||
description: 'cuda version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "129"
|
|
||||||
|
|
||||||
python_minor:
|
|
||||||
description: 'python minor version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "13"
|
|
||||||
|
|
||||||
python_patch:
|
|
||||||
description: 'python patch version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "5"
|
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
@ -32,33 +14,28 @@ jobs:
|
|||||||
pull-requests: "read"
|
pull-requests: "read"
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 30
|
fetch-depth: 0
|
||||||
persist-credentials: false
|
- uses: actions/setup-python@v4
|
||||||
- uses: actions/setup-python@v5
|
|
||||||
with:
|
with:
|
||||||
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
|
python-version: '3.11.3'
|
||||||
- shell: bash
|
- shell: bash
|
||||||
run: |
|
run: |
|
||||||
cd ..
|
cd ..
|
||||||
cp -r ComfyUI ComfyUI_copy
|
cp -r ComfyUI ComfyUI_copy
|
||||||
curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
|
curl https://www.python.org/ftp/python/3.11.3/python-3.11.3-embed-amd64.zip -o python_embeded.zip
|
||||||
unzip python_embeded.zip -d python_embeded
|
unzip python_embeded.zip -d python_embeded
|
||||||
cd python_embeded
|
cd python_embeded
|
||||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
echo 'import site' >> ./python311._pth
|
||||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||||
./python.exe get-pip.py
|
./python.exe get-pip.py
|
||||||
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||||
ls ../temp_wheel_dir
|
ls ../temp_wheel_dir
|
||||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
sed -i '1i../ComfyUI' ./python311._pth
|
||||||
|
|
||||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
|
||||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable_nightly_pytorch
|
mkdir ComfyUI_windows_portable_nightly_pytorch
|
||||||
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
||||||
@ -68,15 +45,13 @@ jobs:
|
|||||||
|
|
||||||
mkdir update
|
mkdir update
|
||||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||||
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
|
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||||
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
|
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
|
||||||
|
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
|
||||||
|
|
||||||
echo "call update_comfyui.bat nopause
|
|
||||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
|
||||||
pause" > ./update/update_comfyui_and_python_dependencies.bat
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||||
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable_nightly_pytorch
|
cd ComfyUI_windows_portable_nightly_pytorch
|
||||||
|
|||||||
106
.github/workflows/windows_release_package.yml
vendored
106
.github/workflows/windows_release_package.yml
vendored
@ -1,106 +0,0 @@
|
|||||||
name: "Windows Release packaging"
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
cu:
|
|
||||||
description: 'cuda version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "129"
|
|
||||||
|
|
||||||
python_minor:
|
|
||||||
description: 'python minor version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "13"
|
|
||||||
|
|
||||||
python_patch:
|
|
||||||
description: 'python patch version'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "6"
|
|
||||||
# push:
|
|
||||||
# branches:
|
|
||||||
# - master
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
package_comfyui:
|
|
||||||
permissions:
|
|
||||||
contents: "write"
|
|
||||||
packages: "write"
|
|
||||||
pull-requests: "read"
|
|
||||||
runs-on: windows-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/cache/restore@v4
|
|
||||||
id: cache
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
cu${{ inputs.cu }}_python_deps.tar
|
|
||||||
update_comfyui_and_python_dependencies.bat
|
|
||||||
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
|
|
||||||
- shell: bash
|
|
||||||
run: |
|
|
||||||
mv cu${{ inputs.cu }}_python_deps.tar ../
|
|
||||||
mv update_comfyui_and_python_dependencies.bat ../
|
|
||||||
cd ..
|
|
||||||
tar xf cu${{ inputs.cu }}_python_deps.tar
|
|
||||||
pwd
|
|
||||||
ls
|
|
||||||
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 150
|
|
||||||
persist-credentials: false
|
|
||||||
- shell: bash
|
|
||||||
run: |
|
|
||||||
cd ..
|
|
||||||
cp -r ComfyUI ComfyUI_copy
|
|
||||||
curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
|
|
||||||
unzip python_embeded.zip -d python_embeded
|
|
||||||
cd python_embeded
|
|
||||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
|
||||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
|
||||||
./python.exe get-pip.py
|
|
||||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
|
||||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
|
||||||
|
|
||||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
|
||||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
|
||||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
|
||||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable
|
|
||||||
mv python_embeded ComfyUI_windows_portable
|
|
||||||
mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI
|
|
||||||
|
|
||||||
cd ComfyUI_windows_portable
|
|
||||||
|
|
||||||
mkdir update
|
|
||||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
|
||||||
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
|
|
||||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
|
||||||
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
|
||||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
|
||||||
|
|
||||||
cd ComfyUI_windows_portable
|
|
||||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
|
||||||
|
|
||||||
python_embeded/python.exe -s ./update/update.py ComfyUI/
|
|
||||||
|
|
||||||
ls
|
|
||||||
|
|
||||||
- name: Upload binaries to release
|
|
||||||
uses: svenstaro/upload-release-action@v2
|
|
||||||
with:
|
|
||||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
file: new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
|
||||||
tag: "latest"
|
|
||||||
overwrite: true
|
|
||||||
|
|
||||||
27
.gitignore
vendored
27
.gitignore
vendored
@ -1,26 +1,11 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
/output/
|
output/
|
||||||
/input/
|
input/
|
||||||
!/input/example.png
|
!input/example.png
|
||||||
/models/
|
models/
|
||||||
/temp/
|
temp/
|
||||||
/custom_nodes/
|
custom_nodes/
|
||||||
!custom_nodes/example_node.py.example
|
!custom_nodes/example_node.py.example
|
||||||
extra_model_paths.yaml
|
extra_model_paths.yaml
|
||||||
/.vs
|
/.vs
|
||||||
.vscode/
|
|
||||||
.idea/
|
|
||||||
venv/
|
|
||||||
.venv/
|
|
||||||
/web/extensions/*
|
|
||||||
!/web/extensions/logging.js.example
|
|
||||||
!/web/extensions/core/
|
|
||||||
/tests-ui/data/object_info.json
|
|
||||||
/user/
|
|
||||||
*.log
|
|
||||||
web_custom_versions/
|
|
||||||
.DS_Store
|
|
||||||
openapi.yaml
|
|
||||||
filtered-openapi.yaml
|
|
||||||
uv.lock
|
|
||||||
|
|||||||
@ -1,2 +0,0 @@
|
|||||||
# Admins
|
|
||||||
* @comfyanonymous @kosinkadink @guill
|
|
||||||
@ -1,41 +0,0 @@
|
|||||||
# Contributing to ComfyUI
|
|
||||||
|
|
||||||
Welcome, and thank you for your interest in contributing to ComfyUI!
|
|
||||||
|
|
||||||
There are several ways in which you can contribute, beyond writing code. The goal of this document is to provide a high-level overview of how you can get involved.
|
|
||||||
|
|
||||||
## Asking Questions
|
|
||||||
|
|
||||||
Have a question? Instead of opening an issue, please ask on [Discord](https://comfy.org/discord) or [Matrix](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) channels. Our team and the community will help you.
|
|
||||||
|
|
||||||
## Providing Feedback
|
|
||||||
|
|
||||||
Your comments and feedback are welcome, and the development team is available via a handful of different channels.
|
|
||||||
|
|
||||||
See the `#bug-report`, `#feature-request` and `#feedback` channels on Discord.
|
|
||||||
|
|
||||||
## Reporting Issues
|
|
||||||
|
|
||||||
Have you identified a reproducible problem in ComfyUI? Do you have a feature request? We want to hear about it! Here's how you can report your issue as effectively as possible.
|
|
||||||
|
|
||||||
|
|
||||||
### Look For an Existing Issue
|
|
||||||
|
|
||||||
Before you create a new issue, please do a search in [open issues](https://github.com/comfyanonymous/ComfyUI/issues) to see if the issue or feature request has already been filed.
|
|
||||||
|
|
||||||
If you find your issue already exists, make relevant comments and add your [reaction](https://github.com/blog/2119-add-reactions-to-pull-requests-issues-and-comments). Use a reaction in place of a "+1" comment:
|
|
||||||
|
|
||||||
* 👍 - upvote
|
|
||||||
* 👎 - downvote
|
|
||||||
|
|
||||||
If you cannot find an existing issue that describes your bug or feature, create a new issue. We have an issue template in place to organize new issues.
|
|
||||||
|
|
||||||
|
|
||||||
### Creating Pull Requests
|
|
||||||
|
|
||||||
* Please refer to the article on [creating pull requests](https://github.com/comfyanonymous/ComfyUI/wiki/How-to-Contribute-Code) and contributing to this project.
|
|
||||||
|
|
||||||
|
|
||||||
## Thank You
|
|
||||||
|
|
||||||
Your contributions to open source, large or small, make great projects like this possible. Thank you for taking the time to contribute.
|
|
||||||
168
QUANTIZATION.md
168
QUANTIZATION.md
@ -1,168 +0,0 @@
|
|||||||
# The Comfy guide to Quantization
|
|
||||||
|
|
||||||
|
|
||||||
## How does quantization work?
|
|
||||||
|
|
||||||
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
|
|
||||||
|
|
||||||
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
|
|
||||||
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
|
|
||||||
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
|
|
||||||
|
|
||||||
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
|
|
||||||
|
|
||||||
```
|
|
||||||
absmax = max(abs(tensor))
|
|
||||||
scale = amax / max_dynamic_range_low_precision
|
|
||||||
|
|
||||||
# Quantization
|
|
||||||
tensor_q = (tensor / scale).to(low_precision_dtype)
|
|
||||||
|
|
||||||
# De-Quantization
|
|
||||||
tensor_dq = tensor_q.to(fp16) * scale
|
|
||||||
|
|
||||||
tensor_dq ~ tensor
|
|
||||||
```
|
|
||||||
|
|
||||||
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
|
|
||||||
|
|
||||||
|
|
||||||
## Quantization in Comfy
|
|
||||||
|
|
||||||
```
|
|
||||||
QuantizedTensor (torch.Tensor subclass)
|
|
||||||
↓ __torch_dispatch__
|
|
||||||
Two-Level Registry (generic + layout handlers)
|
|
||||||
↓
|
|
||||||
MixedPrecisionOps + Metadata Detection
|
|
||||||
```
|
|
||||||
|
|
||||||
### Representation
|
|
||||||
|
|
||||||
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
|
|
||||||
|
|
||||||
A `Layout` class defines how a specific quantization format behaves:
|
|
||||||
- Required parameters
|
|
||||||
- Quantize method
|
|
||||||
- De-Quantize method
|
|
||||||
|
|
||||||
```python
|
|
||||||
from comfy.quant_ops import QuantizedLayout
|
|
||||||
|
|
||||||
class MyLayout(QuantizedLayout):
|
|
||||||
@classmethod
|
|
||||||
def quantize(cls, tensor, **kwargs):
|
|
||||||
# Convert to quantized format
|
|
||||||
qdata = ...
|
|
||||||
params = {'scale': ..., 'orig_dtype': tensor.dtype}
|
|
||||||
return qdata, params
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
|
||||||
return qdata.to(orig_dtype) * scale
|
|
||||||
```
|
|
||||||
|
|
||||||
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
|
|
||||||
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
|
|
||||||
|
|
||||||
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
|
|
||||||
```python
|
|
||||||
from comfy.quant_ops import register_layout_op
|
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
|
|
||||||
def my_linear(func, args, kwargs):
|
|
||||||
# Extract tensors, call optimized kernel
|
|
||||||
...
|
|
||||||
```
|
|
||||||
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
|
|
||||||
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
|
|
||||||
|
|
||||||
|
|
||||||
### Mixed Precision
|
|
||||||
|
|
||||||
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
|
|
||||||
|
|
||||||
**Architecture:**
|
|
||||||
|
|
||||||
```python
|
|
||||||
class MixedPrecisionOps(disable_weight_init):
|
|
||||||
_layer_quant_config = {} # Maps layer names to quantization configs
|
|
||||||
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
|
|
||||||
```
|
|
||||||
|
|
||||||
**Key mechanism:**
|
|
||||||
|
|
||||||
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
|
|
||||||
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
|
|
||||||
- If the layer name **is** in `_layer_quant_config`:
|
|
||||||
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
|
|
||||||
- Load associated quantization parameters (scales, block_size, etc.)
|
|
||||||
|
|
||||||
**Why it's needed:**
|
|
||||||
|
|
||||||
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
|
|
||||||
|
|
||||||
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
|
|
||||||
|
|
||||||
|
|
||||||
## Checkpoint Format
|
|
||||||
|
|
||||||
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
|
|
||||||
|
|
||||||
The quantized checkpoint will contain the same layers as the original checkpoint but:
|
|
||||||
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
|
|
||||||
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
|
|
||||||
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
|
|
||||||
|
|
||||||
### Scaling Parameters details
|
|
||||||
We define 4 possible scaling parameters that should cover most recipes in the near-future:
|
|
||||||
- **weight_scale**: quantization scalers for the weights
|
|
||||||
- **weight_scale_2**: global scalers in the context of double scaling
|
|
||||||
- **pre_quant_scale**: scalers used for smoothing salient weights
|
|
||||||
- **input_scale**: quantization scalers for the activations
|
|
||||||
|
|
||||||
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
|
||||||
|--------|---------------|--------------|----------------|-----------------|-------------|
|
|
||||||
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
|
||||||
|
|
||||||
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
|
||||||
|
|
||||||
### Quantization Metadata
|
|
||||||
|
|
||||||
The metadata stored alongside the checkpoint contains:
|
|
||||||
- **format_version**: String to define a version of the standard
|
|
||||||
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"_quantization_metadata": {
|
|
||||||
"format_version": "1.0",
|
|
||||||
"layers": {
|
|
||||||
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
|
||||||
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
|
||||||
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Creating Quantized Checkpoints
|
|
||||||
|
|
||||||
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
|
|
||||||
|
|
||||||
### Weight Quantization
|
|
||||||
|
|
||||||
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
|
|
||||||
|
|
||||||
### Calibration (for Activation Quantization)
|
|
||||||
|
|
||||||
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
|
|
||||||
|
|
||||||
1. **Collect statistics**: Run inference on N representative samples
|
|
||||||
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
|
|
||||||
3. **Compute scales**: Derive `input_scale` from collected statistics
|
|
||||||
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
|
||||||
|
|
||||||
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
|
||||||
412
README.md
412
README.md
@ -1,274 +1,104 @@
|
|||||||
<div align="center">
|
ComfyUI
|
||||||
|
=======
|
||||||
|
A powerful and modular stable diffusion GUI and backend.
|
||||||
|
-----------
|
||||||
|

|
||||||
|
|
||||||
# ComfyUI
|
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||||
**The most powerful and modular visual AI engine and application.**
|
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||||
|
|
||||||
|
### [Installing ComfyUI](#installing)
|
||||||
[![Website][website-shield]][website-url]
|
|
||||||
[![Dynamic JSON Badge][discord-shield]][discord-url]
|
|
||||||
[![Twitter][twitter-shield]][twitter-url]
|
|
||||||
[![Matrix][matrix-shield]][matrix-url]
|
|
||||||
<br>
|
|
||||||
[![][github-release-shield]][github-release-link]
|
|
||||||
[![][github-release-date-shield]][github-release-link]
|
|
||||||
[![][github-downloads-shield]][github-downloads-link]
|
|
||||||
[![][github-downloads-latest-shield]][github-downloads-link]
|
|
||||||
|
|
||||||
[matrix-shield]: https://img.shields.io/badge/Matrix-000000?style=flat&logo=matrix&logoColor=white
|
|
||||||
[matrix-url]: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
|
||||||
[website-shield]: https://img.shields.io/badge/ComfyOrg-4285F4?style=flat
|
|
||||||
[website-url]: https://www.comfy.org/
|
|
||||||
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
|
||||||
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
|
||||||
[discord-url]: https://www.comfy.org/discord
|
|
||||||
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
|
|
||||||
[twitter-url]: https://x.com/ComfyUI
|
|
||||||
|
|
||||||
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
|
|
||||||
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
|
||||||
[github-release-date-shield]: https://img.shields.io/github/release-date/comfyanonymous/ComfyUI?style=flat
|
|
||||||
[github-downloads-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/total?style=flat
|
|
||||||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
|
||||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
|
||||||
|
|
||||||

|
|
||||||
</div>
|
|
||||||
|
|
||||||
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
|
|
||||||
|
|
||||||
## Get Started
|
|
||||||
|
|
||||||
#### [Desktop Application](https://www.comfy.org/download)
|
|
||||||
- The easiest way to get started.
|
|
||||||
- Available on Windows & macOS.
|
|
||||||
|
|
||||||
#### [Windows Portable Package](#installing)
|
|
||||||
- Get the latest commits and completely portable.
|
|
||||||
- Available on Windows.
|
|
||||||
|
|
||||||
#### [Manual Install](#manual-install-windows-linux)
|
|
||||||
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
|
|
||||||
|
|
||||||
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
|
||||||
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
- Image Models
|
- Fully supports SD1.x and SD2.x
|
||||||
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
|
|
||||||
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
|
||||||
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
|
||||||
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
|
||||||
- Pixart Alpha and Sigma
|
|
||||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
|
||||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
|
||||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
|
||||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
|
||||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
|
||||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
|
||||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
|
||||||
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
|
||||||
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
|
||||||
- Image Editing Models
|
|
||||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
|
||||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
|
||||||
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
|
|
||||||
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
|
|
||||||
- Video Models
|
|
||||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
|
||||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
|
||||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
|
||||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
|
||||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
|
||||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
|
||||||
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
|
|
||||||
- Audio Models
|
|
||||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
|
||||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
|
||||||
- 3D Models
|
|
||||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
|
||||||
- Asynchronous Queue system
|
- Asynchronous Queue system
|
||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading.
|
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
||||||
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
||||||
- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models.
|
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
||||||
- Safe loading of ckpt, pt, pth, etc.. files.
|
|
||||||
- Embeddings/Textual inversion
|
- Embeddings/Textual inversion
|
||||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||||
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
|
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
|
||||||
- Loading full workflows (with seeds) from generated PNG, WebP and FLAC files.
|
- Loading full workflows (with seeds) from generated PNG files.
|
||||||
- Saving/Loading workflows as Json files.
|
- Saving/Loading workflows as Json files.
|
||||||
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
|
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
|
||||||
- [Area Composition](https://comfyanonymous.github.io/ComfyUI_examples/area_composition/)
|
- [Area Composition](https://comfyanonymous.github.io/ComfyUI_examples/area_composition/)
|
||||||
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
|
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
|
||||||
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
||||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||||
|
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
- Starts up very fast.
|
||||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
- Works fully offline: will never download anything.
|
||||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
|
||||||
- Works fully offline: core will never download anything unless you want to.
|
|
||||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
|
||||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||||
|
|
||||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||||
|
|
||||||
## Release Process
|
|
||||||
|
|
||||||
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
|
||||||
|
|
||||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
|
||||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
|
||||||
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
|
|
||||||
- Minor versions will be used for releases off the master branch.
|
|
||||||
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
|
|
||||||
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
|
||||||
- Serves as the foundation for the desktop release
|
|
||||||
|
|
||||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
|
||||||
- Builds a new release using the latest stable core version
|
|
||||||
|
|
||||||
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
|
|
||||||
- Weekly frontend updates are merged into the core repository
|
|
||||||
- Features are frozen for the upcoming core release
|
|
||||||
- Development continues for the next release cycle
|
|
||||||
|
|
||||||
## Shortcuts
|
## Shortcuts
|
||||||
|
|
||||||
| Keybind | Explanation |
|
| Keybind | Explanation |
|
||||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
| - | - |
|
||||||
| `Ctrl` + `Enter` | Queue up current graph for generation |
|
| Ctrl + Enter | Queue up current graph for generation |
|
||||||
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
|
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||||
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
|
| Ctrl + S | Save workflow |
|
||||||
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
|
| Ctrl + O | Load workflow |
|
||||||
| `Ctrl` + `S` | Save workflow |
|
| Ctrl + A | Select all nodes |
|
||||||
| `Ctrl` + `O` | Load workflow |
|
| Ctrl + M | Mute/unmute selected nodes |
|
||||||
| `Ctrl` + `A` | Select all nodes |
|
| Delete/Backspace | Delete selected nodes |
|
||||||
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
|
| Ctrl + Delete/Backspace | Delete the current graph |
|
||||||
| `Ctrl` + `M` | Mute/unmute selected nodes |
|
| Space | Move the canvas around when held and moving the cursor |
|
||||||
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
| Ctrl/Shift + Click | Add clicked node to selection |
|
||||||
| `Delete`/`Backspace` | Delete selected nodes |
|
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||||
| `Ctrl` + `Backspace` | Delete the current graph |
|
| Ctrl + C/Ctrl + Shift + V| Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||||
| `Space` | Move the canvas around when held and moving the cursor |
|
| Shift + Drag | Move multiple selected nodes at the same time |
|
||||||
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
|
| Ctrl + D | Load default graph |
|
||||||
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
| Q | Toggle visibility of the queue |
|
||||||
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
| H | Toggle visibility of history |
|
||||||
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
|
| R | Refresh graph |
|
||||||
| `Ctrl` + `D` | Load default graph |
|
| Double-Click LMB | Open node quick search palette |
|
||||||
| `Alt` + `+` | Canvas Zoom in |
|
|
||||||
| `Alt` + `-` | Canvas Zoom out |
|
|
||||||
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
|
|
||||||
| `P` | Pin/Unpin selected nodes |
|
|
||||||
| `Ctrl` + `G` | Group selected nodes |
|
|
||||||
| `Q` | Toggle visibility of the queue |
|
|
||||||
| `H` | Toggle visibility of history |
|
|
||||||
| `R` | Refresh graph |
|
|
||||||
| `F` | Show/Hide menu |
|
|
||||||
| `.` | Fit view to selection (Whole graph when nothing is selected) |
|
|
||||||
| Double-Click LMB | Open node quick search palette |
|
|
||||||
| `Shift` + Drag | Move multiple wires at once |
|
|
||||||
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
|
|
||||||
|
|
||||||
`Ctrl` can also be replaced with `Cmd` instead for macOS users
|
Ctrl can also be replaced with Cmd instead for MacOS users
|
||||||
|
|
||||||
# Installing
|
# Installing
|
||||||
|
|
||||||
## Windows Portable
|
## Windows
|
||||||
|
|
||||||
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
|
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
|
||||||
|
|
||||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
|
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z)
|
||||||
|
|
||||||
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
|
Just download, extract and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||||
|
|
||||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
|
||||||
|
|
||||||
Update your Nvidia drivers if it doesn't start.
|
|
||||||
|
|
||||||
#### Alternative Downloads:
|
|
||||||
|
|
||||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
|
||||||
|
|
||||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
|
||||||
|
|
||||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
|
||||||
|
|
||||||
#### How do I share models between another UI and ComfyUI?
|
#### How do I share models between another UI and ComfyUI?
|
||||||
|
|
||||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||||
|
|
||||||
|
## Colab Notebook
|
||||||
|
|
||||||
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
|
To run it on colab or paperspace you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: [Link to open with google colab](https://colab.research.google.com/github/comfyanonymous/ComfyUI/blob/master/notebooks/comfyui_colab.ipynb)
|
||||||
|
|
||||||
You can install and start ComfyUI using comfy-cli:
|
|
||||||
```bash
|
|
||||||
pip install comfy-cli
|
|
||||||
comfy install
|
|
||||||
```
|
|
||||||
|
|
||||||
## Manual Install (Windows, Linux)
|
## Manual Install (Windows, Linux)
|
||||||
|
|
||||||
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
|
||||||
|
|
||||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
|
||||||
|
|
||||||
### Instructions:
|
|
||||||
|
|
||||||
Git clone this repo.
|
Git clone this repo.
|
||||||
|
|
||||||
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||||
|
|
||||||
Put your VAE in: models/vae
|
Put your VAE in: models/vae
|
||||||
|
|
||||||
|
At the time of writing this pytorch has issues with python versions higher than 3.10 so make sure your python/pip versions are 3.10.
|
||||||
|
|
||||||
### AMD GPUs (Linux)
|
### AMD GPUs (Linux only)
|
||||||
|
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
|
||||||
|
|
||||||
|
|
||||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
|
||||||
|
|
||||||
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
|
|
||||||
|
|
||||||
RDNA 3 (RX 7000 series):
|
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
|
|
||||||
|
|
||||||
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
|
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
|
|
||||||
|
|
||||||
RDNA 4 (RX 9000 series):
|
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
|
|
||||||
|
|
||||||
### Intel GPUs (Windows and Linux)
|
|
||||||
|
|
||||||
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
|
||||||
|
|
||||||
1. To install PyTorch xpu, use the following command:
|
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu```
|
|
||||||
|
|
||||||
This is the command to install the Pytorch xpu nightly which might have some performance improvements:
|
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install torch and xformers using this command:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 xformers```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
|
|
||||||
|
|
||||||
#### Troubleshooting
|
#### Troubleshooting
|
||||||
|
|
||||||
@ -288,86 +118,33 @@ After this you should have everything installed and can proceed to running Comfy
|
|||||||
|
|
||||||
### Others:
|
### Others:
|
||||||
|
|
||||||
#### Apple Mac silicon
|
[Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476)
|
||||||
|
|
||||||
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
Mac/MPS: There is basic support in the code but until someone makes some install instruction you are on your own.
|
||||||
|
|
||||||
1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly).
|
### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies?
|
||||||
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
|
|
||||||
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
|
|
||||||
1. Launch ComfyUI by running `python main.py`
|
|
||||||
|
|
||||||
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
|
You don't. If you have another UI installed and working with it's own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it:
|
||||||
|
|
||||||
#### Ascend NPUs
|
```source path_to_other_sd_gui/venv/bin/activate```
|
||||||
|
|
||||||
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
|
or on Windows:
|
||||||
|
|
||||||
1. Begin by installing the recommended or newer kernel version for Linux as specified in the Installation page of torch-npu, if necessary.
|
With Powershell: ```"path_to_other_sd_gui\venv\Scripts\Activate.ps1"```
|
||||||
2. Proceed with the installation of Ascend Basekit, which includes the driver, firmware, and CANN, following the instructions provided for your specific platform.
|
|
||||||
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
|
|
||||||
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
|
|
||||||
|
|
||||||
#### Cambricon MLUs
|
With cmd.exe: ```"path_to_other_sd_gui\venv\Scripts\activate.bat"```
|
||||||
|
|
||||||
For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a step-by-step guide tailored to your platform and installation method:
|
|
||||||
|
|
||||||
1. Install the Cambricon CNToolkit by adhering to the platform-specific instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cntoolkit_3.7.2/cntoolkit_install_3.7.2/index.html)
|
|
||||||
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
|
|
||||||
3. Launch ComfyUI by running `python main.py`
|
|
||||||
|
|
||||||
#### Iluvatar Corex
|
|
||||||
|
|
||||||
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
|
|
||||||
|
|
||||||
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
|
||||||
2. Launch ComfyUI by running `python main.py`
|
|
||||||
|
|
||||||
|
|
||||||
## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
|
|
||||||
|
|
||||||
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
|
|
||||||
|
|
||||||
### Setup
|
|
||||||
|
|
||||||
1. Install the manager dependencies:
|
|
||||||
```bash
|
|
||||||
pip install -r manager_requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
|
|
||||||
```bash
|
|
||||||
python main.py --enable-manager
|
|
||||||
```
|
|
||||||
|
|
||||||
### Command Line Options
|
|
||||||
|
|
||||||
| Flag | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
|
||||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
|
||||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
|
||||||
|
|
||||||
|
And then you can use that terminal to run Comfyui without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI.
|
||||||
|
|
||||||
# Running
|
# Running
|
||||||
|
|
||||||
```python main.py```
|
```python main.py```
|
||||||
|
|
||||||
### For AMD cards not officially supported by ROCm
|
### For AMD 6700, 6600 and maybe others
|
||||||
|
|
||||||
Try running it with this command if you have issues:
|
Try running it with this command if you have issues:
|
||||||
|
|
||||||
For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.0 python main.py```
|
```HSA_OVERRIDE_GFX_VERSION=10.3.0 python main.py```
|
||||||
|
|
||||||
For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py```
|
|
||||||
|
|
||||||
### AMD ROCm Tips
|
|
||||||
|
|
||||||
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
|
|
||||||
|
|
||||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
|
||||||
|
|
||||||
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
|
||||||
|
|
||||||
# Notes
|
# Notes
|
||||||
|
|
||||||
@ -381,78 +158,39 @@ You can use () to change emphasis of a word or phrase like: (good code:1.2) or (
|
|||||||
|
|
||||||
You can use {day|night}, for wildcard/dynamic prompts. With this syntax "{wild|card|test}" will be randomly replaced by either "wild", "card" or "test" by the frontend every time you queue the prompt. To use {} characters in your actual prompt escape them like: \\{ or \\}.
|
You can use {day|night}, for wildcard/dynamic prompts. With this syntax "{wild|card|test}" will be randomly replaced by either "wild", "card" or "test" by the frontend every time you queue the prompt. To use {} characters in your actual prompt escape them like: \\{ or \\}.
|
||||||
|
|
||||||
Dynamic prompts also support C-style comments, like `// comment` or `/* comment */`.
|
|
||||||
|
|
||||||
To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension):
|
To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension):
|
||||||
|
|
||||||
```embedding:embedding_filename.pt```
|
```embedding:embedding_filename.pt```
|
||||||
|
|
||||||
|
### Fedora
|
||||||
|
|
||||||
## How to show high-quality previews?
|
To get python 3.10 on fedora:
|
||||||
|
```dnf install python3.10```
|
||||||
|
|
||||||
Use ```--preview-method auto``` to enable previews.
|
Then you can:
|
||||||
|
|
||||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
|
```python3.10 -m ensurepip```
|
||||||
|
|
||||||
## How to use TLS/SSL?
|
This will let you use: pip3.10 to install all the dependencies.
|
||||||
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
|
||||||
|
|
||||||
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
|
## How to increase generation speed?
|
||||||
|
|
||||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
Make sure you use the regular loaders/Load Checkpoint node to load checkpoints. It will auto pick the right settings depending on your GPU.
|
||||||
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
|
||||||
|
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers this option does not do anything.
|
||||||
|
|
||||||
|
```--dont-upcast-attention```
|
||||||
|
|
||||||
## Support and dev channel
|
## Support and dev channel
|
||||||
|
|
||||||
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.
|
|
||||||
|
|
||||||
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
|
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
|
||||||
|
|
||||||
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
|
||||||
|
|
||||||
## Frontend Development
|
|
||||||
|
|
||||||
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
|
|
||||||
|
|
||||||
### Reporting Issues and Requesting Features
|
|
||||||
|
|
||||||
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
|
|
||||||
|
|
||||||
### Using the Latest Frontend
|
|
||||||
|
|
||||||
The new frontend is now the default for ComfyUI. However, please note:
|
|
||||||
|
|
||||||
1. The frontend in the main ComfyUI repository is updated fortnightly.
|
|
||||||
2. Daily releases are available in the separate frontend repository.
|
|
||||||
|
|
||||||
To use the most up-to-date frontend version:
|
|
||||||
|
|
||||||
1. For the latest daily release, launch ComfyUI with this command line argument:
|
|
||||||
|
|
||||||
```
|
|
||||||
--front-end-version Comfy-Org/ComfyUI_frontend@latest
|
|
||||||
```
|
|
||||||
|
|
||||||
2. For a specific version, replace `latest` with the desired version number:
|
|
||||||
|
|
||||||
```
|
|
||||||
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
|
|
||||||
```
|
|
||||||
|
|
||||||
This approach allows you to easily switch between the stable fortnightly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
|
||||||
|
|
||||||
### Accessing the Legacy Frontend
|
|
||||||
|
|
||||||
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
|
|
||||||
|
|
||||||
```
|
|
||||||
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
|
|
||||||
```
|
|
||||||
|
|
||||||
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
|
|
||||||
|
|
||||||
# QA
|
# QA
|
||||||
|
|
||||||
### Which GPU should I buy for this?
|
### Why did you make this?
|
||||||
|
|
||||||
[See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI)
|
I wanted to learn how Stable Diffusion worked in detail. I also wanted something clean and powerful that would let me experiment with SD without restrictions.
|
||||||
|
|
||||||
|
### Who is this for?
|
||||||
|
|
||||||
|
This is for anyone that wants to make complex workflows with SD or that wants to learn more how SD works. The interface follows closely how SD works and the code should be much more simple to understand than other SD UIs.
|
||||||
|
|||||||
84
alembic.ini
84
alembic.ini
@ -1,84 +0,0 @@
|
|||||||
# A generic, single database configuration.
|
|
||||||
|
|
||||||
[alembic]
|
|
||||||
# path to migration scripts
|
|
||||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
|
||||||
script_location = alembic_db
|
|
||||||
|
|
||||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
|
||||||
# Uncomment the line below if you want the files to be prepended with date and time
|
|
||||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
|
||||||
# for all available tokens
|
|
||||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
|
||||||
|
|
||||||
# sys.path path, will be prepended to sys.path if present.
|
|
||||||
# defaults to the current working directory.
|
|
||||||
prepend_sys_path = .
|
|
||||||
|
|
||||||
# timezone to use when rendering the date within the migration file
|
|
||||||
# as well as the filename.
|
|
||||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
|
||||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
|
||||||
# string value is passed to ZoneInfo()
|
|
||||||
# leave blank for localtime
|
|
||||||
# timezone =
|
|
||||||
|
|
||||||
# max length of characters to apply to the "slug" field
|
|
||||||
# truncate_slug_length = 40
|
|
||||||
|
|
||||||
# set to 'true' to run the environment during
|
|
||||||
# the 'revision' command, regardless of autogenerate
|
|
||||||
# revision_environment = false
|
|
||||||
|
|
||||||
# set to 'true' to allow .pyc and .pyo files without
|
|
||||||
# a source .py file to be detected as revisions in the
|
|
||||||
# versions/ directory
|
|
||||||
# sourceless = false
|
|
||||||
|
|
||||||
# version location specification; This defaults
|
|
||||||
# to alembic_db/versions. When using multiple version
|
|
||||||
# directories, initial revisions must be specified with --version-path.
|
|
||||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
|
||||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
|
|
||||||
|
|
||||||
# version path separator; As mentioned above, this is the character used to split
|
|
||||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
|
||||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
|
||||||
# Valid values for version_path_separator are:
|
|
||||||
#
|
|
||||||
# version_path_separator = :
|
|
||||||
# version_path_separator = ;
|
|
||||||
# version_path_separator = space
|
|
||||||
# version_path_separator = newline
|
|
||||||
#
|
|
||||||
# Use os.pathsep. Default configuration used for new projects.
|
|
||||||
version_path_separator = os
|
|
||||||
|
|
||||||
# set to 'true' to search source files recursively
|
|
||||||
# in each "version_locations" directory
|
|
||||||
# new in Alembic version 1.10
|
|
||||||
# recursive_version_locations = false
|
|
||||||
|
|
||||||
# the output encoding used when revision files
|
|
||||||
# are written from script.py.mako
|
|
||||||
# output_encoding = utf-8
|
|
||||||
|
|
||||||
sqlalchemy.url = sqlite:///user/comfyui.db
|
|
||||||
|
|
||||||
|
|
||||||
[post_write_hooks]
|
|
||||||
# post_write_hooks defines scripts or Python functions that are run
|
|
||||||
# on newly generated revision scripts. See the documentation for further
|
|
||||||
# detail and examples
|
|
||||||
|
|
||||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
|
||||||
# hooks = black
|
|
||||||
# black.type = console_scripts
|
|
||||||
# black.entrypoint = black
|
|
||||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
|
||||||
|
|
||||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
|
||||||
# hooks = ruff
|
|
||||||
# ruff.type = exec
|
|
||||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
|
||||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
|
||||||
@ -1,4 +0,0 @@
|
|||||||
## Generate new revision
|
|
||||||
|
|
||||||
1. Update models in `/app/database/models.py`
|
|
||||||
2. Run `alembic revision --autogenerate -m "{your message}"`
|
|
||||||
@ -1,64 +0,0 @@
|
|||||||
from sqlalchemy import engine_from_config
|
|
||||||
from sqlalchemy import pool
|
|
||||||
|
|
||||||
from alembic import context
|
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
|
||||||
# access to the values within the .ini file in use.
|
|
||||||
config = context.config
|
|
||||||
|
|
||||||
|
|
||||||
from app.database.models import Base
|
|
||||||
target_metadata = Base.metadata
|
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
|
||||||
# can be acquired:
|
|
||||||
# my_important_option = config.get_main_option("my_important_option")
|
|
||||||
# ... etc.
|
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_offline() -> None:
|
|
||||||
"""Run migrations in 'offline' mode.
|
|
||||||
This configures the context with just a URL
|
|
||||||
and not an Engine, though an Engine is acceptable
|
|
||||||
here as well. By skipping the Engine creation
|
|
||||||
we don't even need a DBAPI to be available.
|
|
||||||
Calls to context.execute() here emit the given string to the
|
|
||||||
script output.
|
|
||||||
"""
|
|
||||||
url = config.get_main_option("sqlalchemy.url")
|
|
||||||
context.configure(
|
|
||||||
url=url,
|
|
||||||
target_metadata=target_metadata,
|
|
||||||
literal_binds=True,
|
|
||||||
dialect_opts={"paramstyle": "named"},
|
|
||||||
)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
|
||||||
context.run_migrations()
|
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_online() -> None:
|
|
||||||
"""Run migrations in 'online' mode.
|
|
||||||
In this scenario we need to create an Engine
|
|
||||||
and associate a connection with the context.
|
|
||||||
"""
|
|
||||||
connectable = engine_from_config(
|
|
||||||
config.get_section(config.config_ini_section, {}),
|
|
||||||
prefix="sqlalchemy.",
|
|
||||||
poolclass=pool.NullPool,
|
|
||||||
)
|
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
|
||||||
context.configure(
|
|
||||||
connection=connection, target_metadata=target_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
|
||||||
context.run_migrations()
|
|
||||||
|
|
||||||
|
|
||||||
if context.is_offline_mode():
|
|
||||||
run_migrations_offline()
|
|
||||||
else:
|
|
||||||
run_migrations_online()
|
|
||||||
@ -1,28 +0,0 @@
|
|||||||
"""${message}
|
|
||||||
|
|
||||||
Revision ID: ${up_revision}
|
|
||||||
Revises: ${down_revision | comma,n}
|
|
||||||
Create Date: ${create_date}
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
${imports if imports else ""}
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = ${repr(up_revision)}
|
|
||||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
|
||||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
${upgrades if upgrades else "pass"}
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
${downgrades if downgrades else "pass"}
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
# ComfyUI Internal Routes
|
|
||||||
|
|
||||||
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
|
|
||||||
@ -1,78 +0,0 @@
|
|||||||
from aiohttp import web
|
|
||||||
from typing import Optional
|
|
||||||
from folder_paths import folder_names_and_paths, get_directory_by_type
|
|
||||||
from api_server.services.terminal_service import TerminalService
|
|
||||||
import app.logger
|
|
||||||
import os
|
|
||||||
|
|
||||||
class InternalRoutes:
|
|
||||||
'''
|
|
||||||
The top level web router for internal routes: /internal/*
|
|
||||||
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
|
||||||
Check README.md for more information.
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self, prompt_server):
|
|
||||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
|
||||||
self._app: Optional[web.Application] = None
|
|
||||||
self.prompt_server = prompt_server
|
|
||||||
self.terminal_service = TerminalService(prompt_server)
|
|
||||||
|
|
||||||
def setup_routes(self):
|
|
||||||
@self.routes.get('/logs')
|
|
||||||
async def get_logs(request):
|
|
||||||
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
|
||||||
|
|
||||||
@self.routes.get('/logs/raw')
|
|
||||||
async def get_raw_logs(request):
|
|
||||||
self.terminal_service.update_size()
|
|
||||||
return web.json_response({
|
|
||||||
"entries": list(app.logger.get_logs()),
|
|
||||||
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
|
|
||||||
})
|
|
||||||
|
|
||||||
@self.routes.patch('/logs/subscribe')
|
|
||||||
async def subscribe_logs(request):
|
|
||||||
json_data = await request.json()
|
|
||||||
client_id = json_data["clientId"]
|
|
||||||
enabled = json_data["enabled"]
|
|
||||||
if enabled:
|
|
||||||
self.terminal_service.subscribe(client_id)
|
|
||||||
else:
|
|
||||||
self.terminal_service.unsubscribe(client_id)
|
|
||||||
|
|
||||||
return web.Response(status=200)
|
|
||||||
|
|
||||||
|
|
||||||
@self.routes.get('/folder_paths')
|
|
||||||
async def get_folder_paths(request):
|
|
||||||
response = {}
|
|
||||||
for key in folder_names_and_paths:
|
|
||||||
response[key] = folder_names_and_paths[key][0]
|
|
||||||
return web.json_response(response)
|
|
||||||
|
|
||||||
@self.routes.get('/files/{directory_type}')
|
|
||||||
async def get_files(request: web.Request) -> web.Response:
|
|
||||||
directory_type = request.match_info['directory_type']
|
|
||||||
if directory_type not in ("output", "input", "temp"):
|
|
||||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
|
||||||
|
|
||||||
directory = get_directory_by_type(directory_type)
|
|
||||||
|
|
||||||
def is_visible_file(entry: os.DirEntry) -> bool:
|
|
||||||
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
|
|
||||||
return entry.is_file() and not entry.name.startswith('.')
|
|
||||||
|
|
||||||
sorted_files = sorted(
|
|
||||||
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
|
||||||
key=lambda entry: -entry.stat().st_mtime
|
|
||||||
)
|
|
||||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
|
||||||
|
|
||||||
|
|
||||||
def get_app(self):
|
|
||||||
if self._app is None:
|
|
||||||
self._app = web.Application()
|
|
||||||
self.setup_routes()
|
|
||||||
self._app.add_routes(self.routes)
|
|
||||||
return self._app
|
|
||||||
@ -1,60 +0,0 @@
|
|||||||
from app.logger import on_flush
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
|
|
||||||
class TerminalService:
|
|
||||||
def __init__(self, server):
|
|
||||||
self.server = server
|
|
||||||
self.cols = None
|
|
||||||
self.rows = None
|
|
||||||
self.subscriptions = set()
|
|
||||||
on_flush(self.send_messages)
|
|
||||||
|
|
||||||
def get_terminal_size(self):
|
|
||||||
try:
|
|
||||||
size = os.get_terminal_size()
|
|
||||||
return (size.columns, size.lines)
|
|
||||||
except OSError:
|
|
||||||
try:
|
|
||||||
size = shutil.get_terminal_size()
|
|
||||||
return (size.columns, size.lines)
|
|
||||||
except OSError:
|
|
||||||
return (80, 24) # fallback to 80x24
|
|
||||||
|
|
||||||
def update_size(self):
|
|
||||||
columns, lines = self.get_terminal_size()
|
|
||||||
changed = False
|
|
||||||
|
|
||||||
if columns != self.cols:
|
|
||||||
self.cols = columns
|
|
||||||
changed = True
|
|
||||||
|
|
||||||
if lines != self.rows:
|
|
||||||
self.rows = lines
|
|
||||||
changed = True
|
|
||||||
|
|
||||||
if changed:
|
|
||||||
return {"cols": self.cols, "rows": self.rows}
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def subscribe(self, client_id):
|
|
||||||
self.subscriptions.add(client_id)
|
|
||||||
|
|
||||||
def unsubscribe(self, client_id):
|
|
||||||
self.subscriptions.discard(client_id)
|
|
||||||
|
|
||||||
def send_messages(self, entries):
|
|
||||||
if not len(entries) or not len(self.subscriptions):
|
|
||||||
return
|
|
||||||
|
|
||||||
new_size = self.update_size()
|
|
||||||
|
|
||||||
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
|
||||||
if client_id not in self.server.sockets:
|
|
||||||
# Automatically unsub if the socket has disconnected
|
|
||||||
self.unsubscribe(client_id)
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
|
|
||||||
@ -1,42 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import List, Union, TypedDict, Literal
|
|
||||||
from typing_extensions import TypeGuard
|
|
||||||
class FileInfo(TypedDict):
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
type: Literal["file"]
|
|
||||||
size: int
|
|
||||||
|
|
||||||
class DirectoryInfo(TypedDict):
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
type: Literal["directory"]
|
|
||||||
|
|
||||||
FileSystemItem = Union[FileInfo, DirectoryInfo]
|
|
||||||
|
|
||||||
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
|
|
||||||
return item["type"] == "file"
|
|
||||||
|
|
||||||
class FileSystemOperations:
|
|
||||||
@staticmethod
|
|
||||||
def walk_directory(directory: str) -> List[FileSystemItem]:
|
|
||||||
file_list: List[FileSystemItem] = []
|
|
||||||
for root, dirs, files in os.walk(directory):
|
|
||||||
for name in files:
|
|
||||||
file_path = os.path.join(root, name)
|
|
||||||
relative_path = os.path.relpath(file_path, directory)
|
|
||||||
file_list.append({
|
|
||||||
"name": name,
|
|
||||||
"path": relative_path,
|
|
||||||
"type": "file",
|
|
||||||
"size": os.path.getsize(file_path)
|
|
||||||
})
|
|
||||||
for name in dirs:
|
|
||||||
dir_path = os.path.join(root, name)
|
|
||||||
relative_path = os.path.relpath(dir_path, directory)
|
|
||||||
file_list.append({
|
|
||||||
"name": name,
|
|
||||||
"path": relative_path,
|
|
||||||
"type": "directory"
|
|
||||||
})
|
|
||||||
return file_list
|
|
||||||
@ -1,65 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
from aiohttp import web
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
class AppSettings():
|
|
||||||
def __init__(self, user_manager):
|
|
||||||
self.user_manager = user_manager
|
|
||||||
|
|
||||||
def get_settings(self, request):
|
|
||||||
try:
|
|
||||||
file = self.user_manager.get_request_user_filepath(
|
|
||||||
request,
|
|
||||||
"comfy.settings.json"
|
|
||||||
)
|
|
||||||
except KeyError as e:
|
|
||||||
logging.error("User settings not found.")
|
|
||||||
raise web.HTTPUnauthorized() from e
|
|
||||||
if os.path.isfile(file):
|
|
||||||
try:
|
|
||||||
with open(file) as f:
|
|
||||||
return json.load(f)
|
|
||||||
except:
|
|
||||||
logging.error(f"The user settings file is corrupted: {file}")
|
|
||||||
return {}
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def save_settings(self, request, settings):
|
|
||||||
file = self.user_manager.get_request_user_filepath(
|
|
||||||
request, "comfy.settings.json")
|
|
||||||
with open(file, "w") as f:
|
|
||||||
f.write(json.dumps(settings, indent=4))
|
|
||||||
|
|
||||||
def add_routes(self, routes):
|
|
||||||
@routes.get("/settings")
|
|
||||||
async def get_settings(request):
|
|
||||||
return web.json_response(self.get_settings(request))
|
|
||||||
|
|
||||||
@routes.get("/settings/{id}")
|
|
||||||
async def get_setting(request):
|
|
||||||
value = None
|
|
||||||
settings = self.get_settings(request)
|
|
||||||
setting_id = request.match_info.get("id", None)
|
|
||||||
if setting_id and setting_id in settings:
|
|
||||||
value = settings[setting_id]
|
|
||||||
return web.json_response(value)
|
|
||||||
|
|
||||||
@routes.post("/settings")
|
|
||||||
async def post_settings(request):
|
|
||||||
settings = self.get_settings(request)
|
|
||||||
new_settings = await request.json()
|
|
||||||
self.save_settings(request, {**settings, **new_settings})
|
|
||||||
return web.Response(status=200)
|
|
||||||
|
|
||||||
@routes.post("/settings/{id}")
|
|
||||||
async def post_setting(request):
|
|
||||||
setting_id = request.match_info.get("id", None)
|
|
||||||
if not setting_id:
|
|
||||||
return web.Response(status=400)
|
|
||||||
settings = self.get_settings(request)
|
|
||||||
settings[setting_id] = await request.json()
|
|
||||||
self.save_settings(request, settings)
|
|
||||||
return web.Response(status=200)
|
|
||||||
@ -1,145 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
import folder_paths
|
|
||||||
import glob
|
|
||||||
from aiohttp import web
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
from utils.json_util import merge_json_recursive
|
|
||||||
|
|
||||||
|
|
||||||
# Extra locale files to load into main.json
|
|
||||||
EXTRA_LOCALE_FILES = [
|
|
||||||
"nodeDefs.json",
|
|
||||||
"commands.json",
|
|
||||||
"settings.json",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def safe_load_json_file(file_path: str) -> dict:
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
return json.load(f)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logging.error(f"Error loading {file_path}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class CustomNodeManager:
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def build_translations(self):
|
|
||||||
"""Load all custom nodes translations during initialization. Translations are
|
|
||||||
expected to be loaded from `locales/` folder.
|
|
||||||
|
|
||||||
The folder structure is expected to be the following:
|
|
||||||
- custom_nodes/
|
|
||||||
- custom_node_1/
|
|
||||||
- locales/
|
|
||||||
- en/
|
|
||||||
- main.json
|
|
||||||
- commands.json
|
|
||||||
- settings.json
|
|
||||||
|
|
||||||
returned translations are expected to be in the following format:
|
|
||||||
{
|
|
||||||
"en": {
|
|
||||||
"nodeDefs": {...},
|
|
||||||
"commands": {...},
|
|
||||||
"settings": {...},
|
|
||||||
...{other main.json keys}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
translations = {}
|
|
||||||
|
|
||||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
|
||||||
# Sort glob results for deterministic ordering
|
|
||||||
for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
|
|
||||||
locales_dir = os.path.join(custom_node_dir, "locales")
|
|
||||||
if not os.path.exists(locales_dir):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
|
|
||||||
lang_code = os.path.basename(os.path.dirname(lang_dir))
|
|
||||||
|
|
||||||
if lang_code not in translations:
|
|
||||||
translations[lang_code] = {}
|
|
||||||
|
|
||||||
# Load main.json
|
|
||||||
main_file = os.path.join(lang_dir, "main.json")
|
|
||||||
node_translations = safe_load_json_file(main_file)
|
|
||||||
|
|
||||||
# Load extra locale files
|
|
||||||
for extra_file in EXTRA_LOCALE_FILES:
|
|
||||||
extra_file_path = os.path.join(lang_dir, extra_file)
|
|
||||||
key = extra_file.split(".")[0]
|
|
||||||
json_data = safe_load_json_file(extra_file_path)
|
|
||||||
if json_data:
|
|
||||||
node_translations[key] = json_data
|
|
||||||
|
|
||||||
if node_translations:
|
|
||||||
translations[lang_code] = merge_json_recursive(
|
|
||||||
translations[lang_code], node_translations
|
|
||||||
)
|
|
||||||
|
|
||||||
return translations
|
|
||||||
|
|
||||||
def add_routes(self, routes, webapp, loadedModules):
|
|
||||||
|
|
||||||
example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"]
|
|
||||||
|
|
||||||
@routes.get("/workflow_templates")
|
|
||||||
async def get_workflow_templates(request):
|
|
||||||
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
|
|
||||||
|
|
||||||
files = []
|
|
||||||
|
|
||||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
|
||||||
for folder_name in example_workflow_folder_names:
|
|
||||||
pattern = os.path.join(folder, f"*/{folder_name}/*.json")
|
|
||||||
matched_files = glob.glob(pattern)
|
|
||||||
files.extend(matched_files)
|
|
||||||
|
|
||||||
workflow_templates_dict = (
|
|
||||||
{}
|
|
||||||
) # custom_nodes folder name -> example workflow names
|
|
||||||
for file in files:
|
|
||||||
custom_nodes_name = os.path.basename(
|
|
||||||
os.path.dirname(os.path.dirname(file))
|
|
||||||
)
|
|
||||||
workflow_name = os.path.splitext(os.path.basename(file))[0]
|
|
||||||
workflow_templates_dict.setdefault(custom_nodes_name, []).append(
|
|
||||||
workflow_name
|
|
||||||
)
|
|
||||||
return web.json_response(workflow_templates_dict)
|
|
||||||
|
|
||||||
# Serve workflow templates from custom nodes.
|
|
||||||
for module_name, module_dir in loadedModules:
|
|
||||||
for folder_name in example_workflow_folder_names:
|
|
||||||
workflows_dir = os.path.join(module_dir, folder_name)
|
|
||||||
|
|
||||||
if os.path.exists(workflows_dir):
|
|
||||||
if folder_name != "example_workflows":
|
|
||||||
logging.debug(
|
|
||||||
"Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
|
|
||||||
folder_name, module_name)
|
|
||||||
|
|
||||||
webapp.add_routes(
|
|
||||||
[
|
|
||||||
web.static(
|
|
||||||
"/api/workflow_templates/" + module_name, workflows_dir
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@routes.get("/i18n")
|
|
||||||
async def get_i18n(request):
|
|
||||||
"""Returns translations from all custom nodes' locales folders."""
|
|
||||||
return web.json_response(self.build_translations())
|
|
||||||
@ -1,112 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from app.logger import log_startup_warning
|
|
||||||
from utils.install_util import get_missing_requirements_message
|
|
||||||
from comfy.cli_args import args
|
|
||||||
|
|
||||||
_DB_AVAILABLE = False
|
|
||||||
Session = None
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from alembic import command
|
|
||||||
from alembic.config import Config
|
|
||||||
from alembic.runtime.migration import MigrationContext
|
|
||||||
from alembic.script import ScriptDirectory
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
_DB_AVAILABLE = True
|
|
||||||
except ImportError as e:
|
|
||||||
log_startup_warning(
|
|
||||||
f"""
|
|
||||||
------------------------------------------------------------------------
|
|
||||||
Error importing dependencies: {e}
|
|
||||||
{get_missing_requirements_message()}
|
|
||||||
This error is happening because ComfyUI now uses a local sqlite database.
|
|
||||||
------------------------------------------------------------------------
|
|
||||||
""".strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def dependencies_available():
|
|
||||||
"""
|
|
||||||
Temporary function to check if the dependencies are available
|
|
||||||
"""
|
|
||||||
return _DB_AVAILABLE
|
|
||||||
|
|
||||||
|
|
||||||
def can_create_session():
|
|
||||||
"""
|
|
||||||
Temporary function to check if the database is available to create a session
|
|
||||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
|
||||||
"""
|
|
||||||
return dependencies_available() and Session is not None
|
|
||||||
|
|
||||||
|
|
||||||
def get_alembic_config():
|
|
||||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
|
||||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
|
||||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
|
||||||
|
|
||||||
config = Config(config_path)
|
|
||||||
config.set_main_option("script_location", scripts_path)
|
|
||||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_path():
|
|
||||||
url = args.database_url
|
|
||||||
if url.startswith("sqlite:///"):
|
|
||||||
return url.split("///")[1]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
|
||||||
db_url = args.database_url
|
|
||||||
logging.debug(f"Database URL: {db_url}")
|
|
||||||
db_path = get_db_path()
|
|
||||||
db_exists = os.path.exists(db_path)
|
|
||||||
|
|
||||||
config = get_alembic_config()
|
|
||||||
|
|
||||||
# Check if we need to upgrade
|
|
||||||
engine = create_engine(db_url)
|
|
||||||
conn = engine.connect()
|
|
||||||
|
|
||||||
context = MigrationContext.configure(conn)
|
|
||||||
current_rev = context.get_current_revision()
|
|
||||||
|
|
||||||
script = ScriptDirectory.from_config(config)
|
|
||||||
target_rev = script.get_current_head()
|
|
||||||
|
|
||||||
if target_rev is None:
|
|
||||||
logging.warning("No target revision found.")
|
|
||||||
elif current_rev != target_rev:
|
|
||||||
# Backup the database pre upgrade
|
|
||||||
backup_path = db_path + ".bkp"
|
|
||||||
if db_exists:
|
|
||||||
shutil.copy(db_path, backup_path)
|
|
||||||
else:
|
|
||||||
backup_path = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
command.upgrade(config, target_rev)
|
|
||||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
|
||||||
except Exception as e:
|
|
||||||
if backup_path:
|
|
||||||
# Restore the database from backup if upgrade fails
|
|
||||||
shutil.copy(backup_path, db_path)
|
|
||||||
os.remove(backup_path)
|
|
||||||
logging.exception("Error upgrading database: ")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
global Session
|
|
||||||
Session = sessionmaker(bind=engine)
|
|
||||||
|
|
||||||
|
|
||||||
def create_session():
|
|
||||||
return Session()
|
|
||||||
@ -1,14 +0,0 @@
|
|||||||
from sqlalchemy.orm import declarative_base
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
def to_dict(obj):
|
|
||||||
fields = obj.__table__.columns.keys()
|
|
||||||
return {
|
|
||||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
|
||||||
for field in fields
|
|
||||||
if (val := getattr(obj, field))
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO: Define models here
|
|
||||||
@ -1,457 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import zipfile
|
|
||||||
import importlib
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import cached_property
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, TypedDict, Optional
|
|
||||||
from aiohttp import web
|
|
||||||
from importlib.metadata import version
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from typing_extensions import NotRequired
|
|
||||||
|
|
||||||
from utils.install_util import get_missing_requirements_message, requirements_path
|
|
||||||
|
|
||||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
|
||||||
import app.logger
|
|
||||||
|
|
||||||
|
|
||||||
def frontend_install_warning_message():
|
|
||||||
return f"""
|
|
||||||
{get_missing_requirements_message()}
|
|
||||||
|
|
||||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
|
||||||
""".strip()
|
|
||||||
|
|
||||||
def parse_version(version: str) -> tuple[int, int, int]:
|
|
||||||
return tuple(map(int, version.split(".")))
|
|
||||||
|
|
||||||
def is_valid_version(version: str) -> bool:
|
|
||||||
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
|
|
||||||
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
|
|
||||||
return bool(re.match(pattern, version))
|
|
||||||
|
|
||||||
def get_installed_frontend_version():
|
|
||||||
"""Get the currently installed frontend package version."""
|
|
||||||
frontend_version_str = version("comfyui-frontend-package")
|
|
||||||
return frontend_version_str
|
|
||||||
|
|
||||||
|
|
||||||
def get_required_frontend_version():
|
|
||||||
"""Get the required frontend version from requirements.txt."""
|
|
||||||
try:
|
|
||||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if line.startswith("comfyui-frontend-package=="):
|
|
||||||
version_str = line.split("==")[-1]
|
|
||||||
if not is_valid_version(version_str):
|
|
||||||
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
|
||||||
return None
|
|
||||||
return version_str
|
|
||||||
logging.error("comfyui-frontend-package not found in requirements.txt")
|
|
||||||
return None
|
|
||||||
except FileNotFoundError:
|
|
||||||
logging.error("requirements.txt not found. Cannot determine required frontend version.")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error reading requirements.txt: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def check_frontend_version():
|
|
||||||
"""Check if the frontend version is up to date."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
frontend_version_str = get_installed_frontend_version()
|
|
||||||
frontend_version = parse_version(frontend_version_str)
|
|
||||||
required_frontend_str = get_required_frontend_version()
|
|
||||||
required_frontend = parse_version(required_frontend_str)
|
|
||||||
if frontend_version < required_frontend:
|
|
||||||
app.logger.log_startup_warning(
|
|
||||||
f"""
|
|
||||||
________________________________________________________________________
|
|
||||||
WARNING WARNING WARNING WARNING WARNING
|
|
||||||
|
|
||||||
Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
|
|
||||||
|
|
||||||
{frontend_install_warning_message()}
|
|
||||||
________________________________________________________________________
|
|
||||||
""".strip()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to check frontend version: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
REQUEST_TIMEOUT = 10 # seconds
|
|
||||||
|
|
||||||
|
|
||||||
class Asset(TypedDict):
|
|
||||||
url: str
|
|
||||||
|
|
||||||
|
|
||||||
class Release(TypedDict):
|
|
||||||
id: int
|
|
||||||
tag_name: str
|
|
||||||
name: str
|
|
||||||
prerelease: bool
|
|
||||||
created_at: str
|
|
||||||
published_at: str
|
|
||||||
body: str
|
|
||||||
assets: NotRequired[list[Asset]]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FrontEndProvider:
|
|
||||||
owner: str
|
|
||||||
repo: str
|
|
||||||
|
|
||||||
@property
|
|
||||||
def folder_name(self) -> str:
|
|
||||||
return f"{self.owner}_{self.repo}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def release_url(self) -> str:
|
|
||||||
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def all_releases(self) -> list[Release]:
|
|
||||||
releases = []
|
|
||||||
api_url = self.release_url
|
|
||||||
while api_url:
|
|
||||||
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
|
|
||||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
|
||||||
releases.extend(response.json())
|
|
||||||
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
|
|
||||||
if "next" in response.links:
|
|
||||||
api_url = response.links["next"]["url"]
|
|
||||||
else:
|
|
||||||
api_url = None
|
|
||||||
return releases
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def latest_release(self) -> Release:
|
|
||||||
latest_release_url = f"{self.release_url}/latest"
|
|
||||||
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
|
|
||||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def latest_prerelease(self) -> Release:
|
|
||||||
"""Get the latest pre-release version - even if it's older than the latest release"""
|
|
||||||
release = [release for release in self.all_releases if release["prerelease"]]
|
|
||||||
|
|
||||||
if not release:
|
|
||||||
raise ValueError("No pre-releases found")
|
|
||||||
|
|
||||||
# GitHub returns releases in reverse chronological order, so first is latest
|
|
||||||
return release[0]
|
|
||||||
|
|
||||||
def get_release(self, version: str) -> Release:
|
|
||||||
if version == "latest":
|
|
||||||
return self.latest_release
|
|
||||||
elif version == "prerelease":
|
|
||||||
return self.latest_prerelease
|
|
||||||
else:
|
|
||||||
for release in self.all_releases:
|
|
||||||
if release["tag_name"] in [version, f"v{version}"]:
|
|
||||||
return release
|
|
||||||
raise ValueError(f"Version {version} not found in releases")
|
|
||||||
|
|
||||||
|
|
||||||
def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
|
||||||
"""Download dist.zip from github release."""
|
|
||||||
asset_url = None
|
|
||||||
for asset in release.get("assets", []):
|
|
||||||
if asset["name"] == "dist.zip":
|
|
||||||
asset_url = asset["url"]
|
|
||||||
break
|
|
||||||
|
|
||||||
if not asset_url:
|
|
||||||
raise ValueError("dist.zip not found in the release assets")
|
|
||||||
|
|
||||||
# Use a temporary file to download the zip content
|
|
||||||
with tempfile.TemporaryFile() as tmp_file:
|
|
||||||
headers = {"Accept": "application/octet-stream"}
|
|
||||||
response = requests.get(
|
|
||||||
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
|
|
||||||
)
|
|
||||||
response.raise_for_status() # Ensure we got a successful response
|
|
||||||
|
|
||||||
# Write the content to the temporary file
|
|
||||||
tmp_file.write(response.content)
|
|
||||||
|
|
||||||
# Go back to the beginning of the temporary file
|
|
||||||
tmp_file.seek(0)
|
|
||||||
|
|
||||||
# Extract the zip file content to the destination path
|
|
||||||
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
|
|
||||||
zip_ref.extractall(destination_path)
|
|
||||||
|
|
||||||
|
|
||||||
class FrontendManager:
|
|
||||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_required_frontend_version(cls) -> str:
|
|
||||||
"""Get the required frontend package version."""
|
|
||||||
return get_required_frontend_version()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_installed_templates_version(cls) -> str:
|
|
||||||
"""Get the currently installed workflow templates package version."""
|
|
||||||
try:
|
|
||||||
templates_version_str = version("comfyui-workflow-templates")
|
|
||||||
return templates_version_str
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_required_templates_version(cls) -> str:
|
|
||||||
"""Get the required workflow templates version from requirements.txt."""
|
|
||||||
try:
|
|
||||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if line.startswith("comfyui-workflow-templates=="):
|
|
||||||
version_str = line.split("==")[-1]
|
|
||||||
if not is_valid_version(version_str):
|
|
||||||
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
|
|
||||||
return None
|
|
||||||
return version_str
|
|
||||||
logging.error("comfyui-workflow-templates not found in requirements.txt")
|
|
||||||
return None
|
|
||||||
except FileNotFoundError:
|
|
||||||
logging.error("requirements.txt not found. Cannot determine required templates version.")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error reading requirements.txt: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def default_frontend_path(cls) -> str:
|
|
||||||
try:
|
|
||||||
import comfyui_frontend_package
|
|
||||||
|
|
||||||
return str(importlib.resources.files(comfyui_frontend_package) / "static")
|
|
||||||
except ImportError:
|
|
||||||
logging.error(
|
|
||||||
f"""
|
|
||||||
********** ERROR ***********
|
|
||||||
|
|
||||||
comfyui-frontend-package is not installed.
|
|
||||||
|
|
||||||
{frontend_install_warning_message()}
|
|
||||||
|
|
||||||
********** ERROR ***********
|
|
||||||
""".strip()
|
|
||||||
)
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def template_asset_map(cls) -> Optional[Dict[str, str]]:
|
|
||||||
"""Return a mapping of template asset names to their absolute paths."""
|
|
||||||
try:
|
|
||||||
from comfyui_workflow_templates import (
|
|
||||||
get_asset_path,
|
|
||||||
iter_templates,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
logging.error(
|
|
||||||
f"""
|
|
||||||
********** ERROR ***********
|
|
||||||
|
|
||||||
comfyui-workflow-templates is not installed.
|
|
||||||
|
|
||||||
{frontend_install_warning_message()}
|
|
||||||
|
|
||||||
********** ERROR ***********
|
|
||||||
""".strip()
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
template_entries = list(iter_templates())
|
|
||||||
except Exception as exc:
|
|
||||||
logging.error(f"Failed to enumerate workflow templates: {exc}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
asset_map: Dict[str, str] = {}
|
|
||||||
try:
|
|
||||||
for entry in template_entries:
|
|
||||||
for asset in entry.assets:
|
|
||||||
asset_map[asset.filename] = get_asset_path(
|
|
||||||
entry.template_id, asset.filename
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logging.error(f"Failed to resolve template asset paths: {exc}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not asset_map:
|
|
||||||
logging.error("No workflow template assets found. Did the packages install correctly?")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return asset_map
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def legacy_templates_path(cls) -> Optional[str]:
|
|
||||||
"""Return the legacy templates directory shipped inside the meta package."""
|
|
||||||
try:
|
|
||||||
import comfyui_workflow_templates
|
|
||||||
|
|
||||||
return str(
|
|
||||||
importlib.resources.files(comfyui_workflow_templates) / "templates"
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
logging.error(
|
|
||||||
f"""
|
|
||||||
********** ERROR ***********
|
|
||||||
|
|
||||||
comfyui-workflow-templates is not installed.
|
|
||||||
|
|
||||||
{frontend_install_warning_message()}
|
|
||||||
|
|
||||||
********** ERROR ***********
|
|
||||||
""".strip()
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def embedded_docs_path(cls) -> str:
|
|
||||||
"""Get the path to embedded documentation"""
|
|
||||||
try:
|
|
||||||
import comfyui_embedded_docs
|
|
||||||
|
|
||||||
return str(
|
|
||||||
importlib.resources.files(comfyui_embedded_docs) / "docs"
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
logging.info("comfyui-embedded-docs package not found")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
value (str): The version string to parse.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[str, str]: A tuple containing provider name and version.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
argparse.ArgumentTypeError: If the version string is invalid.
|
|
||||||
"""
|
|
||||||
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$"
|
|
||||||
match_result = re.match(VERSION_PATTERN, value)
|
|
||||||
if match_result is None:
|
|
||||||
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
|
||||||
|
|
||||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def init_frontend_unsafe(
|
|
||||||
cls, version_string: str, provider: Optional[FrontEndProvider] = None
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Initializes the frontend for the specified version.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
version_string (str): The version string.
|
|
||||||
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The path to the initialized frontend.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If there is an error during the initialization process.
|
|
||||||
main error source might be request timeout or invalid URL.
|
|
||||||
"""
|
|
||||||
if version_string == DEFAULT_VERSION_STRING:
|
|
||||||
check_frontend_version()
|
|
||||||
return cls.default_frontend_path()
|
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
|
||||||
|
|
||||||
if version.startswith("v"):
|
|
||||||
expected_path = str(
|
|
||||||
Path(cls.CUSTOM_FRONTENDS_ROOT)
|
|
||||||
/ f"{repo_owner}_{repo_name}"
|
|
||||||
/ version.lstrip("v")
|
|
||||||
)
|
|
||||||
if os.path.exists(expected_path):
|
|
||||||
logging.info(
|
|
||||||
f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
|
|
||||||
)
|
|
||||||
return expected_path
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
|
|
||||||
)
|
|
||||||
|
|
||||||
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
|
||||||
release = provider.get_release(version)
|
|
||||||
|
|
||||||
semantic_version = release["tag_name"].lstrip("v")
|
|
||||||
web_root = str(
|
|
||||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
|
||||||
)
|
|
||||||
if not os.path.exists(web_root):
|
|
||||||
try:
|
|
||||||
os.makedirs(web_root, exist_ok=True)
|
|
||||||
logging.info(
|
|
||||||
"Downloading frontend(%s) version(%s) to (%s)",
|
|
||||||
provider.folder_name,
|
|
||||||
semantic_version,
|
|
||||||
web_root,
|
|
||||||
)
|
|
||||||
logging.debug(release)
|
|
||||||
download_release_asset_zip(release, destination_path=web_root)
|
|
||||||
finally:
|
|
||||||
# Clean up the directory if it is empty, i.e. the download failed
|
|
||||||
if not os.listdir(web_root):
|
|
||||||
os.rmdir(web_root)
|
|
||||||
|
|
||||||
return web_root
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def init_frontend(cls, version_string: str) -> str:
|
|
||||||
"""
|
|
||||||
Initializes the frontend with the specified version string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
version_string (str): The version string to initialize the frontend with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The path of the initialized frontend.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return cls.init_frontend_unsafe(version_string)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("Failed to initialize frontend: %s", e)
|
|
||||||
logging.info("Falling back to the default frontend.")
|
|
||||||
check_frontend_version()
|
|
||||||
return cls.default_frontend_path()
|
|
||||||
@classmethod
|
|
||||||
def template_asset_handler(cls):
|
|
||||||
assets = cls.template_asset_map()
|
|
||||||
if not assets:
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def serve_template(request: web.Request) -> web.StreamResponse:
|
|
||||||
rel_path = request.match_info.get("path", "")
|
|
||||||
target = assets.get(rel_path)
|
|
||||||
if target is None:
|
|
||||||
raise web.HTTPNotFound()
|
|
||||||
return web.FileResponse(target)
|
|
||||||
|
|
||||||
return serve_template
|
|
||||||
@ -1,98 +0,0 @@
|
|||||||
from collections import deque
|
|
||||||
from datetime import datetime
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
|
|
||||||
logs = None
|
|
||||||
stdout_interceptor = None
|
|
||||||
stderr_interceptor = None
|
|
||||||
|
|
||||||
|
|
||||||
class LogInterceptor(io.TextIOWrapper):
|
|
||||||
def __init__(self, stream, *args, **kwargs):
|
|
||||||
buffer = stream.buffer
|
|
||||||
encoding = stream.encoding
|
|
||||||
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._flush_callbacks = []
|
|
||||||
self._logs_since_flush = []
|
|
||||||
|
|
||||||
def write(self, data):
|
|
||||||
entry = {"t": datetime.now().isoformat(), "m": data}
|
|
||||||
with self._lock:
|
|
||||||
self._logs_since_flush.append(entry)
|
|
||||||
|
|
||||||
# Simple handling for cr to overwrite the last output if it isnt a full line
|
|
||||||
# else logs just get full of progress messages
|
|
||||||
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
|
|
||||||
logs.pop()
|
|
||||||
logs.append(entry)
|
|
||||||
super().write(data)
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
super().flush()
|
|
||||||
for cb in self._flush_callbacks:
|
|
||||||
cb(self._logs_since_flush)
|
|
||||||
self._logs_since_flush = []
|
|
||||||
|
|
||||||
def on_flush(self, callback):
|
|
||||||
self._flush_callbacks.append(callback)
|
|
||||||
|
|
||||||
|
|
||||||
def get_logs():
|
|
||||||
return logs
|
|
||||||
|
|
||||||
|
|
||||||
def on_flush(callback):
|
|
||||||
if stdout_interceptor is not None:
|
|
||||||
stdout_interceptor.on_flush(callback)
|
|
||||||
if stderr_interceptor is not None:
|
|
||||||
stderr_interceptor.on_flush(callback)
|
|
||||||
|
|
||||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
|
|
||||||
global logs
|
|
||||||
if logs:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Override output streams and log to buffer
|
|
||||||
logs = deque(maxlen=capacity)
|
|
||||||
|
|
||||||
global stdout_interceptor
|
|
||||||
global stderr_interceptor
|
|
||||||
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
|
|
||||||
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
|
|
||||||
|
|
||||||
# Setup default global logger
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(log_level)
|
|
||||||
|
|
||||||
stream_handler = logging.StreamHandler()
|
|
||||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
|
||||||
|
|
||||||
if use_stdout:
|
|
||||||
# Only errors and critical to stderr
|
|
||||||
stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
|
|
||||||
|
|
||||||
# Lesser to stdout
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
stdout_handler.setFormatter(logging.Formatter("%(message)s"))
|
|
||||||
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
|
|
||||||
logger.addHandler(stdout_handler)
|
|
||||||
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
|
|
||||||
STARTUP_WARNINGS = []
|
|
||||||
|
|
||||||
|
|
||||||
def log_startup_warning(msg):
|
|
||||||
logging.warning(msg)
|
|
||||||
STARTUP_WARNINGS.append(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def print_startup_warnings():
|
|
||||||
for s in STARTUP_WARNINGS:
|
|
||||||
logging.warning(s)
|
|
||||||
STARTUP_WARNINGS.clear()
|
|
||||||
@ -1,195 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
import folder_paths
|
|
||||||
import glob
|
|
||||||
import comfy.utils
|
|
||||||
from aiohttp import web
|
|
||||||
from PIL import Image
|
|
||||||
from io import BytesIO
|
|
||||||
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
|
||||||
|
|
||||||
|
|
||||||
class ModelFileManager:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
|
||||||
|
|
||||||
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
|
||||||
return self.cache.get(key, default)
|
|
||||||
|
|
||||||
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
|
|
||||||
self.cache[key] = value
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
self.cache.clear()
|
|
||||||
|
|
||||||
def add_routes(self, routes):
|
|
||||||
# NOTE: This is an experiment to replace `/models`
|
|
||||||
@routes.get("/experiment/models")
|
|
||||||
async def get_model_folders(request):
|
|
||||||
model_types = list(folder_paths.folder_names_and_paths.keys())
|
|
||||||
folder_black_list = ["configs", "custom_nodes"]
|
|
||||||
output_folders: list[dict] = []
|
|
||||||
for folder in model_types:
|
|
||||||
if folder in folder_black_list:
|
|
||||||
continue
|
|
||||||
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
|
|
||||||
return web.json_response(output_folders)
|
|
||||||
|
|
||||||
# NOTE: This is an experiment to replace `/models/{folder}`
|
|
||||||
@routes.get("/experiment/models/{folder}")
|
|
||||||
async def get_all_models(request):
|
|
||||||
folder = request.match_info.get("folder", None)
|
|
||||||
if not folder in folder_paths.folder_names_and_paths:
|
|
||||||
return web.Response(status=404)
|
|
||||||
files = self.get_model_file_list(folder)
|
|
||||||
return web.json_response(files)
|
|
||||||
|
|
||||||
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
|
||||||
async def get_model_preview(request):
|
|
||||||
folder_name = request.match_info.get("folder", None)
|
|
||||||
path_index = int(request.match_info.get("path_index", None))
|
|
||||||
filename = request.match_info.get("filename", None)
|
|
||||||
|
|
||||||
if not folder_name in folder_paths.folder_names_and_paths:
|
|
||||||
return web.Response(status=404)
|
|
||||||
|
|
||||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
|
||||||
folder = folders[0][path_index]
|
|
||||||
full_filename = os.path.join(folder, filename)
|
|
||||||
|
|
||||||
previews = self.get_model_previews(full_filename)
|
|
||||||
default_preview = previews[0] if len(previews) > 0 else None
|
|
||||||
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
|
||||||
return web.Response(status=404)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with Image.open(default_preview) as img:
|
|
||||||
img_bytes = BytesIO()
|
|
||||||
img.save(img_bytes, format="WEBP")
|
|
||||||
img_bytes.seek(0)
|
|
||||||
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
|
|
||||||
except:
|
|
||||||
return web.Response(status=404)
|
|
||||||
|
|
||||||
def get_model_file_list(self, folder_name: str):
|
|
||||||
folder_name = map_legacy(folder_name)
|
|
||||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
|
||||||
output_list: list[dict] = []
|
|
||||||
|
|
||||||
for index, folder in enumerate(folders[0]):
|
|
||||||
if not os.path.isdir(folder):
|
|
||||||
continue
|
|
||||||
out = self.cache_model_file_list_(folder)
|
|
||||||
if out is None:
|
|
||||||
out = self.recursive_search_models_(folder, index)
|
|
||||||
self.set_cache(folder, out)
|
|
||||||
output_list.extend(out[0])
|
|
||||||
|
|
||||||
return output_list
|
|
||||||
|
|
||||||
def cache_model_file_list_(self, folder: str):
|
|
||||||
model_file_list_cache = self.get_cache(folder)
|
|
||||||
|
|
||||||
if model_file_list_cache is None:
|
|
||||||
return None
|
|
||||||
if not os.path.isdir(folder):
|
|
||||||
return None
|
|
||||||
if os.path.getmtime(folder) != model_file_list_cache[1]:
|
|
||||||
return None
|
|
||||||
for x in model_file_list_cache[1]:
|
|
||||||
time_modified = model_file_list_cache[1][x]
|
|
||||||
folder = x
|
|
||||||
if os.path.getmtime(folder) != time_modified:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return model_file_list_cache
|
|
||||||
|
|
||||||
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
|
|
||||||
if not os.path.isdir(directory):
|
|
||||||
return [], {}, time.perf_counter()
|
|
||||||
|
|
||||||
excluded_dir_names = [".git"]
|
|
||||||
# TODO use settings
|
|
||||||
include_hidden_files = False
|
|
||||||
|
|
||||||
result: list[str] = []
|
|
||||||
dirs: dict[str, float] = {}
|
|
||||||
|
|
||||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
|
||||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
|
||||||
if not include_hidden_files:
|
|
||||||
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
|
|
||||||
filenames = [f for f in filenames if not f.startswith(".")]
|
|
||||||
|
|
||||||
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
|
|
||||||
|
|
||||||
for file_name in filenames:
|
|
||||||
try:
|
|
||||||
full_path = os.path.join(dirpath, file_name)
|
|
||||||
relative_path = os.path.relpath(full_path, directory)
|
|
||||||
|
|
||||||
# Get file metadata
|
|
||||||
file_info = {
|
|
||||||
"name": relative_path,
|
|
||||||
"pathIndex": pathIndex,
|
|
||||||
"modified": os.path.getmtime(full_path), # Add modification time
|
|
||||||
"created": os.path.getctime(full_path), # Add creation time
|
|
||||||
"size": os.path.getsize(full_path) # Add file size
|
|
||||||
}
|
|
||||||
result.append(file_info)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
for d in subdirs:
|
|
||||||
path: str = os.path.join(dirpath, d)
|
|
||||||
try:
|
|
||||||
dirs[path] = os.path.getmtime(path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return result, dirs, time.perf_counter()
|
|
||||||
|
|
||||||
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
|
||||||
dirname = os.path.dirname(filepath)
|
|
||||||
|
|
||||||
if not os.path.exists(dirname):
|
|
||||||
return []
|
|
||||||
|
|
||||||
basename = os.path.splitext(filepath)[0]
|
|
||||||
match_files = glob.glob(f"{basename}.*", recursive=False)
|
|
||||||
image_files = filter_files_content_types(match_files, "image")
|
|
||||||
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
|
|
||||||
safetensors_metadata = {}
|
|
||||||
|
|
||||||
result: list[str | BytesIO] = []
|
|
||||||
|
|
||||||
for filename in image_files:
|
|
||||||
_basename = os.path.splitext(filename)[0]
|
|
||||||
if _basename == basename:
|
|
||||||
result.append(filename)
|
|
||||||
if _basename == f"{basename}.preview":
|
|
||||||
result.append(filename)
|
|
||||||
|
|
||||||
if safetensors_file:
|
|
||||||
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
|
||||||
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
|
|
||||||
if header:
|
|
||||||
safetensors_metadata = json.loads(header)
|
|
||||||
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
|
|
||||||
if safetensors_images:
|
|
||||||
safetensors_images = json.loads(safetensors_images)
|
|
||||||
for image in safetensors_images:
|
|
||||||
result.append(BytesIO(base64.b64decode(image)))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
self.clear_cache()
|
|
||||||
@ -1,112 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TypedDict
|
|
||||||
import os
|
|
||||||
import folder_paths
|
|
||||||
import glob
|
|
||||||
from aiohttp import web
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
|
|
||||||
class Source:
|
|
||||||
custom_node = "custom_node"
|
|
||||||
|
|
||||||
class SubgraphEntry(TypedDict):
|
|
||||||
source: str
|
|
||||||
"""
|
|
||||||
Source of subgraph - custom_nodes vs templates.
|
|
||||||
"""
|
|
||||||
path: str
|
|
||||||
"""
|
|
||||||
Relative path of the subgraph file.
|
|
||||||
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
|
|
||||||
"""
|
|
||||||
name: str
|
|
||||||
"""
|
|
||||||
Name of subgraph file.
|
|
||||||
"""
|
|
||||||
info: CustomNodeSubgraphEntryInfo
|
|
||||||
"""
|
|
||||||
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
|
|
||||||
"""
|
|
||||||
data: str
|
|
||||||
|
|
||||||
class CustomNodeSubgraphEntryInfo(TypedDict):
|
|
||||||
node_pack: str
|
|
||||||
"""Node pack name."""
|
|
||||||
|
|
||||||
class SubgraphManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
|
||||||
|
|
||||||
async def load_entry_data(self, entry: SubgraphEntry):
|
|
||||||
with open(entry['path'], 'r') as f:
|
|
||||||
entry['data'] = f.read()
|
|
||||||
return entry
|
|
||||||
|
|
||||||
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
|
|
||||||
if entry is None:
|
|
||||||
return None
|
|
||||||
entry = entry.copy()
|
|
||||||
entry.pop('path', None)
|
|
||||||
if remove_data:
|
|
||||||
entry.pop('data', None)
|
|
||||||
return entry
|
|
||||||
|
|
||||||
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
|
|
||||||
entries = entries.copy()
|
|
||||||
for key in list(entries.keys()):
|
|
||||||
entries[key] = await self.sanitize_entry(entries[key], remove_data)
|
|
||||||
return entries
|
|
||||||
|
|
||||||
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
|
||||||
# if not forced to reload and cached, return cache
|
|
||||||
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
|
||||||
return self.cached_custom_node_subgraphs
|
|
||||||
# Load subgraphs from custom nodes
|
|
||||||
subfolder = "subgraphs"
|
|
||||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
|
||||||
|
|
||||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
|
||||||
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
|
|
||||||
matched_files = glob.glob(pattern)
|
|
||||||
for file in matched_files:
|
|
||||||
# replace backslashes with forward slashes
|
|
||||||
file = file.replace('\\', '/')
|
|
||||||
info: CustomNodeSubgraphEntryInfo = {
|
|
||||||
"node_pack": "custom_nodes." + file.split('/')[-3]
|
|
||||||
}
|
|
||||||
source = Source.custom_node
|
|
||||||
# hash source + path to make sure id will be as unique as possible, but
|
|
||||||
# reproducible across backend reloads
|
|
||||||
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
|
||||||
entry: SubgraphEntry = {
|
|
||||||
"source": Source.custom_node,
|
|
||||||
"name": os.path.splitext(os.path.basename(file))[0],
|
|
||||||
"path": file,
|
|
||||||
"info": info,
|
|
||||||
}
|
|
||||||
subgraphs_dict[id] = entry
|
|
||||||
self.cached_custom_node_subgraphs = subgraphs_dict
|
|
||||||
return subgraphs_dict
|
|
||||||
|
|
||||||
async def get_custom_node_subgraph(self, id: str, loadedModules):
|
|
||||||
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
|
|
||||||
entry: SubgraphEntry = subgraphs.get(id, None)
|
|
||||||
if entry is not None and entry.get('data', None) is None:
|
|
||||||
await self.load_entry_data(entry)
|
|
||||||
return entry
|
|
||||||
|
|
||||||
def add_routes(self, routes, loadedModules):
|
|
||||||
@routes.get("/global_subgraphs")
|
|
||||||
async def get_global_subgraphs(request):
|
|
||||||
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
|
|
||||||
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
|
|
||||||
# that's the reasoning for the current implementation
|
|
||||||
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
|
||||||
|
|
||||||
@routes.get("/global_subgraphs/{id}")
|
|
||||||
async def get_global_subgraph(request):
|
|
||||||
id = request.match_info.get("id", None)
|
|
||||||
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
|
|
||||||
return web.json_response(await self.sanitize_entry(subgraph))
|
|
||||||
@ -1,456 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
import glob
|
|
||||||
import shutil
|
|
||||||
import logging
|
|
||||||
from aiohttp import web
|
|
||||||
from urllib import parse
|
|
||||||
from comfy.cli_args import args
|
|
||||||
import folder_paths
|
|
||||||
from .app_settings import AppSettings
|
|
||||||
from typing import TypedDict
|
|
||||||
|
|
||||||
default_user = "default"
|
|
||||||
|
|
||||||
|
|
||||||
class FileInfo(TypedDict):
|
|
||||||
path: str
|
|
||||||
size: int
|
|
||||||
modified: int
|
|
||||||
created: int
|
|
||||||
|
|
||||||
|
|
||||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
|
||||||
return {
|
|
||||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
|
||||||
"size": os.path.getsize(path),
|
|
||||||
"modified": os.path.getmtime(path),
|
|
||||||
"created": os.path.getctime(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class UserManager():
|
|
||||||
def __init__(self):
|
|
||||||
user_directory = folder_paths.get_user_directory()
|
|
||||||
|
|
||||||
self.settings = AppSettings(self)
|
|
||||||
if not os.path.exists(user_directory):
|
|
||||||
os.makedirs(user_directory, exist_ok=True)
|
|
||||||
if not args.multi_user:
|
|
||||||
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
|
||||||
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
|
||||||
|
|
||||||
if args.multi_user:
|
|
||||||
if os.path.isfile(self.get_users_file()):
|
|
||||||
with open(self.get_users_file()) as f:
|
|
||||||
self.users = json.load(f)
|
|
||||||
else:
|
|
||||||
self.users = {}
|
|
||||||
else:
|
|
||||||
self.users = {"default": "default"}
|
|
||||||
|
|
||||||
def get_users_file(self):
|
|
||||||
return os.path.join(folder_paths.get_user_directory(), "users.json")
|
|
||||||
|
|
||||||
def get_request_user_id(self, request):
|
|
||||||
user = "default"
|
|
||||||
if args.multi_user and "comfy-user" in request.headers:
|
|
||||||
user = request.headers["comfy-user"]
|
|
||||||
# Block System Users (use same error message to prevent probing)
|
|
||||||
if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
|
||||||
raise KeyError("Unknown user: " + user)
|
|
||||||
|
|
||||||
if user not in self.users:
|
|
||||||
raise KeyError("Unknown user: " + user)
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
|
||||||
if type == "userdata":
|
|
||||||
root_dir = folder_paths.get_user_directory()
|
|
||||||
else:
|
|
||||||
raise KeyError("Unknown filepath type:" + type)
|
|
||||||
|
|
||||||
user = self.get_request_user_id(request)
|
|
||||||
user_root = folder_paths.get_public_user_directory(user)
|
|
||||||
if user_root is None:
|
|
||||||
return None
|
|
||||||
path = user_root
|
|
||||||
|
|
||||||
# prevent leaving /{type}
|
|
||||||
if os.path.commonpath((root_dir, user_root)) != root_dir:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if file is not None:
|
|
||||||
# Check if filename is url encoded
|
|
||||||
if "%" in file:
|
|
||||||
file = parse.unquote(file)
|
|
||||||
|
|
||||||
# prevent leaving /{type}/{user}
|
|
||||||
path = os.path.abspath(os.path.join(user_root, file))
|
|
||||||
if os.path.commonpath((user_root, path)) != user_root:
|
|
||||||
return None
|
|
||||||
|
|
||||||
parent = os.path.split(path)[0]
|
|
||||||
|
|
||||||
if create_dir and not os.path.exists(parent):
|
|
||||||
os.makedirs(parent, exist_ok=True)
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
def add_user(self, name):
|
|
||||||
name = name.strip()
|
|
||||||
if not name:
|
|
||||||
raise ValueError("username not provided")
|
|
||||||
if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
|
||||||
raise ValueError("System User prefix not allowed")
|
|
||||||
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
|
|
||||||
if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
|
||||||
raise ValueError("System User prefix not allowed")
|
|
||||||
user_id = user_id + "_" + str(uuid.uuid4())
|
|
||||||
|
|
||||||
self.users[user_id] = name
|
|
||||||
|
|
||||||
with open(self.get_users_file(), "w") as f:
|
|
||||||
json.dump(self.users, f)
|
|
||||||
|
|
||||||
return user_id
|
|
||||||
|
|
||||||
def add_routes(self, routes):
|
|
||||||
self.settings.add_routes(routes)
|
|
||||||
|
|
||||||
@routes.get("/users")
|
|
||||||
async def get_users(request):
|
|
||||||
if args.multi_user:
|
|
||||||
return web.json_response({"storage": "server", "users": self.users})
|
|
||||||
else:
|
|
||||||
user_dir = self.get_request_user_filepath(request, None, create_dir=False)
|
|
||||||
return web.json_response({
|
|
||||||
"storage": "server",
|
|
||||||
"migrated": os.path.exists(user_dir)
|
|
||||||
})
|
|
||||||
|
|
||||||
@routes.post("/users")
|
|
||||||
async def post_users(request):
|
|
||||||
body = await request.json()
|
|
||||||
username = body["username"]
|
|
||||||
if username in self.users.values():
|
|
||||||
return web.json_response({"error": "Duplicate username."}, status=400)
|
|
||||||
|
|
||||||
try:
|
|
||||||
user_id = self.add_user(username)
|
|
||||||
except ValueError as e:
|
|
||||||
return web.json_response({"error": str(e)}, status=400)
|
|
||||||
return web.json_response(user_id)
|
|
||||||
|
|
||||||
@routes.get("/userdata")
|
|
||||||
async def listuserdata(request):
|
|
||||||
"""
|
|
||||||
List user data files in a specified directory.
|
|
||||||
|
|
||||||
This endpoint allows listing files in a user's data directory, with options for recursion,
|
|
||||||
full file information, and path splitting.
|
|
||||||
|
|
||||||
Query Parameters:
|
|
||||||
- dir (required): The directory to list files from.
|
|
||||||
- recurse (optional): If "true", recursively list files in subdirectories.
|
|
||||||
- full_info (optional): If "true", return detailed file information (path, size, modified time).
|
|
||||||
- split (optional): If "true", split file paths into components (only applies when full_info is false).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- 400: If 'dir' parameter is missing.
|
|
||||||
- 403: If the requested path is not allowed.
|
|
||||||
- 404: If the requested directory does not exist.
|
|
||||||
- 200: JSON response with the list of files or file information.
|
|
||||||
|
|
||||||
The response format depends on the query parameters:
|
|
||||||
- Default: List of relative file paths.
|
|
||||||
- full_info=true: List of dictionaries with file details.
|
|
||||||
- split=true (and full_info=false): List of lists, each containing path components.
|
|
||||||
"""
|
|
||||||
directory = request.rel_url.query.get('dir', '')
|
|
||||||
if not directory:
|
|
||||||
return web.Response(status=400, text="Directory not provided")
|
|
||||||
|
|
||||||
path = self.get_request_user_filepath(request, directory)
|
|
||||||
if not path:
|
|
||||||
return web.Response(status=403, text="Invalid directory")
|
|
||||||
|
|
||||||
if not os.path.exists(path):
|
|
||||||
return web.Response(status=404, text="Directory not found")
|
|
||||||
|
|
||||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
|
||||||
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
|
||||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
|
||||||
|
|
||||||
# Use different patterns based on whether we're recursing or not
|
|
||||||
if recurse:
|
|
||||||
pattern = os.path.join(glob.escape(path), '**', '*')
|
|
||||||
else:
|
|
||||||
pattern = os.path.join(glob.escape(path), '*')
|
|
||||||
|
|
||||||
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
|
|
||||||
if full_info:
|
|
||||||
return get_file_info(full_path, path)
|
|
||||||
|
|
||||||
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
|
|
||||||
if split_path:
|
|
||||||
return [rel_path] + rel_path.split('/')
|
|
||||||
|
|
||||||
return rel_path
|
|
||||||
|
|
||||||
results = [
|
|
||||||
process_full_path(full_path)
|
|
||||||
for full_path in glob.glob(pattern, recursive=recurse)
|
|
||||||
if os.path.isfile(full_path)
|
|
||||||
]
|
|
||||||
|
|
||||||
return web.json_response(results)
|
|
||||||
|
|
||||||
@routes.get("/v2/userdata")
|
|
||||||
async def list_userdata_v2(request):
|
|
||||||
"""
|
|
||||||
List files and directories in a user's data directory.
|
|
||||||
|
|
||||||
This endpoint provides a structured listing of contents within a specified
|
|
||||||
subdirectory of the user's data storage.
|
|
||||||
|
|
||||||
Query Parameters:
|
|
||||||
- path (optional): The relative path within the user's data directory
|
|
||||||
to list. Defaults to the root ('').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- 400: If the requested path is invalid, outside the user's data directory, or is not a directory.
|
|
||||||
- 404: If the requested path does not exist.
|
|
||||||
- 403: If the user is invalid.
|
|
||||||
- 500: If there is an error reading the directory contents.
|
|
||||||
- 200: JSON response containing a list of file and directory objects.
|
|
||||||
Each object includes:
|
|
||||||
- name: The name of the file or directory.
|
|
||||||
- type: 'file' or 'directory'.
|
|
||||||
- path: The relative path from the user's data root.
|
|
||||||
- size (for files): The size in bytes.
|
|
||||||
- modified (for files): The last modified timestamp (Unix epoch).
|
|
||||||
"""
|
|
||||||
requested_rel_path = request.rel_url.query.get('path', '')
|
|
||||||
|
|
||||||
# URL-decode the path parameter
|
|
||||||
try:
|
|
||||||
requested_rel_path = parse.unquote(requested_rel_path)
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
|
|
||||||
return web.Response(status=400, text="Invalid characters in path parameter")
|
|
||||||
|
|
||||||
|
|
||||||
# Check user validity and get the absolute path for the requested directory
|
|
||||||
try:
|
|
||||||
base_user_path = self.get_request_user_filepath(request, None, create_dir=False)
|
|
||||||
|
|
||||||
if requested_rel_path:
|
|
||||||
target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False)
|
|
||||||
else:
|
|
||||||
target_abs_path = base_user_path
|
|
||||||
|
|
||||||
except KeyError as e:
|
|
||||||
# Invalid user detected by get_request_user_id inside get_request_user_filepath
|
|
||||||
logging.warning(f"Access denied for user: {e}")
|
|
||||||
return web.Response(status=403, text="Invalid user specified in request")
|
|
||||||
|
|
||||||
|
|
||||||
if not target_abs_path:
|
|
||||||
# Path traversal or other issue detected by get_request_user_filepath
|
|
||||||
return web.Response(status=400, text="Invalid path requested")
|
|
||||||
|
|
||||||
# Handle cases where the user directory or target path doesn't exist
|
|
||||||
if not os.path.exists(target_abs_path):
|
|
||||||
# Check if it's the base user directory that's missing (new user case)
|
|
||||||
if target_abs_path == base_user_path:
|
|
||||||
# It's okay if the base user directory doesn't exist yet, return empty list
|
|
||||||
return web.json_response([])
|
|
||||||
else:
|
|
||||||
# A specific subdirectory was requested but doesn't exist
|
|
||||||
return web.Response(status=404, text="Requested path not found")
|
|
||||||
|
|
||||||
if not os.path.isdir(target_abs_path):
|
|
||||||
return web.Response(status=400, text="Requested path is not a directory")
|
|
||||||
|
|
||||||
results = []
|
|
||||||
try:
|
|
||||||
for root, dirs, files in os.walk(target_abs_path, topdown=True):
|
|
||||||
# Process directories
|
|
||||||
for dir_name in dirs:
|
|
||||||
dir_path = os.path.join(root, dir_name)
|
|
||||||
rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/')
|
|
||||||
results.append({
|
|
||||||
"name": dir_name,
|
|
||||||
"path": rel_path,
|
|
||||||
"type": "directory"
|
|
||||||
})
|
|
||||||
|
|
||||||
# Process files
|
|
||||||
for file_name in files:
|
|
||||||
file_path = os.path.join(root, file_name)
|
|
||||||
rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/')
|
|
||||||
entry_info = {
|
|
||||||
"name": file_name,
|
|
||||||
"path": rel_path,
|
|
||||||
"type": "file"
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk
|
|
||||||
entry_info["size"] = stats.st_size
|
|
||||||
entry_info["modified"] = stats.st_mtime
|
|
||||||
except OSError as stat_error:
|
|
||||||
logging.warning(f"Could not stat file {file_path}: {stat_error}")
|
|
||||||
pass # Include file with available info
|
|
||||||
results.append(entry_info)
|
|
||||||
except OSError as e:
|
|
||||||
logging.error(f"Error listing directory {target_abs_path}: {e}")
|
|
||||||
return web.Response(status=500, text="Error reading directory contents")
|
|
||||||
|
|
||||||
# Sort results alphabetically, directories first then files
|
|
||||||
results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower()))
|
|
||||||
|
|
||||||
return web.json_response(results)
|
|
||||||
|
|
||||||
def get_user_data_path(request, check_exists = False, param = "file"):
|
|
||||||
file = request.match_info.get(param, None)
|
|
||||||
if not file:
|
|
||||||
return web.Response(status=400)
|
|
||||||
|
|
||||||
path = self.get_request_user_filepath(request, file)
|
|
||||||
if not path:
|
|
||||||
return web.Response(status=403)
|
|
||||||
|
|
||||||
if check_exists and not os.path.exists(path):
|
|
||||||
return web.Response(status=404)
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
@routes.get("/userdata/{file}")
|
|
||||||
async def getuserdata(request):
|
|
||||||
path = get_user_data_path(request, check_exists=True)
|
|
||||||
if not isinstance(path, str):
|
|
||||||
return path
|
|
||||||
|
|
||||||
return web.FileResponse(path)
|
|
||||||
|
|
||||||
@routes.post("/userdata/{file}")
|
|
||||||
async def post_userdata(request):
|
|
||||||
"""
|
|
||||||
Upload or update a user data file.
|
|
||||||
|
|
||||||
This endpoint handles file uploads to a user's data directory, with options for
|
|
||||||
controlling overwrite behavior and response format.
|
|
||||||
|
|
||||||
Query Parameters:
|
|
||||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
|
||||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
|
||||||
If "false", returns only the relative file path.
|
|
||||||
|
|
||||||
Path Parameters:
|
|
||||||
- file: The target file path (URL encoded if necessary).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- 400: If 'file' parameter is missing.
|
|
||||||
- 403: If the requested path is not allowed.
|
|
||||||
- 409: If overwrite=false and the file already exists.
|
|
||||||
- 200: JSON response with either:
|
|
||||||
- Full file information (if full_info=true)
|
|
||||||
- Relative file path (if full_info=false)
|
|
||||||
|
|
||||||
The request body should contain the raw file content to be written.
|
|
||||||
"""
|
|
||||||
path = get_user_data_path(request)
|
|
||||||
if not isinstance(path, str):
|
|
||||||
return path
|
|
||||||
|
|
||||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
|
||||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
|
||||||
|
|
||||||
if not overwrite and os.path.exists(path):
|
|
||||||
return web.Response(status=409, text="File already exists")
|
|
||||||
|
|
||||||
try:
|
|
||||||
body = await request.read()
|
|
||||||
|
|
||||||
with open(path, "wb") as f:
|
|
||||||
f.write(body)
|
|
||||||
except OSError as e:
|
|
||||||
logging.warning(f"Error saving file '{path}': {e}")
|
|
||||||
return web.Response(
|
|
||||||
status=400,
|
|
||||||
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
|
|
||||||
)
|
|
||||||
|
|
||||||
user_path = self.get_request_user_filepath(request, None)
|
|
||||||
if full_info:
|
|
||||||
resp = get_file_info(path, user_path)
|
|
||||||
else:
|
|
||||||
resp = os.path.relpath(path, user_path)
|
|
||||||
|
|
||||||
return web.json_response(resp)
|
|
||||||
|
|
||||||
@routes.delete("/userdata/{file}")
|
|
||||||
async def delete_userdata(request):
|
|
||||||
path = get_user_data_path(request, check_exists=True)
|
|
||||||
if not isinstance(path, str):
|
|
||||||
return path
|
|
||||||
|
|
||||||
os.remove(path)
|
|
||||||
|
|
||||||
return web.Response(status=204)
|
|
||||||
|
|
||||||
@routes.post("/userdata/{file}/move/{dest}")
|
|
||||||
async def move_userdata(request):
|
|
||||||
"""
|
|
||||||
Move or rename a user data file.
|
|
||||||
|
|
||||||
This endpoint handles moving or renaming files within a user's data directory, with options for
|
|
||||||
controlling overwrite behavior and response format.
|
|
||||||
|
|
||||||
Path Parameters:
|
|
||||||
- file: The source file path (URL encoded if necessary)
|
|
||||||
- dest: The destination file path (URL encoded if necessary)
|
|
||||||
|
|
||||||
Query Parameters:
|
|
||||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
|
||||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
|
||||||
If "false", returns only the relative file path.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- 400: If either 'file' or 'dest' parameter is missing
|
|
||||||
- 403: If either requested path is not allowed
|
|
||||||
- 404: If the source file does not exist
|
|
||||||
- 409: If overwrite=false and the destination file already exists
|
|
||||||
- 200: JSON response with either:
|
|
||||||
- Full file information (if full_info=true)
|
|
||||||
- Relative file path (if full_info=false)
|
|
||||||
"""
|
|
||||||
source = get_user_data_path(request, check_exists=True)
|
|
||||||
if not isinstance(source, str):
|
|
||||||
return source
|
|
||||||
|
|
||||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
|
||||||
if not isinstance(dest, str):
|
|
||||||
return dest
|
|
||||||
|
|
||||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
|
||||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
|
||||||
|
|
||||||
if not overwrite and os.path.exists(dest):
|
|
||||||
return web.Response(status=409, text="File already exists")
|
|
||||||
|
|
||||||
logging.info(f"moving '{source}' -> '{dest}'")
|
|
||||||
shutil.move(source, dest)
|
|
||||||
|
|
||||||
user_path = self.get_request_user_filepath(request, None)
|
|
||||||
if full_info:
|
|
||||||
resp = get_file_info(dest, user_path)
|
|
||||||
else:
|
|
||||||
resp = os.path.relpath(dest, user_path)
|
|
||||||
|
|
||||||
return web.json_response(resp)
|
|
||||||
@ -1,91 +0,0 @@
|
|||||||
from .wav2vec2 import Wav2Vec2Model
|
|
||||||
from .whisper import WhisperLargeV3
|
|
||||||
import comfy.model_management
|
|
||||||
import comfy.ops
|
|
||||||
import comfy.utils
|
|
||||||
import logging
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
|
|
||||||
class AudioEncoderModel():
|
|
||||||
def __init__(self, config):
|
|
||||||
self.load_device = comfy.model_management.text_encoder_device()
|
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
|
||||||
model_type = config.pop("model_type")
|
|
||||||
model_config = dict(config)
|
|
||||||
model_config.update({
|
|
||||||
"dtype": self.dtype,
|
|
||||||
"device": offload_device,
|
|
||||||
"operations": comfy.ops.manual_cast
|
|
||||||
})
|
|
||||||
|
|
||||||
if model_type == "wav2vec2":
|
|
||||||
self.model = Wav2Vec2Model(**model_config)
|
|
||||||
elif model_type == "whisper3":
|
|
||||||
self.model = WhisperLargeV3(**model_config)
|
|
||||||
self.model.eval()
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
|
||||||
self.model_sample_rate = 16000
|
|
||||||
|
|
||||||
def load_sd(self, sd):
|
|
||||||
return self.model.load_state_dict(sd, strict=False)
|
|
||||||
|
|
||||||
def get_sd(self):
|
|
||||||
return self.model.state_dict()
|
|
||||||
|
|
||||||
def encode_audio(self, audio, sample_rate):
|
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
|
||||||
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
|
|
||||||
out, all_layers = self.model(audio.to(self.load_device))
|
|
||||||
outputs = {}
|
|
||||||
outputs["encoded_audio"] = out
|
|
||||||
outputs["encoded_audio_all_layers"] = all_layers
|
|
||||||
outputs["audio_samples"] = audio.shape[2]
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
def load_audio_encoder_from_sd(sd, prefix=""):
|
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
|
||||||
if "encoder.layer_norm.bias" in sd: #wav2vec2
|
|
||||||
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
|
|
||||||
if embed_dim == 1024:# large
|
|
||||||
config = {
|
|
||||||
"model_type": "wav2vec2",
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"num_heads": 16,
|
|
||||||
"num_layers": 24,
|
|
||||||
"conv_norm": True,
|
|
||||||
"conv_bias": True,
|
|
||||||
"do_normalize": True,
|
|
||||||
"do_stable_layer_norm": True
|
|
||||||
}
|
|
||||||
elif embed_dim == 768: # base
|
|
||||||
config = {
|
|
||||||
"model_type": "wav2vec2",
|
|
||||||
"embed_dim": 768,
|
|
||||||
"num_heads": 12,
|
|
||||||
"num_layers": 12,
|
|
||||||
"conv_norm": False,
|
|
||||||
"conv_bias": False,
|
|
||||||
"do_normalize": False, # chinese-wav2vec2-base has this False
|
|
||||||
"do_stable_layer_norm": False
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
|
|
||||||
elif "model.encoder.embed_positions.weight" in sd:
|
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
|
|
||||||
config = {
|
|
||||||
"model_type": "whisper3",
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise RuntimeError("ERROR: audio encoder not supported.")
|
|
||||||
|
|
||||||
audio_encoder = AudioEncoderModel(config)
|
|
||||||
m, u = audio_encoder.load_sd(sd)
|
|
||||||
if len(m) > 0:
|
|
||||||
logging.warning("missing audio encoder: {}".format(m))
|
|
||||||
if len(u) > 0:
|
|
||||||
logging.warning("unexpected audio encoder: {}".format(u))
|
|
||||||
|
|
||||||
return audio_encoder
|
|
||||||
@ -1,252 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNormConv(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
|
||||||
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
|
|
||||||
|
|
||||||
class LayerGroupNormConv(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
|
||||||
self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
return torch.nn.functional.gelu(self.layer_norm(x))
|
|
||||||
|
|
||||||
class ConvNoNorm(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
return torch.nn.functional.gelu(x)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvFeatureEncoder(nn.Module):
|
|
||||||
def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
if conv_norm:
|
|
||||||
self.conv_layers = nn.ModuleList([
|
|
||||||
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
|
||||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
self.conv_layers = nn.ModuleList([
|
|
||||||
LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
|
||||||
])
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.unsqueeze(1)
|
|
||||||
|
|
||||||
for conv in self.conv_layers:
|
|
||||||
x = conv(x)
|
|
||||||
|
|
||||||
return x.transpose(1, 2)
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureProjection(nn.Module):
|
|
||||||
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
|
|
||||||
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.layer_norm(x)
|
|
||||||
x = self.projection(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class PositionalConvEmbedding(nn.Module):
|
|
||||||
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv1d(
|
|
||||||
embed_dim,
|
|
||||||
embed_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
padding=kernel_size // 2,
|
|
||||||
groups=groups,
|
|
||||||
)
|
|
||||||
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
|
||||||
self.activation = nn.GELU()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.transpose(1, 2)
|
|
||||||
x = self.conv(x)[:, :, :-1]
|
|
||||||
x = self.activation(x)
|
|
||||||
x = x.transpose(1, 2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim=768,
|
|
||||||
num_heads=12,
|
|
||||||
num_layers=12,
|
|
||||||
mlp_ratio=4.0,
|
|
||||||
do_stable_layer_norm=True,
|
|
||||||
dtype=None, device=None, operations=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
|
|
||||||
self.layers = nn.ModuleList([
|
|
||||||
TransformerEncoderLayer(
|
|
||||||
embed_dim=embed_dim,
|
|
||||||
num_heads=num_heads,
|
|
||||||
mlp_ratio=mlp_ratio,
|
|
||||||
do_stable_layer_norm=do_stable_layer_norm,
|
|
||||||
device=device, dtype=dtype, operations=operations
|
|
||||||
)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
])
|
|
||||||
|
|
||||||
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
|
|
||||||
self.do_stable_layer_norm = do_stable_layer_norm
|
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
|
||||||
x = x + self.pos_conv_embed(x)
|
|
||||||
all_x = ()
|
|
||||||
if not self.do_stable_layer_norm:
|
|
||||||
x = self.layer_norm(x)
|
|
||||||
for layer in self.layers:
|
|
||||||
all_x += (x,)
|
|
||||||
x = layer(x, mask)
|
|
||||||
if self.do_stable_layer_norm:
|
|
||||||
x = self.layer_norm(x)
|
|
||||||
all_x += (x,)
|
|
||||||
return x, all_x
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = embed_dim // num_heads
|
|
||||||
|
|
||||||
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
|
||||||
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
|
||||||
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
|
||||||
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
|
||||||
assert (mask is None) # TODO?
|
|
||||||
q = self.q_proj(x)
|
|
||||||
k = self.k_proj(x)
|
|
||||||
v = self.v_proj(x)
|
|
||||||
|
|
||||||
out = optimized_attention_masked(q, k, v, self.num_heads)
|
|
||||||
return self.out_proj(out)
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
|
||||||
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
|
|
||||||
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.intermediate_dense(x)
|
|
||||||
x = torch.nn.functional.gelu(x)
|
|
||||||
x = self.output_dense(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim=768,
|
|
||||||
num_heads=12,
|
|
||||||
mlp_ratio=4.0,
|
|
||||||
do_stable_layer_norm=True,
|
|
||||||
dtype=None, device=None, operations=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
|
|
||||||
|
|
||||||
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
|
||||||
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
|
||||||
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
|
||||||
self.do_stable_layer_norm = do_stable_layer_norm
|
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
|
||||||
residual = x
|
|
||||||
if self.do_stable_layer_norm:
|
|
||||||
x = self.layer_norm(x)
|
|
||||||
x = self.attention(x, mask=mask)
|
|
||||||
x = residual + x
|
|
||||||
if not self.do_stable_layer_norm:
|
|
||||||
x = self.layer_norm(x)
|
|
||||||
return self.final_layer_norm(x + self.feed_forward(x))
|
|
||||||
else:
|
|
||||||
return x + self.feed_forward(self.final_layer_norm(x))
|
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2Model(nn.Module):
|
|
||||||
"""Complete Wav2Vec 2.0 model."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim=1024,
|
|
||||||
final_dim=256,
|
|
||||||
num_heads=16,
|
|
||||||
num_layers=24,
|
|
||||||
conv_norm=True,
|
|
||||||
conv_bias=True,
|
|
||||||
do_normalize=True,
|
|
||||||
do_stable_layer_norm=True,
|
|
||||||
dtype=None, device=None, operations=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
conv_dim = 512
|
|
||||||
self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
|
|
||||||
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
|
|
||||||
|
|
||||||
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
|
|
||||||
self.encoder = TransformerEncoder(
|
|
||||||
embed_dim=embed_dim,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_layers=num_layers,
|
|
||||||
do_stable_layer_norm=do_stable_layer_norm,
|
|
||||||
device=device, dtype=dtype, operations=operations
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, mask_time_indices=None, return_dict=False):
|
|
||||||
x = torch.mean(x, dim=1)
|
|
||||||
|
|
||||||
if self.do_normalize:
|
|
||||||
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
|
||||||
|
|
||||||
features = self.feature_extractor(x)
|
|
||||||
features = self.feature_projection(features)
|
|
||||||
batch_size, seq_len, _ = features.shape
|
|
||||||
|
|
||||||
x, all_x = self.encoder(features)
|
|
||||||
return x, all_x
|
|
||||||
@ -1,186 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchaudio
|
|
||||||
from typing import Optional
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
|
||||||
import comfy.ops
|
|
||||||
|
|
||||||
class WhisperFeatureExtractor(nn.Module):
|
|
||||||
def __init__(self, n_mels=128, device=None):
|
|
||||||
super().__init__()
|
|
||||||
self.sample_rate = 16000
|
|
||||||
self.n_fft = 400
|
|
||||||
self.hop_length = 160
|
|
||||||
self.n_mels = n_mels
|
|
||||||
self.chunk_length = 30
|
|
||||||
self.n_samples = 480000
|
|
||||||
|
|
||||||
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
|
||||||
sample_rate=self.sample_rate,
|
|
||||||
n_fft=self.n_fft,
|
|
||||||
hop_length=self.hop_length,
|
|
||||||
n_mels=self.n_mels,
|
|
||||||
f_min=0,
|
|
||||||
f_max=8000,
|
|
||||||
norm="slaney",
|
|
||||||
mel_scale="slaney",
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
def __call__(self, audio):
|
|
||||||
audio = torch.mean(audio, dim=1)
|
|
||||||
batch_size = audio.shape[0]
|
|
||||||
processed_audio = []
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
aud = audio[i]
|
|
||||||
if aud.shape[0] > self.n_samples:
|
|
||||||
aud = aud[:self.n_samples]
|
|
||||||
elif aud.shape[0] < self.n_samples:
|
|
||||||
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
|
|
||||||
processed_audio.append(aud)
|
|
||||||
|
|
||||||
audio = torch.stack(processed_audio)
|
|
||||||
|
|
||||||
mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
|
|
||||||
|
|
||||||
log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
|
||||||
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
|
|
||||||
log_mel_spec = (log_mel_spec + 4.0) / 4.0
|
|
||||||
|
|
||||||
return log_mel_spec
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
|
||||||
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
assert d_model % n_heads == 0
|
|
||||||
|
|
||||||
self.d_model = d_model
|
|
||||||
self.n_heads = n_heads
|
|
||||||
self.d_k = d_model // n_heads
|
|
||||||
|
|
||||||
self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
|
||||||
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
|
|
||||||
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
|
||||||
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
batch_size, seq_len, _ = query.shape
|
|
||||||
|
|
||||||
q = self.q_proj(query)
|
|
||||||
k = self.k_proj(key)
|
|
||||||
v = self.v_proj(value)
|
|
||||||
|
|
||||||
attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
|
|
||||||
attn_output = self.out_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderLayer(nn.Module):
|
|
||||||
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
|
|
||||||
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
|
|
||||||
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
|
|
||||||
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
residual = x
|
|
||||||
x = self.self_attn_layer_norm(x)
|
|
||||||
x = self.self_attn(x, x, x, attention_mask)
|
|
||||||
x = residual + x
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
x = self.final_layer_norm(x)
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = F.gelu(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
x = residual + x
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class AudioEncoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_mels: int = 128,
|
|
||||||
n_ctx: int = 1500,
|
|
||||||
n_state: int = 1280,
|
|
||||||
n_head: int = 20,
|
|
||||||
n_layer: int = 32,
|
|
||||||
dtype=None,
|
|
||||||
device=None,
|
|
||||||
operations=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
|
|
||||||
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
|
||||||
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
|
|
||||||
for _ in range(n_layer)
|
|
||||||
])
|
|
||||||
|
|
||||||
self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x = F.gelu(self.conv1(x))
|
|
||||||
x = F.gelu(self.conv2(x))
|
|
||||||
|
|
||||||
x = x.transpose(1, 2)
|
|
||||||
|
|
||||||
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
|
|
||||||
|
|
||||||
all_x = ()
|
|
||||||
for layer in self.layers:
|
|
||||||
all_x += (x,)
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
x = self.layer_norm(x)
|
|
||||||
all_x += (x,)
|
|
||||||
return x, all_x
|
|
||||||
|
|
||||||
|
|
||||||
class WhisperLargeV3(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_mels: int = 128,
|
|
||||||
n_audio_ctx: int = 1500,
|
|
||||||
n_audio_state: int = 1280,
|
|
||||||
n_audio_head: int = 20,
|
|
||||||
n_audio_layer: int = 32,
|
|
||||||
dtype=None,
|
|
||||||
device=None,
|
|
||||||
operations=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
|
|
||||||
|
|
||||||
self.encoder = AudioEncoder(
|
|
||||||
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
|
|
||||||
dtype=dtype, device=device, operations=operations
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, audio):
|
|
||||||
mel = self.feature_extractor(audio)
|
|
||||||
x, all_x = self.encoder(mel)
|
|
||||||
return x, all_x
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
import pickle
|
|
||||||
|
|
||||||
load = pickle.load
|
|
||||||
|
|
||||||
class Empty:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Unpickler(pickle.Unpickler):
|
|
||||||
def find_class(self, module, name):
|
|
||||||
#TODO: safe unpickle
|
|
||||||
if module.startswith("pytorch_lightning"):
|
|
||||||
return Empty
|
|
||||||
return super().find_class(module, name)
|
|
||||||
@ -2,56 +2,21 @@
|
|||||||
#and modified
|
#and modified
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch as th
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ..ldm.modules.diffusionmodules.util import (
|
from ..ldm.modules.diffusionmodules.util import (
|
||||||
|
conv_nd,
|
||||||
|
linear,
|
||||||
|
zero_module,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..ldm.modules.attention import SpatialTransformer
|
from ..ldm.modules.attention import SpatialTransformer
|
||||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||||
from ..ldm.util import exists
|
from ..ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
from .control_types import UNION_CONTROLNET_TYPES
|
from ..ldm.util import log_txt_as_img, exists, instantiate_from_config
|
||||||
from collections import OrderedDict
|
|
||||||
import comfy.ops
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
|
||||||
|
|
||||||
class OptimizedAttention(nn.Module):
|
|
||||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.heads = nhead
|
|
||||||
self.c = c
|
|
||||||
|
|
||||||
self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
|
|
||||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.in_proj(x)
|
|
||||||
q, k, v = x.split(self.c, dim=2)
|
|
||||||
out = optimized_attention(q, k, v, self.heads)
|
|
||||||
return self.out_proj(out)
|
|
||||||
|
|
||||||
class QuickGELU(nn.Module):
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
return x * torch.sigmoid(1.702 * x)
|
|
||||||
|
|
||||||
class ResBlockUnionControlnet(nn.Module):
|
|
||||||
def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
|
|
||||||
self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
|
|
||||||
("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
|
|
||||||
self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def attention(self, x: torch.Tensor):
|
|
||||||
return self.attn(x)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
x = x + self.attention(self.ln_1(x))
|
|
||||||
x = x + self.mlp(self.ln_2(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
class ControlledUnetModel(UNetModel):
|
class ControlledUnetModel(UNetModel):
|
||||||
#implemented in the ldm unet
|
#implemented in the ldm unet
|
||||||
@ -65,13 +30,13 @@ class ControlNet(nn.Module):
|
|||||||
model_channels,
|
model_channels,
|
||||||
hint_channels,
|
hint_channels,
|
||||||
num_res_blocks,
|
num_res_blocks,
|
||||||
|
attention_resolutions,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
channel_mult=(1, 2, 4, 8),
|
channel_mult=(1, 2, 4, 8),
|
||||||
conv_resample=True,
|
conv_resample=True,
|
||||||
dims=2,
|
dims=2,
|
||||||
num_classes=None,
|
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
dtype=torch.float32,
|
use_fp16=False,
|
||||||
num_heads=-1,
|
num_heads=-1,
|
||||||
num_head_channels=-1,
|
num_head_channels=-1,
|
||||||
num_heads_upsample=-1,
|
num_heads_upsample=-1,
|
||||||
@ -87,17 +52,8 @@ class ControlNet(nn.Module):
|
|||||||
num_attention_blocks=None,
|
num_attention_blocks=None,
|
||||||
disable_middle_self_attn=False,
|
disable_middle_self_attn=False,
|
||||||
use_linear_in_transformer=False,
|
use_linear_in_transformer=False,
|
||||||
adm_in_channels=None,
|
|
||||||
transformer_depth_middle=None,
|
|
||||||
transformer_depth_output=None,
|
|
||||||
attn_precision=None,
|
|
||||||
union_controlnet_num_control_type=None,
|
|
||||||
device=None,
|
|
||||||
operations=comfy.ops.disable_weight_init,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||||
|
|
||||||
@ -120,7 +76,6 @@ class ControlNet(nn.Module):
|
|||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
|
|
||||||
if isinstance(num_res_blocks, int):
|
if isinstance(num_res_blocks, int):
|
||||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||||
else:
|
else:
|
||||||
@ -128,22 +83,23 @@ class ControlNet(nn.Module):
|
|||||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
|
|
||||||
if disable_self_attentions is not None:
|
if disable_self_attentions is not None:
|
||||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||||
assert len(disable_self_attentions) == len(channel_mult)
|
assert len(disable_self_attentions) == len(channel_mult)
|
||||||
if num_attention_blocks is not None:
|
if num_attention_blocks is not None:
|
||||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||||
|
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||||
|
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||||
|
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||||
|
f"attention will still not be set.")
|
||||||
|
|
||||||
transformer_depth = transformer_depth[:]
|
self.attention_resolutions = attention_resolutions
|
||||||
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.channel_mult = channel_mult
|
self.channel_mult = channel_mult
|
||||||
self.conv_resample = conv_resample
|
self.conv_resample = conv_resample
|
||||||
self.num_classes = num_classes
|
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.dtype = dtype
|
self.dtype = th.float16 if use_fp16 else th.float32
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.num_head_channels = num_head_channels
|
self.num_head_channels = num_head_channels
|
||||||
self.num_heads_upsample = num_heads_upsample
|
self.num_heads_upsample = num_heads_upsample
|
||||||
@ -151,53 +107,36 @@ class ControlNet(nn.Module):
|
|||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
time_embed_dim = model_channels * 4
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
linear(model_channels, time_embed_dim),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
linear(time_embed_dim, time_embed_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
|
||||||
if isinstance(self.num_classes, int):
|
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
|
||||||
elif self.num_classes == "continuous":
|
|
||||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
|
||||||
elif self.num_classes == "sequential":
|
|
||||||
assert adm_in_channels is not None
|
|
||||||
self.label_emb = nn.Sequential(
|
|
||||||
nn.Sequential(
|
|
||||||
operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
|
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
||||||
|
|
||||||
self.input_hint_block = TimestepEmbedSequential(
|
self.input_hint_block = TimestepEmbedSequential(
|
||||||
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
|
conv_nd(dims, hint_channels, 16, 3, padding=1),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
|
conv_nd(dims, 16, 16, 3, padding=1),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
|
conv_nd(dims, 32, 32, 3, padding=1),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
|
conv_nd(dims, 96, 96, 3, padding=1),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
||||||
)
|
)
|
||||||
|
|
||||||
self._feature_size = model_channels
|
self._feature_size = model_channels
|
||||||
@ -215,14 +154,10 @@ class ControlNet(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
dtype=self.dtype,
|
|
||||||
device=device,
|
|
||||||
operations=operations,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = mult * model_channels
|
ch = mult * model_channels
|
||||||
num_transformers = transformer_depth.pop(0)
|
if ds in attention_resolutions:
|
||||||
if num_transformers > 0:
|
|
||||||
if num_head_channels == -1:
|
if num_head_channels == -1:
|
||||||
dim_head = ch // num_heads
|
dim_head = ch // num_heads
|
||||||
else:
|
else:
|
||||||
@ -238,14 +173,20 @@ class ControlNet(nn.Module):
|
|||||||
|
|
||||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(
|
||||||
SpatialTransformer(
|
AttentionBlock(
|
||||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
ch,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_head_channels=dim_head,
|
||||||
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
self.zero_convs.append(self.make_zero_conv(ch))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
input_block_chans.append(ch)
|
input_block_chans.append(ch)
|
||||||
if level != len(channel_mult) - 1:
|
if level != len(channel_mult) - 1:
|
||||||
@ -261,19 +202,16 @@ class ControlNet(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
down=True,
|
down=True,
|
||||||
dtype=self.dtype,
|
|
||||||
device=device,
|
|
||||||
operations=operations
|
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Downsample(
|
else Downsample(
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
|
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
ch = out_ch
|
ch = out_ch
|
||||||
input_block_chans.append(ch)
|
input_block_chans.append(ch)
|
||||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
self.zero_convs.append(self.make_zero_conv(ch))
|
||||||
ds *= 2
|
ds *= 2
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
@ -285,7 +223,7 @@ class ControlNet(nn.Module):
|
|||||||
if legacy:
|
if legacy:
|
||||||
#num_heads = 1
|
#num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
mid_block = [
|
self.middle_block = TimestepEmbedSequential(
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
time_embed_dim,
|
time_embed_dim,
|
||||||
@ -293,15 +231,17 @@ class ControlNet(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
dtype=self.dtype,
|
),
|
||||||
device=device,
|
AttentionBlock(
|
||||||
operations=operations
|
ch,
|
||||||
)]
|
use_checkpoint=use_checkpoint,
|
||||||
if transformer_depth_middle >= 0:
|
num_heads=num_heads,
|
||||||
mid_block += [SpatialTransformer( # always uses a self-attn
|
num_head_channels=dim_head,
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||||
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint
|
||||||
),
|
),
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
@ -310,114 +250,23 @@ class ControlNet(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
dtype=self.dtype,
|
),
|
||||||
device=device,
|
)
|
||||||
operations=operations
|
self.middle_block_out = self.make_zero_conv(ch)
|
||||||
)]
|
|
||||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
|
||||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
if union_controlnet_num_control_type is not None:
|
def make_zero_conv(self, channels):
|
||||||
self.num_control_type = union_controlnet_num_control_type
|
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||||
num_trans_channel = 320
|
|
||||||
num_trans_head = 8
|
|
||||||
num_trans_layer = 1
|
|
||||||
num_proj_channel = 320
|
|
||||||
# task_scale_factor = num_trans_channel ** 0.5
|
|
||||||
self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
|
|
||||||
|
|
||||||
self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
|
def forward(self, x, hint, timesteps, context, **kwargs):
|
||||||
self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||||
#-----------------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
control_add_embed_dim = 256
|
|
||||||
class ControlAddEmbedding(nn.Module):
|
|
||||||
def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.num_control_type = num_control_type
|
|
||||||
self.in_dim = in_dim
|
|
||||||
self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
|
|
||||||
self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
|
|
||||||
def forward(self, control_type, dtype, device):
|
|
||||||
c_type = torch.zeros((self.num_control_type,), device=device)
|
|
||||||
c_type[control_type] = 1.0
|
|
||||||
c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
|
|
||||||
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
|
|
||||||
|
|
||||||
self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
|
|
||||||
else:
|
|
||||||
self.task_embedding = None
|
|
||||||
self.control_add_embedding = None
|
|
||||||
|
|
||||||
def union_controlnet_merge(self, hint, control_type, emb, context):
|
|
||||||
# Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
|
|
||||||
inputs = []
|
|
||||||
condition_list = []
|
|
||||||
|
|
||||||
for idx in range(min(1, len(control_type))):
|
|
||||||
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
|
|
||||||
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
|
||||||
if idx < len(control_type):
|
|
||||||
feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
|
|
||||||
|
|
||||||
inputs.append(feat_seq.unsqueeze(1))
|
|
||||||
condition_list.append(controlnet_cond)
|
|
||||||
|
|
||||||
x = torch.cat(inputs, dim=1)
|
|
||||||
x = self.transformer_layes(x)
|
|
||||||
controlnet_cond_fuser = None
|
|
||||||
for idx in range(len(control_type)):
|
|
||||||
alpha = self.spatial_ch_projs(x[:, idx])
|
|
||||||
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
|
||||||
o = condition_list[idx] + alpha
|
|
||||||
if controlnet_cond_fuser is None:
|
|
||||||
controlnet_cond_fuser = o
|
|
||||||
else:
|
|
||||||
controlnet_cond_fuser += o
|
|
||||||
return controlnet_cond_fuser
|
|
||||||
|
|
||||||
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
|
||||||
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
|
||||||
|
|
||||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
guided_hint = None
|
guided_hint = self.input_hint_block(hint, emb, context)
|
||||||
if self.control_add_embedding is not None: #Union Controlnet
|
|
||||||
control_type = kwargs.get("control_type", [])
|
|
||||||
|
|
||||||
if any([c >= self.num_control_type for c in control_type]):
|
outs = []
|
||||||
max_type = max(control_type)
|
|
||||||
max_type_name = {
|
|
||||||
v: k for k, v in UNION_CONTROLNET_TYPES.items()
|
|
||||||
}[max_type]
|
|
||||||
raise ValueError(
|
|
||||||
f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
|
|
||||||
f"({self.num_control_type}) supported.\n" +
|
|
||||||
"Please consider using the ProMax ControlNet Union model.\n" +
|
|
||||||
"https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
|
|
||||||
)
|
|
||||||
|
|
||||||
emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
|
h = x.type(self.dtype)
|
||||||
if len(control_type) > 0:
|
|
||||||
if len(hint.shape) < 5:
|
|
||||||
hint = hint.unsqueeze(dim=0)
|
|
||||||
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
|
|
||||||
|
|
||||||
if guided_hint is None:
|
|
||||||
guided_hint = self.input_hint_block(hint, emb, context)
|
|
||||||
|
|
||||||
out_output = []
|
|
||||||
out_middle = []
|
|
||||||
|
|
||||||
if self.num_classes is not None:
|
|
||||||
if y is None:
|
|
||||||
raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?")
|
|
||||||
emb = emb + self.label_emb(y)
|
|
||||||
|
|
||||||
h = x
|
|
||||||
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
||||||
if guided_hint is not None:
|
if guided_hint is not None:
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context)
|
||||||
@ -425,10 +274,10 @@ class ControlNet(nn.Module):
|
|||||||
guided_hint = None
|
guided_hint = None
|
||||||
else:
|
else:
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context)
|
||||||
out_output.append(zero_conv(h, emb, context))
|
outs.append(zero_conv(h, emb, context))
|
||||||
|
|
||||||
h = self.middle_block(h, emb, context)
|
h = self.middle_block(h, emb, context)
|
||||||
out_middle.append(self.middle_block_out(h, emb, context))
|
outs.append(self.middle_block_out(h, emb, context))
|
||||||
|
|
||||||
return {"middle": out_middle, "output": out_output}
|
return outs
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +0,0 @@
|
|||||||
UNION_CONTROLNET_TYPES = {
|
|
||||||
"openpose": 0,
|
|
||||||
"depth": 1,
|
|
||||||
"hed/pidi/scribble/ted": 2,
|
|
||||||
"canny/lineart/anime_lineart/mlsd": 3,
|
|
||||||
"normal": 4,
|
|
||||||
"segment": 5,
|
|
||||||
"tile": 6,
|
|
||||||
"repaint": 7,
|
|
||||||
}
|
|
||||||
@ -1,120 +0,0 @@
|
|||||||
import math
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetEmbedder(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
img_size: int,
|
|
||||||
patch_size: int,
|
|
||||||
in_chans: int,
|
|
||||||
attention_head_dim: int,
|
|
||||||
num_attention_heads: int,
|
|
||||||
adm_in_channels: int,
|
|
||||||
num_layers: int,
|
|
||||||
main_model_double: int,
|
|
||||||
double_y_emb: bool,
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
pos_embed_max_size: Optional[int] = None,
|
|
||||||
operations = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.main_model_double = main_model_double
|
|
||||||
self.dtype = dtype
|
|
||||||
self.hidden_size = num_attention_heads * attention_head_dim
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.x_embedder = PatchEmbed(
|
|
||||||
img_size=img_size,
|
|
||||||
patch_size=patch_size,
|
|
||||||
in_chans=in_chans,
|
|
||||||
embed_dim=self.hidden_size,
|
|
||||||
strict_img_size=pos_embed_max_size is None,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
operations=operations,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
|
||||||
|
|
||||||
self.double_y_emb = double_y_emb
|
|
||||||
if self.double_y_emb:
|
|
||||||
self.orig_y_embedder = VectorEmbedder(
|
|
||||||
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
|
||||||
)
|
|
||||||
self.y_embedder = VectorEmbedder(
|
|
||||||
self.hidden_size, self.hidden_size, dtype, device, operations=operations
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.y_embedder = VectorEmbedder(
|
|
||||||
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
|
||||||
)
|
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
|
||||||
DismantledBlock(
|
|
||||||
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
|
|
||||||
dtype=dtype, device=device, operations=operations
|
|
||||||
)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
)
|
|
||||||
|
|
||||||
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
|
|
||||||
# TODO double check this logic when 8b
|
|
||||||
self.use_y_embedder = True
|
|
||||||
|
|
||||||
self.controlnet_blocks = nn.ModuleList([])
|
|
||||||
for _ in range(len(self.transformer_blocks)):
|
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
||||||
self.controlnet_blocks.append(controlnet_block)
|
|
||||||
|
|
||||||
self.pos_embed_input = PatchEmbed(
|
|
||||||
img_size=img_size,
|
|
||||||
patch_size=patch_size,
|
|
||||||
in_chans=in_chans,
|
|
||||||
embed_dim=self.hidden_size,
|
|
||||||
strict_img_size=False,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
operations=operations,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
timesteps: torch.Tensor,
|
|
||||||
y: Optional[torch.Tensor] = None,
|
|
||||||
context: Optional[torch.Tensor] = None,
|
|
||||||
hint = None,
|
|
||||||
) -> Tuple[Tensor, List[Tensor]]:
|
|
||||||
x_shape = list(x.shape)
|
|
||||||
x = self.x_embedder(x)
|
|
||||||
if not self.double_y_emb:
|
|
||||||
h = (x_shape[-2] + 1) // self.patch_size
|
|
||||||
w = (x_shape[-1] + 1) // self.patch_size
|
|
||||||
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
|
|
||||||
c = self.t_embedder(timesteps, dtype=x.dtype)
|
|
||||||
if y is not None and self.y_embedder is not None:
|
|
||||||
if self.double_y_emb:
|
|
||||||
y = self.orig_y_embedder(y)
|
|
||||||
y = self.y_embedder(y)
|
|
||||||
c = c + y
|
|
||||||
|
|
||||||
x = x + self.pos_embed_input(hint)
|
|
||||||
|
|
||||||
block_out = ()
|
|
||||||
|
|
||||||
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
|
|
||||||
for i in range(len(self.transformer_blocks)):
|
|
||||||
out = self.transformer_blocks[i](x, c)
|
|
||||||
if not self.double_y_emb:
|
|
||||||
x = out
|
|
||||||
block_out += (self.controlnet_blocks[i](out),) * repeat
|
|
||||||
|
|
||||||
return {"output": block_out}
|
|
||||||
@ -1,81 +0,0 @@
|
|||||||
import torch
|
|
||||||
from typing import Optional
|
|
||||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
|
||||||
|
|
||||||
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_blocks = None,
|
|
||||||
control_latent_channels = None,
|
|
||||||
dtype = None,
|
|
||||||
device = None,
|
|
||||||
operations = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
|
|
||||||
# controlnet_blocks
|
|
||||||
self.controlnet_blocks = torch.nn.ModuleList([])
|
|
||||||
for _ in range(len(self.joint_blocks)):
|
|
||||||
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
|
||||||
|
|
||||||
if control_latent_channels is None:
|
|
||||||
control_latent_channels = self.in_channels
|
|
||||||
|
|
||||||
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
|
||||||
None,
|
|
||||||
self.patch_size,
|
|
||||||
control_latent_channels,
|
|
||||||
self.hidden_size,
|
|
||||||
bias=True,
|
|
||||||
strict_img_size=False,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
operations=operations
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
timesteps: torch.Tensor,
|
|
||||||
y: Optional[torch.Tensor] = None,
|
|
||||||
context: Optional[torch.Tensor] = None,
|
|
||||||
hint = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
#weird sd3 controlnet specific stuff
|
|
||||||
y = torch.zeros_like(y)
|
|
||||||
|
|
||||||
if self.context_processor is not None:
|
|
||||||
context = self.context_processor(context)
|
|
||||||
|
|
||||||
hw = x.shape[-2:]
|
|
||||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
|
|
||||||
x += self.pos_embed_input(hint)
|
|
||||||
|
|
||||||
c = self.t_embedder(timesteps, dtype=x.dtype)
|
|
||||||
if y is not None and self.y_embedder is not None:
|
|
||||||
y = self.y_embedder(y)
|
|
||||||
c = c + y
|
|
||||||
|
|
||||||
if context is not None:
|
|
||||||
context = self.context_embedder(context)
|
|
||||||
|
|
||||||
output = []
|
|
||||||
|
|
||||||
blocks = len(self.joint_blocks)
|
|
||||||
for i in range(blocks):
|
|
||||||
context, x = self.joint_blocks[i](
|
|
||||||
context,
|
|
||||||
x,
|
|
||||||
c=c,
|
|
||||||
use_checkpoint=self.use_checkpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = self.controlnet_blocks[i](x)
|
|
||||||
count = self.depth // blocks
|
|
||||||
if i == blocks - 1:
|
|
||||||
count -= 1
|
|
||||||
for j in range(count):
|
|
||||||
output.append(out)
|
|
||||||
|
|
||||||
return {"output": output}
|
|
||||||
@ -1,258 +1,36 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import enum
|
|
||||||
import os
|
|
||||||
import comfy.options
|
|
||||||
|
|
||||||
|
|
||||||
class EnumAction(argparse.Action):
|
|
||||||
"""
|
|
||||||
Argparse action for handling Enums
|
|
||||||
"""
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
# Pop off the type value
|
|
||||||
enum_type = kwargs.pop("type", None)
|
|
||||||
|
|
||||||
# Ensure an Enum subclass is provided
|
|
||||||
if enum_type is None:
|
|
||||||
raise ValueError("type must be assigned an Enum when using EnumAction")
|
|
||||||
if not issubclass(enum_type, enum.Enum):
|
|
||||||
raise TypeError("type must be an Enum when using EnumAction")
|
|
||||||
|
|
||||||
# Generate choices from the Enum
|
|
||||||
choices = tuple(e.value for e in enum_type)
|
|
||||||
kwargs.setdefault("choices", choices)
|
|
||||||
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
|
|
||||||
|
|
||||||
super(EnumAction, self).__init__(**kwargs)
|
|
||||||
|
|
||||||
self._enum = enum_type
|
|
||||||
|
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
|
||||||
# Convert value back into an Enum
|
|
||||||
value = self._enum(values)
|
|
||||||
setattr(namespace, self.dest, value)
|
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
|
||||||
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
|
||||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||||
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
|
||||||
|
|
||||||
parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
|
|
||||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
|
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||||
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
|
|
||||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
|
||||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||||
cm_group = parser.add_mutually_exclusive_group()
|
|
||||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
|
||||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
|
||||||
|
|
||||||
|
|
||||||
fp_group = parser.add_mutually_exclusive_group()
|
|
||||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
|
||||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
|
||||||
|
|
||||||
fpunet_group = parser.add_mutually_exclusive_group()
|
|
||||||
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
|
|
||||||
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
|
|
||||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
|
|
||||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
|
||||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
|
||||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
|
||||||
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
|
|
||||||
|
|
||||||
fpvae_group = parser.add_mutually_exclusive_group()
|
|
||||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
|
||||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
|
||||||
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
|
||||||
|
|
||||||
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
|
|
||||||
|
|
||||||
fpte_group = parser.add_mutually_exclusive_group()
|
|
||||||
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
|
|
||||||
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
|
||||||
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
|
||||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
|
||||||
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
|
||||||
|
|
||||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
|
||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
|
||||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
|
||||||
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
|
||||||
|
|
||||||
class LatentPreviewMethod(enum.Enum):
|
|
||||||
NoPreviews = "none"
|
|
||||||
Auto = "auto"
|
|
||||||
Latent2RGB = "latent2rgb"
|
|
||||||
TAESD = "taesd"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_string(cls, value: str):
|
|
||||||
for member in cls:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
return None
|
|
||||||
|
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
|
||||||
|
|
||||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
|
||||||
|
|
||||||
cache_group = parser.add_mutually_exclusive_group()
|
|
||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
|
||||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
|
||||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
|
||||||
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
upcast = parser.add_mutually_exclusive_group()
|
|
||||||
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
|
|
||||||
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
|
||||||
manager_group = parser.add_mutually_exclusive_group()
|
|
||||||
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
|
||||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
|
||||||
|
|
||||||
|
|
||||||
vram_group = parser.add_mutually_exclusive_group()
|
vram_group = parser.add_mutually_exclusive_group()
|
||||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
|
||||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||||
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
||||||
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
||||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||||
|
|
||||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
|
||||||
|
|
||||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
|
||||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
|
||||||
|
|
||||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
|
||||||
|
|
||||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
|
||||||
|
|
||||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
|
||||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
|
||||||
|
|
||||||
class PerformanceFeature(enum.Enum):
|
|
||||||
Fp16Accumulation = "fp16_accumulation"
|
|
||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
|
||||||
CublasOps = "cublas_ops"
|
|
||||||
AutoTune = "autotune"
|
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
|
||||||
|
|
||||||
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
|
|
||||||
|
|
||||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
|
||||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||||
|
|
||||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
args = parser.parse_args()
|
||||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
|
||||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
|
||||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
|
|
||||||
|
|
||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
|
||||||
|
|
||||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
|
||||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
|
||||||
|
|
||||||
|
|
||||||
# The default built-in provider hosted under web/
|
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--front-end-version",
|
|
||||||
type=str,
|
|
||||||
default=DEFAULT_VERSION_STRING,
|
|
||||||
help="""
|
|
||||||
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
|
|
||||||
download available frontend implementations from GitHub releases.
|
|
||||||
|
|
||||||
The version string should be in the format of:
|
|
||||||
[repoOwner]/[repoName]@[version]
|
|
||||||
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_valid_directory(path: str) -> str:
|
|
||||||
"""Validate if the given path is a directory, and check permissions."""
|
|
||||||
if not os.path.exists(path):
|
|
||||||
raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
|
|
||||||
if not os.path.isdir(path):
|
|
||||||
raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
|
|
||||||
if not os.access(path, os.R_OK):
|
|
||||||
raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
|
|
||||||
return path
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--front-end-root",
|
|
||||||
type=is_valid_directory,
|
|
||||||
default=None,
|
|
||||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
|
|
||||||
|
|
||||||
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--comfy-api-base",
|
|
||||||
type=str,
|
|
||||||
default="https://api.comfy.org",
|
|
||||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
|
||||||
)
|
|
||||||
|
|
||||||
database_default_path = os.path.abspath(
|
|
||||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
|
||||||
)
|
|
||||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
|
||||||
args = parser.parse_args()
|
|
||||||
else:
|
|
||||||
args = parser.parse_args([])
|
|
||||||
|
|
||||||
if args.windows_standalone_build:
|
if args.windows_standalone_build:
|
||||||
args.auto_launch = True
|
args.auto_launch = True
|
||||||
|
|
||||||
if args.disable_auto_launch:
|
|
||||||
args.auto_launch = False
|
|
||||||
|
|
||||||
if args.force_fp16:
|
|
||||||
args.fp16_unet = True
|
|
||||||
|
|
||||||
|
|
||||||
# '--fast' is not provided, use an empty set
|
|
||||||
if args.fast is None:
|
|
||||||
args.fast = set()
|
|
||||||
# '--fast' is provided with an empty list, enable all optimizations
|
|
||||||
elif args.fast == []:
|
|
||||||
args.fast = set(PerformanceFeature)
|
|
||||||
# '--fast' is provided with a list of performance features, use that list
|
|
||||||
else:
|
|
||||||
args.fast = set(args.fast)
|
|
||||||
|
|||||||
@ -1,23 +0,0 @@
|
|||||||
{
|
|
||||||
"architectures": [
|
|
||||||
"CLIPTextModel"
|
|
||||||
],
|
|
||||||
"attention_dropout": 0.0,
|
|
||||||
"bos_token_id": 0,
|
|
||||||
"dropout": 0.0,
|
|
||||||
"eos_token_id": 49407,
|
|
||||||
"hidden_act": "gelu",
|
|
||||||
"hidden_size": 1280,
|
|
||||||
"initializer_factor": 1.0,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"intermediate_size": 5120,
|
|
||||||
"layer_norm_eps": 1e-05,
|
|
||||||
"max_position_embeddings": 77,
|
|
||||||
"model_type": "clip_text_model",
|
|
||||||
"num_attention_heads": 20,
|
|
||||||
"num_hidden_layers": 32,
|
|
||||||
"pad_token_id": 1,
|
|
||||||
"projection_dim": 1280,
|
|
||||||
"torch_dtype": "float32",
|
|
||||||
"vocab_size": 49408
|
|
||||||
}
|
|
||||||
@ -1,254 +0,0 @@
|
|||||||
import torch
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
|
||||||
import comfy.ops
|
|
||||||
|
|
||||||
class CLIPAttention(torch.nn.Module):
|
|
||||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.heads = heads
|
|
||||||
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
|
||||||
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
|
||||||
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, x, mask=None, optimized_attention=None):
|
|
||||||
q = self.q_proj(x)
|
|
||||||
k = self.k_proj(x)
|
|
||||||
v = self.v_proj(x)
|
|
||||||
|
|
||||||
out = optimized_attention(q, k, v, self.heads, mask)
|
|
||||||
return self.out_proj(out)
|
|
||||||
|
|
||||||
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
|
||||||
"gelu": torch.nn.functional.gelu,
|
|
||||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
|
||||||
}
|
|
||||||
|
|
||||||
class CLIPMLP(torch.nn.Module):
|
|
||||||
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
|
|
||||||
self.activation = ACTIVATIONS[activation]
|
|
||||||
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.activation(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class CLIPLayer(torch.nn.Module):
|
|
||||||
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
|
||||||
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
|
|
||||||
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
|
||||||
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
|
|
||||||
|
|
||||||
def forward(self, x, mask=None, optimized_attention=None):
|
|
||||||
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
|
|
||||||
x += self.mlp(self.layer_norm2(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPEncoder(torch.nn.Module):
|
|
||||||
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
|
|
||||||
|
|
||||||
def forward(self, x, mask=None, intermediate_output=None):
|
|
||||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
|
||||||
|
|
||||||
all_intermediate = None
|
|
||||||
if intermediate_output is not None:
|
|
||||||
if intermediate_output == "all":
|
|
||||||
all_intermediate = []
|
|
||||||
intermediate_output = None
|
|
||||||
elif intermediate_output < 0:
|
|
||||||
intermediate_output = len(self.layers) + intermediate_output
|
|
||||||
|
|
||||||
intermediate = None
|
|
||||||
for i, l in enumerate(self.layers):
|
|
||||||
x = l(x, mask, optimized_attention)
|
|
||||||
if i == intermediate_output:
|
|
||||||
intermediate = x.clone()
|
|
||||||
if all_intermediate is not None:
|
|
||||||
all_intermediate.append(x.unsqueeze(1).clone())
|
|
||||||
|
|
||||||
if all_intermediate is not None:
|
|
||||||
intermediate = torch.cat(all_intermediate, dim=1)
|
|
||||||
|
|
||||||
return x, intermediate
|
|
||||||
|
|
||||||
class CLIPEmbeddings(torch.nn.Module):
|
|
||||||
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
|
|
||||||
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, input_tokens, dtype=torch.float32):
|
|
||||||
return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device)
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextModel_(torch.nn.Module):
|
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
|
||||||
num_layers = config_dict["num_hidden_layers"]
|
|
||||||
embed_dim = config_dict["hidden_size"]
|
|
||||||
heads = config_dict["num_attention_heads"]
|
|
||||||
intermediate_size = config_dict["intermediate_size"]
|
|
||||||
intermediate_activation = config_dict["hidden_act"]
|
|
||||||
num_positions = config_dict["max_position_embeddings"]
|
|
||||||
self.eos_token_id = config_dict["eos_token_id"]
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
|
|
||||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
|
||||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
|
|
||||||
if embeds is not None:
|
|
||||||
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
|
||||||
else:
|
|
||||||
x = self.embeddings(input_tokens, dtype=dtype)
|
|
||||||
|
|
||||||
mask = None
|
|
||||||
if attention_mask is not None:
|
|
||||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
|
||||||
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
|
||||||
|
|
||||||
causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
mask += causal_mask
|
|
||||||
else:
|
|
||||||
mask = causal_mask
|
|
||||||
|
|
||||||
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
|
|
||||||
x = self.final_layer_norm(x)
|
|
||||||
if i is not None and final_layer_norm_intermediate:
|
|
||||||
i = self.final_layer_norm(i)
|
|
||||||
|
|
||||||
if num_tokens is not None:
|
|
||||||
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
|
|
||||||
else:
|
|
||||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
|
||||||
return x, i, pooled_output
|
|
||||||
|
|
||||||
class CLIPTextModel(torch.nn.Module):
|
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
self.num_layers = config_dict["num_hidden_layers"]
|
|
||||||
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
|
||||||
embed_dim = config_dict["hidden_size"]
|
|
||||||
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.text_model.embeddings.token_embedding
|
|
||||||
|
|
||||||
def set_input_embeddings(self, embeddings):
|
|
||||||
self.text_model.embeddings.token_embedding = embeddings
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
x = self.text_model(*args, **kwargs)
|
|
||||||
out = self.text_projection(x[2])
|
|
||||||
return (x[0], x[1], out, x[2])
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
|
||||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
num_patches = (image_size // patch_size) ** 2
|
|
||||||
if model_type == "siglip_vision_model":
|
|
||||||
self.class_embedding = None
|
|
||||||
patch_bias = True
|
|
||||||
else:
|
|
||||||
num_patches = num_patches + 1
|
|
||||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
|
||||||
patch_bias = False
|
|
||||||
|
|
||||||
self.patch_embedding = operations.Conv2d(
|
|
||||||
in_channels=num_channels,
|
|
||||||
out_channels=embed_dim,
|
|
||||||
kernel_size=patch_size,
|
|
||||||
stride=patch_size,
|
|
||||||
bias=patch_bias,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
|
||||||
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
|
||||||
if self.class_embedding is not None:
|
|
||||||
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
|
|
||||||
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVision(torch.nn.Module):
|
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
num_layers = config_dict["num_hidden_layers"]
|
|
||||||
embed_dim = config_dict["hidden_size"]
|
|
||||||
heads = config_dict["num_attention_heads"]
|
|
||||||
intermediate_size = config_dict["intermediate_size"]
|
|
||||||
intermediate_activation = config_dict["hidden_act"]
|
|
||||||
model_type = config_dict["model_type"]
|
|
||||||
|
|
||||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
|
||||||
if model_type == "siglip_vision_model":
|
|
||||||
self.pre_layrnorm = lambda a: a
|
|
||||||
self.output_layernorm = True
|
|
||||||
else:
|
|
||||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
|
||||||
self.output_layernorm = False
|
|
||||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
|
||||||
self.post_layernorm = operations.LayerNorm(embed_dim)
|
|
||||||
|
|
||||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
|
||||||
x = self.embeddings(pixel_values)
|
|
||||||
x = self.pre_layrnorm(x)
|
|
||||||
#TODO: attention_mask?
|
|
||||||
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
|
||||||
if self.output_layernorm:
|
|
||||||
x = self.post_layernorm(x)
|
|
||||||
pooled_output = x
|
|
||||||
else:
|
|
||||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
|
||||||
return x, i, pooled_output
|
|
||||||
|
|
||||||
class LlavaProjector(torch.nn.Module):
|
|
||||||
def __init__(self, in_dim, out_dim, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
|
|
||||||
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
|
|
||||||
|
|
||||||
class CLIPVisionModelProjection(torch.nn.Module):
|
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
|
||||||
if "projection_dim" in config_dict:
|
|
||||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
|
||||||
else:
|
|
||||||
self.visual_projection = lambda a: a
|
|
||||||
|
|
||||||
if "llava3" == config_dict.get("projector_type", None):
|
|
||||||
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
|
|
||||||
else:
|
|
||||||
self.multi_modal_projector = None
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
x = self.vision_model(*args, **kwargs)
|
|
||||||
out = self.visual_projection(x[2])
|
|
||||||
projected = None
|
|
||||||
if self.multi_modal_projector is not None:
|
|
||||||
projected = self.multi_modal_projector(x[1])
|
|
||||||
|
|
||||||
return (x[0], x[1], out, projected)
|
|
||||||
@ -1,164 +1,64 @@
|
|||||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
|
||||||
|
from .utils import load_torch_file, transformers_convert
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import comfy.ops
|
|
||||||
import comfy.model_patcher
|
|
||||||
import comfy.model_management
|
|
||||||
import comfy.utils
|
|
||||||
import comfy.clip_model
|
|
||||||
import comfy.image_encoders.dino2
|
|
||||||
|
|
||||||
class Output:
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return getattr(self, key)
|
|
||||||
def __setitem__(self, key, item):
|
|
||||||
setattr(self, key, item)
|
|
||||||
|
|
||||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
|
||||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
|
||||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
|
||||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
|
||||||
image = image.movedim(-1, 1)
|
|
||||||
if not (image.shape[2] == size and image.shape[3] == size):
|
|
||||||
if crop:
|
|
||||||
scale = (size / min(image.shape[2], image.shape[3]))
|
|
||||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
|
||||||
else:
|
|
||||||
scale_size = (size, size)
|
|
||||||
|
|
||||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
|
||||||
h = (image.shape[2] - size)//2
|
|
||||||
w = (image.shape[3] - size)//2
|
|
||||||
image = image[:,:,h:h+size,w:w+size]
|
|
||||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
|
||||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
|
||||||
|
|
||||||
IMAGE_ENCODERS = {
|
|
||||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
|
||||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
|
||||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
|
||||||
}
|
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
def __init__(self, json_config):
|
def __init__(self, json_config):
|
||||||
with open(json_config) as f:
|
config = CLIPVisionConfig.from_json_file(json_config)
|
||||||
config = json.load(f)
|
self.model = CLIPVisionModelWithProjection(config)
|
||||||
|
self.processor = CLIPImageProcessor(crop_size=224,
|
||||||
self.image_size = config.get("image_size", 224)
|
do_center_crop=True,
|
||||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
do_convert_rgb=True,
|
||||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
do_normalize=True,
|
||||||
model_type = config.get("model_type", "clip_vision_model")
|
do_resize=True,
|
||||||
model_class = IMAGE_ENCODERS.get(model_type)
|
image_mean=[ 0.48145466,0.4578275,0.40821073],
|
||||||
if model_type == "siglip_vision_model":
|
image_std=[0.26862954,0.26130258,0.27577711],
|
||||||
self.return_all_hidden_states = True
|
resample=3, #bicubic
|
||||||
else:
|
size=224)
|
||||||
self.return_all_hidden_states = False
|
|
||||||
|
|
||||||
self.load_device = comfy.model_management.text_encoder_device()
|
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
|
||||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.model.load_state_dict(sd, strict=False)
|
self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
def get_sd(self):
|
def encode_image(self, image):
|
||||||
return self.model.state_dict()
|
img = torch.clip((255. * image[0]), 0, 255).round().int()
|
||||||
|
inputs = self.processor(images=[img], return_tensors="pt")
|
||||||
def encode_image(self, image, crop=True):
|
outputs = self.model(**inputs)
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
|
||||||
|
|
||||||
outputs = Output()
|
|
||||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
|
||||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
|
||||||
if self.return_all_hidden_states:
|
|
||||||
all_hs = out[1].to(comfy.model_management.intermediate_device())
|
|
||||||
outputs["penultimate_hidden_states"] = all_hs[:, -2]
|
|
||||||
outputs["all_hidden_states"] = all_hs
|
|
||||||
else:
|
|
||||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
|
||||||
|
|
||||||
outputs["mm_projected"] = out[3]
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def convert_to_transformers(sd, prefix):
|
def convert_to_transformers(sd):
|
||||||
sd_k = sd.keys()
|
sd_k = sd.keys()
|
||||||
if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
|
if "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" in sd_k:
|
||||||
keys_to_replace = {
|
keys_to_replace = {
|
||||||
"{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
|
"embedder.model.visual.class_embedding": "vision_model.embeddings.class_embedding",
|
||||||
"{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
|
"embedder.model.visual.conv1.weight": "vision_model.embeddings.patch_embedding.weight",
|
||||||
"{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
|
"embedder.model.visual.positional_embedding": "vision_model.embeddings.position_embedding.weight",
|
||||||
"{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
|
"embedder.model.visual.ln_post.bias": "vision_model.post_layernorm.bias",
|
||||||
"{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
|
"embedder.model.visual.ln_post.weight": "vision_model.post_layernorm.weight",
|
||||||
"{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
|
"embedder.model.visual.ln_pre.bias": "vision_model.pre_layrnorm.bias",
|
||||||
"{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
|
"embedder.model.visual.ln_pre.weight": "vision_model.pre_layrnorm.weight",
|
||||||
}
|
}
|
||||||
|
|
||||||
for x in keys_to_replace:
|
for x in keys_to_replace:
|
||||||
if x in sd_k:
|
if x in sd_k:
|
||||||
sd[keys_to_replace[x]] = sd.pop(x)
|
sd[keys_to_replace[x]] = sd.pop(x)
|
||||||
|
|
||||||
if "{}proj".format(prefix) in sd_k:
|
if "embedder.model.visual.proj" in sd_k:
|
||||||
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
|
sd['visual_projection.weight'] = sd.pop("embedder.model.visual.proj").transpose(0, 1)
|
||||||
|
|
||||||
sd = transformers_convert(sd, prefix, "vision_model.", 48)
|
sd = transformers_convert(sd, "embedder.model.visual", "vision_model", 32)
|
||||||
else:
|
|
||||||
replace_prefix = {prefix: ""}
|
|
||||||
sd = state_dict_prefix_replace(sd, replace_prefix)
|
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
def load_clipvision_from_sd(sd):
|
||||||
if convert_keys:
|
sd = convert_to_transformers(sd)
|
||||||
sd = convert_to_transformers(sd, prefix)
|
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||||
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
|
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
|
|
||||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
|
||||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
|
||||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
|
||||||
if embed_shape == 729:
|
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
|
||||||
elif embed_shape == 1024:
|
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
|
||||||
elif embed_shape == 577:
|
|
||||||
if "multi_modal_projector.linear_1.bias" in sd:
|
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
|
||||||
else:
|
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
|
||||||
else:
|
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
|
||||||
|
|
||||||
# Dinov2
|
|
||||||
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
|
|
||||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
|
||||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
|
||||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
|
||||||
else:
|
else:
|
||||||
return None
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
|
|
||||||
clip = ClipVisionModel(json_config)
|
clip = ClipVisionModel(json_config)
|
||||||
m, u = clip.load_sd(sd)
|
clip.load_sd(sd)
|
||||||
if len(m) > 0:
|
|
||||||
logging.warning("missing clip vision: {}".format(m))
|
|
||||||
u = set(u)
|
|
||||||
keys = list(sd.keys())
|
|
||||||
for k in keys:
|
|
||||||
if k not in u:
|
|
||||||
sd.pop(k)
|
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load(ckpt_path):
|
def load(ckpt_path):
|
||||||
sd = load_torch_file(ckpt_path)
|
sd = load_torch_file(ckpt_path)
|
||||||
if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
|
return load_clipvision_from_sd(sd)
|
||||||
return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
|
|
||||||
else:
|
|
||||||
return load_clipvision_from_sd(sd)
|
|
||||||
|
|||||||
@ -1,18 +0,0 @@
|
|||||||
{
|
|
||||||
"attention_dropout": 0.0,
|
|
||||||
"dropout": 0.0,
|
|
||||||
"hidden_act": "gelu",
|
|
||||||
"hidden_size": 1664,
|
|
||||||
"image_size": 224,
|
|
||||||
"initializer_factor": 1.0,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"intermediate_size": 8192,
|
|
||||||
"layer_norm_eps": 1e-05,
|
|
||||||
"model_type": "clip_vision_model",
|
|
||||||
"num_attention_heads": 16,
|
|
||||||
"num_channels": 3,
|
|
||||||
"num_hidden_layers": 48,
|
|
||||||
"patch_size": 14,
|
|
||||||
"projection_dim": 1280,
|
|
||||||
"torch_dtype": "float32"
|
|
||||||
}
|
|
||||||
@ -1,18 +0,0 @@
|
|||||||
{
|
|
||||||
"attention_dropout": 0.0,
|
|
||||||
"dropout": 0.0,
|
|
||||||
"hidden_act": "quick_gelu",
|
|
||||||
"hidden_size": 1024,
|
|
||||||
"image_size": 336,
|
|
||||||
"initializer_factor": 1.0,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"intermediate_size": 4096,
|
|
||||||
"layer_norm_eps": 1e-5,
|
|
||||||
"model_type": "clip_vision_model",
|
|
||||||
"num_attention_heads": 16,
|
|
||||||
"num_channels": 3,
|
|
||||||
"num_hidden_layers": 24,
|
|
||||||
"patch_size": 14,
|
|
||||||
"projection_dim": 768,
|
|
||||||
"torch_dtype": "float32"
|
|
||||||
}
|
|
||||||
@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"attention_dropout": 0.0,
|
|
||||||
"dropout": 0.0,
|
|
||||||
"hidden_act": "quick_gelu",
|
|
||||||
"hidden_size": 1024,
|
|
||||||
"image_size": 336,
|
|
||||||
"initializer_factor": 1.0,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"intermediate_size": 4096,
|
|
||||||
"layer_norm_eps": 1e-5,
|
|
||||||
"model_type": "clip_vision_model",
|
|
||||||
"num_attention_heads": 16,
|
|
||||||
"num_channels": 3,
|
|
||||||
"num_hidden_layers": 24,
|
|
||||||
"patch_size": 14,
|
|
||||||
"projection_dim": 768,
|
|
||||||
"projector_type": "llava3",
|
|
||||||
"torch_dtype": "float32"
|
|
||||||
}
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
{
|
|
||||||
"num_channels": 3,
|
|
||||||
"hidden_act": "gelu_pytorch_tanh",
|
|
||||||
"hidden_size": 1152,
|
|
||||||
"image_size": 384,
|
|
||||||
"intermediate_size": 4304,
|
|
||||||
"model_type": "siglip_vision_model",
|
|
||||||
"num_attention_heads": 16,
|
|
||||||
"num_hidden_layers": 27,
|
|
||||||
"patch_size": 14,
|
|
||||||
"image_mean": [0.5, 0.5, 0.5],
|
|
||||||
"image_std": [0.5, 0.5, 0.5]
|
|
||||||
}
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
{
|
|
||||||
"num_channels": 3,
|
|
||||||
"hidden_act": "gelu_pytorch_tanh",
|
|
||||||
"hidden_size": 1152,
|
|
||||||
"image_size": 512,
|
|
||||||
"intermediate_size": 4304,
|
|
||||||
"model_type": "siglip_vision_model",
|
|
||||||
"num_attention_heads": 16,
|
|
||||||
"num_hidden_layers": 27,
|
|
||||||
"patch_size": 16,
|
|
||||||
"image_mean": [0.5, 0.5, 0.5],
|
|
||||||
"image_std": [0.5, 0.5, 0.5]
|
|
||||||
}
|
|
||||||
@ -1,43 +0,0 @@
|
|||||||
# Comfy Typing
|
|
||||||
## Type hinting for ComfyUI Node development
|
|
||||||
|
|
||||||
This module provides type hinting and concrete convenience types for node developers.
|
|
||||||
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from comfy.comfy_types import IO, ComfyNodeABC, CheckLazyMixin
|
|
||||||
|
|
||||||
class ExampleNode(ComfyNodeABC):
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s) -> InputTypeDict:
|
|
||||||
return {"required": {}}
|
|
||||||
```
|
|
||||||
|
|
||||||
Full example is in [examples/example_nodes.py](examples/example_nodes.py).
|
|
||||||
|
|
||||||
# Types
|
|
||||||
A few primary types are documented below. More complete information is available via the docstrings on each type.
|
|
||||||
|
|
||||||
## `IO`
|
|
||||||
|
|
||||||
A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
|
|
||||||
|
|
||||||
- `ANY`: `"*"`
|
|
||||||
- `NUMBER`: `"FLOAT,INT"`
|
|
||||||
- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
|
|
||||||
|
|
||||||
## `ComfyNodeABC`
|
|
||||||
|
|
||||||
An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
|
|
||||||
|
|
||||||
### Type hinting for `INPUT_TYPES`
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
### `INPUT_TYPES` return dict
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
### Options for individual inputs
|
|
||||||
|
|
||||||

|
|
||||||
@ -1,46 +0,0 @@
|
|||||||
import torch
|
|
||||||
from typing import Callable, Protocol, TypedDict, Optional, List
|
|
||||||
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin, FileLocator
|
|
||||||
|
|
||||||
|
|
||||||
class UnetApplyFunction(Protocol):
|
|
||||||
"""Function signature protocol on comfy.model_base.BaseModel.apply_model"""
|
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class UnetApplyConds(TypedDict):
|
|
||||||
"""Optional conditions for unet apply function."""
|
|
||||||
|
|
||||||
c_concat: Optional[torch.Tensor]
|
|
||||||
c_crossattn: Optional[torch.Tensor]
|
|
||||||
control: Optional[torch.Tensor]
|
|
||||||
transformer_options: Optional[dict]
|
|
||||||
|
|
||||||
|
|
||||||
class UnetParams(TypedDict):
|
|
||||||
# Tensor of shape [B, C, H, W]
|
|
||||||
input: torch.Tensor
|
|
||||||
# Tensor of shape [B]
|
|
||||||
timestep: torch.Tensor
|
|
||||||
c: UnetApplyConds
|
|
||||||
# List of [0, 1], [0], [1], ...
|
|
||||||
# 0 means conditional, 1 means conditional unconditional
|
|
||||||
cond_or_uncond: List[int]
|
|
||||||
|
|
||||||
|
|
||||||
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"UnetWrapperFunction",
|
|
||||||
UnetApplyConds.__name__,
|
|
||||||
UnetParams.__name__,
|
|
||||||
UnetApplyFunction.__name__,
|
|
||||||
IO.__name__,
|
|
||||||
InputTypeDict.__name__,
|
|
||||||
ComfyNodeABC.__name__,
|
|
||||||
CheckLazyMixin.__name__,
|
|
||||||
FileLocator.__name__,
|
|
||||||
]
|
|
||||||
@ -1,28 +0,0 @@
|
|||||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
|
||||||
from inspect import cleandoc
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleNode(ComfyNodeABC):
|
|
||||||
"""An example node that just adds 1 to an input integer.
|
|
||||||
|
|
||||||
* Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
|
|
||||||
* This node is intended as an example for developers only.
|
|
||||||
"""
|
|
||||||
|
|
||||||
DESCRIPTION = cleandoc(__doc__)
|
|
||||||
CATEGORY = "examples"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s) -> InputTypeDict:
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"input_int": (IO.INT, {"defaultInput": True}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.INT,)
|
|
||||||
RETURN_NAMES = ("input_plus_one",)
|
|
||||||
FUNCTION = "execute"
|
|
||||||
|
|
||||||
def execute(self, input_int: int):
|
|
||||||
return (input_int + 1,)
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 19 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 16 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 19 KiB |
@ -1,350 +0,0 @@
|
|||||||
"""Comfy-specific type hinting"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
from typing import Literal, TypedDict, Optional
|
|
||||||
from typing_extensions import NotRequired
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class StrEnum(str, Enum):
|
|
||||||
"""Base class for string enums. Python's StrEnum is not available until 3.11."""
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
|
|
||||||
class IO(StrEnum):
|
|
||||||
"""Node input/output data types.
|
|
||||||
|
|
||||||
Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
STRING = "STRING"
|
|
||||||
IMAGE = "IMAGE"
|
|
||||||
MASK = "MASK"
|
|
||||||
LATENT = "LATENT"
|
|
||||||
BOOLEAN = "BOOLEAN"
|
|
||||||
INT = "INT"
|
|
||||||
FLOAT = "FLOAT"
|
|
||||||
COMBO = "COMBO"
|
|
||||||
CONDITIONING = "CONDITIONING"
|
|
||||||
SAMPLER = "SAMPLER"
|
|
||||||
SIGMAS = "SIGMAS"
|
|
||||||
GUIDER = "GUIDER"
|
|
||||||
NOISE = "NOISE"
|
|
||||||
CLIP = "CLIP"
|
|
||||||
CONTROL_NET = "CONTROL_NET"
|
|
||||||
VAE = "VAE"
|
|
||||||
MODEL = "MODEL"
|
|
||||||
LORA_MODEL = "LORA_MODEL"
|
|
||||||
LOSS_MAP = "LOSS_MAP"
|
|
||||||
CLIP_VISION = "CLIP_VISION"
|
|
||||||
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
|
||||||
STYLE_MODEL = "STYLE_MODEL"
|
|
||||||
GLIGEN = "GLIGEN"
|
|
||||||
UPSCALE_MODEL = "UPSCALE_MODEL"
|
|
||||||
AUDIO = "AUDIO"
|
|
||||||
WEBCAM = "WEBCAM"
|
|
||||||
POINT = "POINT"
|
|
||||||
FACE_ANALYSIS = "FACE_ANALYSIS"
|
|
||||||
BBOX = "BBOX"
|
|
||||||
SEGS = "SEGS"
|
|
||||||
VIDEO = "VIDEO"
|
|
||||||
|
|
||||||
ANY = "*"
|
|
||||||
"""Always matches any type, but at a price.
|
|
||||||
|
|
||||||
Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible.
|
|
||||||
"""
|
|
||||||
NUMBER = "FLOAT,INT"
|
|
||||||
"""A float or an int - could be either"""
|
|
||||||
PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN"
|
|
||||||
"""Could be any of: string, float, int, or bool"""
|
|
||||||
|
|
||||||
def __ne__(self, value: object) -> bool:
|
|
||||||
if self == "*" or value == "*":
|
|
||||||
return False
|
|
||||||
if not isinstance(value, str):
|
|
||||||
return True
|
|
||||||
a = frozenset(self.split(","))
|
|
||||||
b = frozenset(value.split(","))
|
|
||||||
return not (b.issubset(a) or a.issubset(b))
|
|
||||||
|
|
||||||
|
|
||||||
class RemoteInputOptions(TypedDict):
|
|
||||||
route: str
|
|
||||||
"""The route to the remote source."""
|
|
||||||
refresh_button: bool
|
|
||||||
"""Specifies whether to show a refresh button in the UI below the widget."""
|
|
||||||
control_after_refresh: Literal["first", "last"]
|
|
||||||
"""Specifies the control after the refresh button is clicked. If "first", the first item will be automatically selected, and so on."""
|
|
||||||
timeout: int
|
|
||||||
"""The maximum amount of time to wait for a response from the remote source in milliseconds."""
|
|
||||||
max_retries: int
|
|
||||||
"""The maximum number of retries before aborting the request."""
|
|
||||||
refresh: int
|
|
||||||
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
|
|
||||||
|
|
||||||
|
|
||||||
class MultiSelectOptions(TypedDict):
|
|
||||||
placeholder: NotRequired[str]
|
|
||||||
"""The placeholder text to display in the multi-select widget when no items are selected."""
|
|
||||||
chip: NotRequired[bool]
|
|
||||||
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
|
|
||||||
|
|
||||||
|
|
||||||
class InputTypeOptions(TypedDict):
|
|
||||||
"""Provides type hinting for the return type of the INPUT_TYPES node function.
|
|
||||||
|
|
||||||
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
|
|
||||||
"""
|
|
||||||
|
|
||||||
default: NotRequired[bool | str | float | int | list | tuple]
|
|
||||||
"""The default value of the widget"""
|
|
||||||
defaultInput: NotRequired[bool]
|
|
||||||
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
|
|
||||||
- defaultInput on required inputs should be dropped.
|
|
||||||
- defaultInput on optional inputs should be replaced with forceInput.
|
|
||||||
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
|
|
||||||
"""
|
|
||||||
forceInput: NotRequired[bool]
|
|
||||||
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
|
|
||||||
lazy: NotRequired[bool]
|
|
||||||
"""Declares that this input uses lazy evaluation"""
|
|
||||||
rawLink: NotRequired[bool]
|
|
||||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
|
||||||
tooltip: NotRequired[str]
|
|
||||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
|
||||||
socketless: NotRequired[bool]
|
|
||||||
"""All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created.
|
|
||||||
Available from frontend v1.17.5
|
|
||||||
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
|
|
||||||
"""
|
|
||||||
widgetType: NotRequired[str]
|
|
||||||
"""Specifies a type to be used for widget initialization if different from the input type.
|
|
||||||
Available from frontend v1.18.0
|
|
||||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/3550"""
|
|
||||||
# class InputTypeNumber(InputTypeOptions):
|
|
||||||
# default: float | int
|
|
||||||
min: NotRequired[float]
|
|
||||||
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
|
|
||||||
max: NotRequired[float]
|
|
||||||
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
|
|
||||||
step: NotRequired[float]
|
|
||||||
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
|
|
||||||
round: NotRequired[float]
|
|
||||||
"""Floats are rounded by this value (``FLOAT``)"""
|
|
||||||
# class InputTypeBoolean(InputTypeOptions):
|
|
||||||
# default: bool
|
|
||||||
label_on: NotRequired[str]
|
|
||||||
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
|
|
||||||
label_off: NotRequired[str]
|
|
||||||
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
|
|
||||||
# class InputTypeString(InputTypeOptions):
|
|
||||||
# default: str
|
|
||||||
multiline: NotRequired[bool]
|
|
||||||
"""Use a multiline text box (``STRING``)"""
|
|
||||||
placeholder: NotRequired[str]
|
|
||||||
"""Placeholder text to display in the UI when empty (``STRING``)"""
|
|
||||||
# Deprecated:
|
|
||||||
# defaultVal: str
|
|
||||||
dynamicPrompts: NotRequired[bool]
|
|
||||||
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
|
|
||||||
# class InputTypeCombo(InputTypeOptions):
|
|
||||||
image_upload: NotRequired[bool]
|
|
||||||
"""Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
|
|
||||||
image_folder: NotRequired[Literal["input", "output", "temp"]]
|
|
||||||
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
|
|
||||||
"""
|
|
||||||
remote: NotRequired[RemoteInputOptions]
|
|
||||||
"""Specifies the configuration for a remote input.
|
|
||||||
Available after ComfyUI frontend v1.9.7
|
|
||||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
|
|
||||||
control_after_generate: NotRequired[bool]
|
|
||||||
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
|
|
||||||
options: NotRequired[list[str | int | float]]
|
|
||||||
"""COMBO type only. Specifies the selectable options for the combo widget.
|
|
||||||
Prefer:
|
|
||||||
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
|
|
||||||
Over:
|
|
||||||
[["Option 1", "Option 2", "Option 3"]]
|
|
||||||
"""
|
|
||||||
multi_select: NotRequired[MultiSelectOptions]
|
|
||||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
|
||||||
Available after ComfyUI frontend v1.13.4
|
|
||||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
|
||||||
|
|
||||||
|
|
||||||
class HiddenInputTypeDict(TypedDict):
|
|
||||||
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
|
|
||||||
|
|
||||||
node_id: NotRequired[Literal["UNIQUE_ID"]]
|
|
||||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
|
||||||
unique_id: NotRequired[Literal["UNIQUE_ID"]]
|
|
||||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
|
||||||
prompt: NotRequired[Literal["PROMPT"]]
|
|
||||||
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
|
|
||||||
extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]]
|
|
||||||
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
|
|
||||||
dynprompt: NotRequired[Literal["DYNPROMPT"]]
|
|
||||||
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
|
|
||||||
|
|
||||||
|
|
||||||
class InputTypeDict(TypedDict):
|
|
||||||
"""Provides type hinting for node INPUT_TYPES.
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
|
|
||||||
"""
|
|
||||||
|
|
||||||
required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
|
|
||||||
"""Describes all inputs that must be connected for the node to execute."""
|
|
||||||
optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
|
|
||||||
"""Describes inputs which do not need to be connected."""
|
|
||||||
hidden: NotRequired[HiddenInputTypeDict]
|
|
||||||
"""Offers advanced functionality and server-client communication.
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class ComfyNodeABC(ABC):
|
|
||||||
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview
|
|
||||||
"""
|
|
||||||
|
|
||||||
DESCRIPTION: str
|
|
||||||
"""Node description, shown as a tooltip when hovering over the node.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
# Explicitly define the description
|
|
||||||
DESCRIPTION = "Example description here."
|
|
||||||
|
|
||||||
# Use the docstring of the node class.
|
|
||||||
DESCRIPTION = cleandoc(__doc__)
|
|
||||||
"""
|
|
||||||
CATEGORY: str
|
|
||||||
"""The category of the node, as per the "Add Node" menu.
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#category
|
|
||||||
"""
|
|
||||||
EXPERIMENTAL: bool
|
|
||||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
|
||||||
DEPRECATED: bool
|
|
||||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
|
||||||
API_NODE: Optional[bool]
|
|
||||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def INPUT_TYPES(s) -> InputTypeDict:
|
|
||||||
"""Defines node inputs.
|
|
||||||
|
|
||||||
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
|
|
||||||
* The ``optional`` key can be added to describe inputs which do not need to be connected.
|
|
||||||
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#input-types
|
|
||||||
"""
|
|
||||||
return {"required": {}}
|
|
||||||
|
|
||||||
OUTPUT_NODE: bool
|
|
||||||
"""Flags this node as an output node, causing any inputs it requires to be executed.
|
|
||||||
|
|
||||||
If a node is not connected to any output nodes, that node will not be executed. Usage::
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
From the docs:
|
|
||||||
|
|
||||||
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
|
|
||||||
"""
|
|
||||||
INPUT_IS_LIST: bool
|
|
||||||
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
|
||||||
|
|
||||||
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
|
|
||||||
|
|
||||||
From the docs:
|
|
||||||
|
|
||||||
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
|
||||||
"""
|
|
||||||
OUTPUT_IS_LIST: tuple[bool, ...]
|
|
||||||
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
|
||||||
|
|
||||||
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
|
|
||||||
|
|
||||||
A ``tuple[bool]``, where the items match those in `RETURN_TYPES`::
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.INT, IO.INT, IO.STRING)
|
|
||||||
OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally
|
|
||||||
|
|
||||||
From the docs:
|
|
||||||
|
|
||||||
In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing,
|
|
||||||
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
|
|
||||||
specifying which outputs which should be so treated.
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
|
||||||
"""
|
|
||||||
|
|
||||||
RETURN_TYPES: tuple[IO, ...]
|
|
||||||
"""A tuple representing the outputs of this node.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
|
|
||||||
"""
|
|
||||||
RETURN_NAMES: tuple[str, ...]
|
|
||||||
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
|
|
||||||
"""
|
|
||||||
OUTPUT_TOOLTIPS: tuple[str, ...]
|
|
||||||
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
|
||||||
FUNCTION: str
|
|
||||||
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#function
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class CheckLazyMixin:
|
|
||||||
"""Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs."""
|
|
||||||
|
|
||||||
def check_lazy_status(self, **kwargs) -> list[str]:
|
|
||||||
"""Returns a list of input names that should be evaluated.
|
|
||||||
|
|
||||||
This basic mixin impl. requires all inputs.
|
|
||||||
|
|
||||||
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
|
|
||||||
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
|
|
||||||
|
|
||||||
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
|
|
||||||
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
|
|
||||||
|
|
||||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status
|
|
||||||
"""
|
|
||||||
|
|
||||||
need = [name for name in kwargs if kwargs[name] is None]
|
|
||||||
return need
|
|
||||||
|
|
||||||
|
|
||||||
class FileLocator(TypedDict):
|
|
||||||
"""Provides type hinting for the file location"""
|
|
||||||
|
|
||||||
filename: str
|
|
||||||
"""The filename of the file."""
|
|
||||||
subfolder: str
|
|
||||||
"""The subfolder of the file."""
|
|
||||||
type: Literal["input", "output", "temp"]
|
|
||||||
"""The root folder of the file."""
|
|
||||||
137
comfy/conds.py
137
comfy/conds.py
@ -1,137 +0,0 @@
|
|||||||
import torch
|
|
||||||
import math
|
|
||||||
import comfy.utils
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
class CONDRegular:
|
|
||||||
def __init__(self, cond):
|
|
||||||
self.cond = cond
|
|
||||||
|
|
||||||
def _copy_with(self, cond):
|
|
||||||
return self.__class__(cond)
|
|
||||||
|
|
||||||
def process_cond(self, batch_size, **kwargs):
|
|
||||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
|
|
||||||
|
|
||||||
def can_concat(self, other):
|
|
||||||
if self.cond.shape != other.cond.shape:
|
|
||||||
return False
|
|
||||||
if self.cond.device != other.cond.device:
|
|
||||||
logging.warning("WARNING: conds not on same device, skipping concat.")
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def concat(self, others):
|
|
||||||
conds = [self.cond]
|
|
||||||
for x in others:
|
|
||||||
conds.append(x.cond)
|
|
||||||
return torch.cat(conds)
|
|
||||||
|
|
||||||
def size(self):
|
|
||||||
return list(self.cond.size())
|
|
||||||
|
|
||||||
|
|
||||||
class CONDNoiseShape(CONDRegular):
|
|
||||||
def process_cond(self, batch_size, area, **kwargs):
|
|
||||||
data = self.cond
|
|
||||||
if area is not None:
|
|
||||||
dims = len(area) // 2
|
|
||||||
for i in range(dims):
|
|
||||||
data = data.narrow(i + 2, area[i + dims], area[i])
|
|
||||||
|
|
||||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
|
|
||||||
|
|
||||||
|
|
||||||
class CONDCrossAttn(CONDRegular):
|
|
||||||
def can_concat(self, other):
|
|
||||||
s1 = self.cond.shape
|
|
||||||
s2 = other.cond.shape
|
|
||||||
if s1 != s2:
|
|
||||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
|
||||||
return False
|
|
||||||
|
|
||||||
mult_min = math.lcm(s1[1], s2[1])
|
|
||||||
diff = mult_min // min(s1[1], s2[1])
|
|
||||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
|
||||||
return False
|
|
||||||
if self.cond.device != other.cond.device:
|
|
||||||
logging.warning("WARNING: conds not on same device: skipping concat.")
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def concat(self, others):
|
|
||||||
conds = [self.cond]
|
|
||||||
crossattn_max_len = self.cond.shape[1]
|
|
||||||
for x in others:
|
|
||||||
c = x.cond
|
|
||||||
crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1])
|
|
||||||
conds.append(c)
|
|
||||||
|
|
||||||
out = []
|
|
||||||
for c in conds:
|
|
||||||
if c.shape[1] < crossattn_max_len:
|
|
||||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
|
||||||
out.append(c)
|
|
||||||
return torch.cat(out)
|
|
||||||
|
|
||||||
|
|
||||||
class CONDConstant(CONDRegular):
|
|
||||||
def __init__(self, cond):
|
|
||||||
self.cond = cond
|
|
||||||
|
|
||||||
def process_cond(self, batch_size, **kwargs):
|
|
||||||
return self._copy_with(self.cond)
|
|
||||||
|
|
||||||
def can_concat(self, other):
|
|
||||||
if self.cond != other.cond:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def concat(self, others):
|
|
||||||
return self.cond
|
|
||||||
|
|
||||||
def size(self):
|
|
||||||
return [1]
|
|
||||||
|
|
||||||
|
|
||||||
class CONDList(CONDRegular):
|
|
||||||
def __init__(self, cond):
|
|
||||||
self.cond = cond
|
|
||||||
|
|
||||||
def process_cond(self, batch_size, **kwargs):
|
|
||||||
out = []
|
|
||||||
for c in self.cond:
|
|
||||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
|
|
||||||
|
|
||||||
return self._copy_with(out)
|
|
||||||
|
|
||||||
def can_concat(self, other):
|
|
||||||
if len(self.cond) != len(other.cond):
|
|
||||||
return False
|
|
||||||
for i in range(len(self.cond)):
|
|
||||||
if self.cond[i].shape != other.cond[i].shape:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def concat(self, others):
|
|
||||||
out = []
|
|
||||||
for i in range(len(self.cond)):
|
|
||||||
o = [self.cond[i]]
|
|
||||||
for x in others:
|
|
||||||
o.append(x.cond[i])
|
|
||||||
out.append(torch.cat(o))
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def size(self): # hackish implementation to make the mem estimation work
|
|
||||||
o = 0
|
|
||||||
c = 1
|
|
||||||
for c in self.cond:
|
|
||||||
size = c.size()
|
|
||||||
o += math.prod(size)
|
|
||||||
if len(size) > 1:
|
|
||||||
c = size[1]
|
|
||||||
|
|
||||||
return [1, c, o // c]
|
|
||||||
@ -1,629 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
from typing import TYPE_CHECKING, Callable
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import collections
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
import logging
|
|
||||||
import comfy.model_management
|
|
||||||
import comfy.patcher_extension
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from comfy.model_base import BaseModel
|
|
||||||
from comfy.model_patcher import ModelPatcher
|
|
||||||
from comfy.controlnet import ControlBase
|
|
||||||
|
|
||||||
|
|
||||||
class ContextWindowABC(ABC):
|
|
||||||
def __init__(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_tensor(self, full: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Get torch.Tensor applicable to current window.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Not implemented.")
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Not implemented.")
|
|
||||||
|
|
||||||
class ContextHandlerABC(ABC):
|
|
||||||
def __init__(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
|
||||||
raise NotImplementedError("Not implemented.")
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list:
|
|
||||||
raise NotImplementedError("Not implemented.")
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
|
||||||
raise NotImplementedError("Not implemented.")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class IndexListContextWindow(ContextWindowABC):
|
|
||||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
|
||||||
self.index_list = index_list
|
|
||||||
self.context_length = len(index_list)
|
|
||||||
self.dim = dim
|
|
||||||
self.total_frames = total_frames
|
|
||||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
|
||||||
|
|
||||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
|
||||||
if dim is None:
|
|
||||||
dim = self.dim
|
|
||||||
if dim == 0 and full.shape[dim] == 1:
|
|
||||||
return full
|
|
||||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
|
||||||
window = full[idx]
|
|
||||||
if retain_index_list:
|
|
||||||
idx = tuple([slice(None)] * dim + [retain_index_list])
|
|
||||||
window[idx] = full[idx]
|
|
||||||
return window.to(device)
|
|
||||||
|
|
||||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
|
||||||
if dim is None:
|
|
||||||
dim = self.dim
|
|
||||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
|
||||||
full[idx] += to_add
|
|
||||||
return full
|
|
||||||
|
|
||||||
def get_region_index(self, num_regions: int) -> int:
|
|
||||||
region_idx = int(self.center_ratio * num_regions)
|
|
||||||
return min(max(region_idx, 0), num_regions - 1)
|
|
||||||
|
|
||||||
|
|
||||||
class IndexListCallbacks:
|
|
||||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
|
||||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
|
||||||
EXECUTE_START = "execute_start"
|
|
||||||
EXECUTE_CLEANUP = "execute_cleanup"
|
|
||||||
RESIZE_COND_ITEM = "resize_cond_item"
|
|
||||||
|
|
||||||
def init_callbacks(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ContextSchedule:
|
|
||||||
name: str
|
|
||||||
func: Callable
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ContextFuseMethod:
|
|
||||||
name: str
|
|
||||||
func: Callable
|
|
||||||
|
|
||||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
|
||||||
class IndexListContextHandler(ContextHandlerABC):
|
|
||||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
|
||||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
|
||||||
self.context_schedule = context_schedule
|
|
||||||
self.fuse_method = fuse_method
|
|
||||||
self.context_length = context_length
|
|
||||||
self.context_overlap = context_overlap
|
|
||||||
self.context_stride = context_stride
|
|
||||||
self.closed_loop = closed_loop
|
|
||||||
self.dim = dim
|
|
||||||
self._step = 0
|
|
||||||
self.freenoise = freenoise
|
|
||||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
|
||||||
self.split_conds_to_windows = split_conds_to_windows
|
|
||||||
|
|
||||||
self.callbacks = {}
|
|
||||||
|
|
||||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
|
||||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
|
||||||
if x_in.size(self.dim) > self.context_length:
|
|
||||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
|
||||||
if self.cond_retain_index_list:
|
|
||||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
|
||||||
if control.previous_controlnet is not None:
|
|
||||||
self.prepare_control_objects(control.previous_controlnet, device)
|
|
||||||
return control
|
|
||||||
|
|
||||||
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list:
|
|
||||||
if cond_in is None:
|
|
||||||
return None
|
|
||||||
# reuse or resize cond items to match context requirements
|
|
||||||
resized_cond = []
|
|
||||||
# if multiple conds, split based on primary region
|
|
||||||
if self.split_conds_to_windows and len(cond_in) > 1:
|
|
||||||
region = window.get_region_index(len(cond_in))
|
|
||||||
logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
|
|
||||||
cond_in = [cond_in[region]]
|
|
||||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
|
||||||
for actual_cond in cond_in:
|
|
||||||
resized_actual_cond = actual_cond.copy()
|
|
||||||
# now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
|
|
||||||
for key in actual_cond:
|
|
||||||
try:
|
|
||||||
cond_item = actual_cond[key]
|
|
||||||
if isinstance(cond_item, torch.Tensor):
|
|
||||||
# check that tensor is the expected length - x.size(0)
|
|
||||||
if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim):
|
|
||||||
# if so, it's subsetting time - tell controls the expected indeces so they can handle them
|
|
||||||
actual_cond_item = window.get_tensor(cond_item)
|
|
||||||
resized_actual_cond[key] = actual_cond_item.to(device)
|
|
||||||
else:
|
|
||||||
resized_actual_cond[key] = cond_item.to(device)
|
|
||||||
# look for control
|
|
||||||
elif key == "control":
|
|
||||||
resized_actual_cond[key] = self.prepare_control_objects(cond_item, device)
|
|
||||||
elif isinstance(cond_item, dict):
|
|
||||||
new_cond_item = cond_item.copy()
|
|
||||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
|
||||||
for cond_key, cond_value in new_cond_item.items():
|
|
||||||
# Allow callbacks to handle custom conditioning items
|
|
||||||
handled = False
|
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(
|
|
||||||
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
|
|
||||||
):
|
|
||||||
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
|
|
||||||
if result is not None:
|
|
||||||
new_cond_item[cond_key] = result
|
|
||||||
handled = True
|
|
||||||
break
|
|
||||||
if handled:
|
|
||||||
continue
|
|
||||||
if isinstance(cond_value, torch.Tensor):
|
|
||||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
|
||||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
|
||||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
|
||||||
# Handle audio_embed (temporal dim is 1)
|
|
||||||
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
|
||||||
audio_cond = cond_value.cond
|
|
||||||
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
|
||||||
# if has cond that is a Tensor, check if needs to be subset
|
|
||||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
|
||||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
|
||||||
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
|
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
|
|
||||||
elif cond_key == "num_video_frames": # for SVD
|
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
|
||||||
new_cond_item[cond_key].cond = window.context_length
|
|
||||||
resized_actual_cond[key] = new_cond_item
|
|
||||||
else:
|
|
||||||
resized_actual_cond[key] = cond_item
|
|
||||||
finally:
|
|
||||||
del cond_item # just in case to prevent VRAM issues
|
|
||||||
resized_cond.append(resized_actual_cond)
|
|
||||||
return resized_cond
|
|
||||||
|
|
||||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
|
||||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
|
||||||
matches = torch.nonzero(mask)
|
|
||||||
if torch.numel(matches) == 0:
|
|
||||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
|
||||||
self._step = int(matches[0].item())
|
|
||||||
|
|
||||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
|
||||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
|
||||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
|
||||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
|
||||||
return context_windows
|
|
||||||
|
|
||||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
|
||||||
self.set_step(timestep, model_options)
|
|
||||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
|
||||||
enumerated_context_windows = list(enumerate(context_windows))
|
|
||||||
|
|
||||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
|
||||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
|
||||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
|
||||||
else:
|
|
||||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
|
||||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
|
||||||
|
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
|
||||||
callback(self, model, x_in, conds, timestep, model_options)
|
|
||||||
|
|
||||||
for enum_window in enumerated_context_windows:
|
|
||||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
|
||||||
for result in results:
|
|
||||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
|
||||||
conds_final, counts_final, biases_final)
|
|
||||||
try:
|
|
||||||
# finalize conds
|
|
||||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
|
||||||
# relative is already normalized, so return as is
|
|
||||||
del counts_final
|
|
||||||
return conds_final
|
|
||||||
else:
|
|
||||||
# normalize conds via division by context usage counts
|
|
||||||
for i in range(len(conds_final)):
|
|
||||||
conds_final[i] /= counts_final[i]
|
|
||||||
del counts_final
|
|
||||||
return conds_final
|
|
||||||
finally:
|
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
|
||||||
callback(self, model, x_in, conds, timestep, model_options)
|
|
||||||
|
|
||||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
|
||||||
model_options, device=None, first_device=None):
|
|
||||||
results: list[ContextResults] = []
|
|
||||||
for window_idx, window in enumerated_context_windows:
|
|
||||||
# allow processing to end between context window executions for faster Cancel
|
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
|
||||||
|
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
|
||||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
|
||||||
|
|
||||||
# update exposed params
|
|
||||||
model_options["transformer_options"]["context_window"] = window
|
|
||||||
# get subsections of x, timestep, conds
|
|
||||||
sub_x = window.get_tensor(x_in, device)
|
|
||||||
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
|
||||||
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
|
||||||
|
|
||||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
|
||||||
if device is not None:
|
|
||||||
for i in range(len(sub_conds_out)):
|
|
||||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
|
||||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
|
|
||||||
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
|
|
||||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
|
||||||
for pos, idx in enumerate(window.index_list):
|
|
||||||
# bias is the influence of a specific index in relation to the whole context window
|
|
||||||
bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2)
|
|
||||||
bias = max(1e-2, bias)
|
|
||||||
# take weighted average relative to total bias of current idx
|
|
||||||
for i in range(len(sub_conds_out)):
|
|
||||||
bias_total = biases_final[i][idx]
|
|
||||||
prev_weight = (bias_total / (bias_total + bias))
|
|
||||||
new_weight = (bias / (bias_total + bias))
|
|
||||||
# account for dims of tensors
|
|
||||||
idx_window = tuple([slice(None)] * self.dim + [idx])
|
|
||||||
pos_window = tuple([slice(None)] * self.dim + [pos])
|
|
||||||
# apply new values
|
|
||||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
|
||||||
biases_final[i][idx] = bias_total + bias
|
|
||||||
else:
|
|
||||||
# add conds and counts based on weights of fuse method
|
|
||||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
|
||||||
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
|
||||||
for i in range(len(sub_conds_out)):
|
|
||||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
|
||||||
window.add_window(counts_final[i], weights_tensor)
|
|
||||||
|
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
|
|
||||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
|
||||||
# limit noise_shape length to context_length for more accurate vram use estimation
|
|
||||||
model_options = kwargs.get("model_options", None)
|
|
||||||
if model_options is None:
|
|
||||||
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
|
||||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
|
||||||
if handler is not None:
|
|
||||||
noise_shape = list(noise_shape)
|
|
||||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
|
||||||
return executor(model, noise_shape, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
|
||||||
model.add_wrapper_with_key(
|
|
||||||
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
|
|
||||||
"ContextWindows_prepare_sampling",
|
|
||||||
_prepare_sampling_wrapper
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
|
|
||||||
model_options = extra_args.get("model_options", None)
|
|
||||||
if model_options is None:
|
|
||||||
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
|
||||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
|
||||||
if handler is None:
|
|
||||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
|
||||||
if not handler.freenoise:
|
|
||||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
|
||||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
|
||||||
|
|
||||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
|
||||||
model.add_wrapper_with_key(
|
|
||||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
|
||||||
"ContextWindows_sampler_sample",
|
|
||||||
_sampler_sample_wrapper
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
|
||||||
total_dims = len(x_in.shape)
|
|
||||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
|
||||||
for _ in range(dim):
|
|
||||||
weights_tensor = weights_tensor.unsqueeze(0)
|
|
||||||
for _ in range(total_dims - dim - 1):
|
|
||||||
weights_tensor = weights_tensor.unsqueeze(-1)
|
|
||||||
return weights_tensor
|
|
||||||
|
|
||||||
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
|
|
||||||
total_dims = len(x_in.shape)
|
|
||||||
shape = []
|
|
||||||
for _ in range(dim):
|
|
||||||
shape.append(1)
|
|
||||||
shape.append(x_in.shape[dim])
|
|
||||||
for _ in range(total_dims - dim - 1):
|
|
||||||
shape.append(1)
|
|
||||||
return shape
|
|
||||||
|
|
||||||
class ContextSchedules:
|
|
||||||
UNIFORM_LOOPED = "looped_uniform"
|
|
||||||
UNIFORM_STANDARD = "standard_uniform"
|
|
||||||
STATIC_STANDARD = "standard_static"
|
|
||||||
BATCHED = "batched"
|
|
||||||
|
|
||||||
|
|
||||||
# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
|
|
||||||
def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
|
||||||
windows = []
|
|
||||||
if num_frames < handler.context_length:
|
|
||||||
windows.append(list(range(num_frames)))
|
|
||||||
return windows
|
|
||||||
|
|
||||||
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
|
||||||
# obtain uniform windows as normal, looping and all
|
|
||||||
for context_step in 1 << np.arange(context_stride):
|
|
||||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
|
||||||
for j in range(
|
|
||||||
int(ordered_halving(handler._step) * context_step) + pad,
|
|
||||||
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
|
|
||||||
(handler.context_length * context_step - handler.context_overlap),
|
|
||||||
):
|
|
||||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
|
||||||
|
|
||||||
return windows
|
|
||||||
|
|
||||||
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
|
||||||
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
|
|
||||||
# instead, they get shifted to the corresponding end of the frames.
|
|
||||||
# in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
|
|
||||||
windows = []
|
|
||||||
if num_frames <= handler.context_length:
|
|
||||||
windows.append(list(range(num_frames)))
|
|
||||||
return windows
|
|
||||||
|
|
||||||
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
|
||||||
# first, obtain uniform windows as normal, looping and all
|
|
||||||
for context_step in 1 << np.arange(context_stride):
|
|
||||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
|
||||||
for j in range(
|
|
||||||
int(ordered_halving(handler._step) * context_step) + pad,
|
|
||||||
num_frames + pad + (-handler.context_overlap),
|
|
||||||
(handler.context_length * context_step - handler.context_overlap),
|
|
||||||
):
|
|
||||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
|
||||||
|
|
||||||
# now that windows are created, shift any windows that loop, and delete duplicate windows
|
|
||||||
delete_idxs = []
|
|
||||||
win_i = 0
|
|
||||||
while win_i < len(windows):
|
|
||||||
# if window is rolls over itself, need to shift it
|
|
||||||
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
|
|
||||||
if is_roll:
|
|
||||||
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
|
|
||||||
shift_window_to_end(windows[win_i], num_frames=num_frames)
|
|
||||||
# check if next window (cyclical) is missing roll_val
|
|
||||||
if roll_val not in windows[(win_i+1) % len(windows)]:
|
|
||||||
# need to insert new window here - just insert window starting at roll_val
|
|
||||||
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
|
|
||||||
# delete window if it's not unique
|
|
||||||
for pre_i in range(0, win_i):
|
|
||||||
if windows[win_i] == windows[pre_i]:
|
|
||||||
delete_idxs.append(win_i)
|
|
||||||
break
|
|
||||||
win_i += 1
|
|
||||||
|
|
||||||
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
|
|
||||||
delete_idxs.reverse()
|
|
||||||
for i in delete_idxs:
|
|
||||||
windows.pop(i)
|
|
||||||
|
|
||||||
return windows
|
|
||||||
|
|
||||||
|
|
||||||
def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
|
||||||
windows = []
|
|
||||||
if num_frames <= handler.context_length:
|
|
||||||
windows.append(list(range(num_frames)))
|
|
||||||
return windows
|
|
||||||
# always return the same set of windows
|
|
||||||
delta = handler.context_length - handler.context_overlap
|
|
||||||
for start_idx in range(0, num_frames, delta):
|
|
||||||
# if past the end of frames, move start_idx back to allow same context_length
|
|
||||||
ending = start_idx + handler.context_length
|
|
||||||
if ending >= num_frames:
|
|
||||||
final_delta = ending - num_frames
|
|
||||||
final_start_idx = start_idx - final_delta
|
|
||||||
windows.append(list(range(final_start_idx, final_start_idx + handler.context_length)))
|
|
||||||
break
|
|
||||||
windows.append(list(range(start_idx, start_idx + handler.context_length)))
|
|
||||||
return windows
|
|
||||||
|
|
||||||
|
|
||||||
def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
|
||||||
windows = []
|
|
||||||
if num_frames <= handler.context_length:
|
|
||||||
windows.append(list(range(num_frames)))
|
|
||||||
return windows
|
|
||||||
# always return the same set of windows;
|
|
||||||
# no overlap, just cut up based on context_length;
|
|
||||||
# last window size will be different if num_frames % opts.context_length != 0
|
|
||||||
for start_idx in range(0, num_frames, handler.context_length):
|
|
||||||
windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames))))
|
|
||||||
return windows
|
|
||||||
|
|
||||||
|
|
||||||
def create_windows_default(num_frames: int, handler: IndexListContextHandler):
|
|
||||||
return [list(range(num_frames))]
|
|
||||||
|
|
||||||
|
|
||||||
CONTEXT_MAPPING = {
|
|
||||||
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
|
|
||||||
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
|
|
||||||
ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
|
|
||||||
ContextSchedules.BATCHED: create_windows_batched,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
|
||||||
func = CONTEXT_MAPPING.get(context_schedule, None)
|
|
||||||
if func is None:
|
|
||||||
raise ValueError(f"Unknown context_schedule '{context_schedule}'.")
|
|
||||||
return ContextSchedule(context_schedule, func)
|
|
||||||
|
|
||||||
|
|
||||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
|
||||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
|
||||||
|
|
||||||
|
|
||||||
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
|
||||||
# weight is the same for all
|
|
||||||
return [1.0] * length
|
|
||||||
|
|
||||||
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
|
||||||
# weight is based on the distance away from the edge of the context window;
|
|
||||||
# based on weighted average concept in FreeNoise paper
|
|
||||||
if length % 2 == 0:
|
|
||||||
max_weight = length // 2
|
|
||||||
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
|
|
||||||
else:
|
|
||||||
max_weight = (length + 1) // 2
|
|
||||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
|
||||||
return weight_sequence
|
|
||||||
|
|
||||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
|
||||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
|
||||||
# only expected overlap is given different weights
|
|
||||||
weights_torch = torch.ones((length))
|
|
||||||
# blend left-side on all except first window
|
|
||||||
if min(idxs) > 0:
|
|
||||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
|
||||||
weights_torch[:handler.context_overlap] = ramp_up
|
|
||||||
# blend right-side on all except last window
|
|
||||||
if max(idxs) < full_length-1:
|
|
||||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
|
||||||
weights_torch[-handler.context_overlap:] = ramp_down
|
|
||||||
return weights_torch
|
|
||||||
|
|
||||||
class ContextFuseMethods:
|
|
||||||
FLAT = "flat"
|
|
||||||
PYRAMID = "pyramid"
|
|
||||||
RELATIVE = "relative"
|
|
||||||
OVERLAP_LINEAR = "overlap-linear"
|
|
||||||
|
|
||||||
LIST = [PYRAMID, FLAT, OVERLAP_LINEAR]
|
|
||||||
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR]
|
|
||||||
|
|
||||||
|
|
||||||
FUSE_MAPPING = {
|
|
||||||
ContextFuseMethods.FLAT: create_weights_flat,
|
|
||||||
ContextFuseMethods.PYRAMID: create_weights_pyramid,
|
|
||||||
ContextFuseMethods.RELATIVE: create_weights_pyramid,
|
|
||||||
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
|
|
||||||
func = FUSE_MAPPING.get(fuse_method, None)
|
|
||||||
if func is None:
|
|
||||||
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
|
|
||||||
return ContextFuseMethod(fuse_method, func)
|
|
||||||
|
|
||||||
# Returns fraction that has denominator that is a power of 2
|
|
||||||
def ordered_halving(val):
|
|
||||||
# get binary value, padded with 0s for 64 bits
|
|
||||||
bin_str = f"{val:064b}"
|
|
||||||
# flip binary value, padding included
|
|
||||||
bin_flip = bin_str[::-1]
|
|
||||||
# convert binary to int
|
|
||||||
as_int = int(bin_flip, 2)
|
|
||||||
# divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
|
|
||||||
# or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
|
|
||||||
return as_int / (1 << 64)
|
|
||||||
|
|
||||||
|
|
||||||
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
|
|
||||||
all_indexes = list(range(num_frames))
|
|
||||||
for w in windows:
|
|
||||||
for val in w:
|
|
||||||
try:
|
|
||||||
all_indexes.remove(val)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
return all_indexes
|
|
||||||
|
|
||||||
|
|
||||||
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
|
|
||||||
prev_val = -1
|
|
||||||
for i, val in enumerate(window):
|
|
||||||
val = val % num_frames
|
|
||||||
if val < prev_val:
|
|
||||||
return True, i
|
|
||||||
prev_val = val
|
|
||||||
return False, -1
|
|
||||||
|
|
||||||
|
|
||||||
def shift_window_to_start(window: list[int], num_frames: int):
|
|
||||||
start_val = window[0]
|
|
||||||
for i in range(len(window)):
|
|
||||||
# 1) subtract each element by start_val to move vals relative to the start of all frames
|
|
||||||
# 2) add num_frames and take modulus to get adjusted vals
|
|
||||||
window[i] = ((window[i] - start_val) + num_frames) % num_frames
|
|
||||||
|
|
||||||
|
|
||||||
def shift_window_to_end(window: list[int], num_frames: int):
|
|
||||||
# 1) shift window to start
|
|
||||||
shift_window_to_start(window, num_frames)
|
|
||||||
end_val = window[-1]
|
|
||||||
end_delta = num_frames - end_val - 1
|
|
||||||
for i in range(len(window)):
|
|
||||||
# 2) add end_delta to each val to slide windows to end
|
|
||||||
window[i] = window[i] + end_delta
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
|
|
||||||
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
|
|
||||||
logging.info("Context windows: Applying FreeNoise")
|
|
||||||
generator = torch.Generator(device='cpu').manual_seed(seed)
|
|
||||||
latent_video_length = noise.shape[dim]
|
|
||||||
delta = context_length - context_overlap
|
|
||||||
|
|
||||||
for start_idx in range(0, latent_video_length - context_length, delta):
|
|
||||||
place_idx = start_idx + context_length
|
|
||||||
|
|
||||||
actual_delta = min(delta, latent_video_length - place_idx)
|
|
||||||
if actual_delta <= 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
|
|
||||||
|
|
||||||
source_slice = [slice(None)] * noise.ndim
|
|
||||||
source_slice[dim] = list_idx
|
|
||||||
target_slice = [slice(None)] * noise.ndim
|
|
||||||
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
|
|
||||||
|
|
||||||
noise[tuple(target_slice)] = noise[tuple(source_slice)]
|
|
||||||
|
|
||||||
return noise
|
|
||||||
@ -1,884 +0,0 @@
|
|||||||
"""
|
|
||||||
This file is part of ComfyUI.
|
|
||||||
Copyright (C) 2024 Comfy
|
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
|
||||||
it under the terms of the GNU General Public License as published by
|
|
||||||
the Free Software Foundation, either version 3 of the License, or
|
|
||||||
(at your option) any later version.
|
|
||||||
|
|
||||||
This program is distributed in the hope that it will be useful,
|
|
||||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
GNU General Public License for more details.
|
|
||||||
|
|
||||||
You should have received a copy of the GNU General Public License
|
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from enum import Enum
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import comfy.utils
|
|
||||||
import comfy.model_management
|
|
||||||
import comfy.model_detection
|
|
||||||
import comfy.model_patcher
|
|
||||||
import comfy.ops
|
|
||||||
import comfy.latent_formats
|
|
||||||
import comfy.model_base
|
|
||||||
|
|
||||||
import comfy.cldm.cldm
|
|
||||||
import comfy.t2i_adapter.adapter
|
|
||||||
import comfy.ldm.cascade.controlnet
|
|
||||||
import comfy.cldm.mmdit
|
|
||||||
import comfy.ldm.hydit.controlnet
|
|
||||||
import comfy.ldm.flux.controlnet
|
|
||||||
import comfy.ldm.qwen_image.controlnet
|
|
||||||
import comfy.cldm.dit_embedder
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from comfy.hooks import HookGroup
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
|
||||||
current_batch_size = tensor.shape[0]
|
|
||||||
if current_batch_size == 1:
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
per_batch = target_batch_size // batched_number
|
|
||||||
tensor = tensor[:per_batch]
|
|
||||||
|
|
||||||
if per_batch > tensor.shape[0]:
|
|
||||||
tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
|
|
||||||
|
|
||||||
current_batch_size = tensor.shape[0]
|
|
||||||
if current_batch_size == target_batch_size:
|
|
||||||
return tensor
|
|
||||||
else:
|
|
||||||
return torch.cat([tensor] * batched_number, dim=0)
|
|
||||||
|
|
||||||
class StrengthType(Enum):
|
|
||||||
CONSTANT = 1
|
|
||||||
LINEAR_UP = 2
|
|
||||||
|
|
||||||
class ControlBase:
|
|
||||||
def __init__(self):
|
|
||||||
self.cond_hint_original = None
|
|
||||||
self.cond_hint = None
|
|
||||||
self.strength = 1.0
|
|
||||||
self.timestep_percent_range = (0.0, 1.0)
|
|
||||||
self.latent_format = None
|
|
||||||
self.vae = None
|
|
||||||
self.global_average_pooling = False
|
|
||||||
self.timestep_range = None
|
|
||||||
self.compression_ratio = 8
|
|
||||||
self.upscale_algorithm = 'nearest-exact'
|
|
||||||
self.extra_args = {}
|
|
||||||
self.previous_controlnet = None
|
|
||||||
self.extra_conds = []
|
|
||||||
self.strength_type = StrengthType.CONSTANT
|
|
||||||
self.concat_mask = False
|
|
||||||
self.extra_concat_orig = []
|
|
||||||
self.extra_concat = None
|
|
||||||
self.extra_hooks: HookGroup = None
|
|
||||||
self.preprocess_image = lambda a: a
|
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
|
||||||
self.cond_hint_original = cond_hint
|
|
||||||
self.strength = strength
|
|
||||||
self.timestep_percent_range = timestep_percent_range
|
|
||||||
if self.latent_format is not None:
|
|
||||||
if vae is None:
|
|
||||||
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
|
|
||||||
self.vae = vae
|
|
||||||
self.extra_concat_orig = extra_concat.copy()
|
|
||||||
if self.concat_mask and len(self.extra_concat_orig) == 0:
|
|
||||||
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
|
||||||
self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
self.previous_controlnet.pre_run(model, percent_to_timestep_function)
|
|
||||||
|
|
||||||
def set_previous_controlnet(self, controlnet):
|
|
||||||
self.previous_controlnet = controlnet
|
|
||||||
return self
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
self.previous_controlnet.cleanup()
|
|
||||||
|
|
||||||
self.cond_hint = None
|
|
||||||
self.extra_concat = None
|
|
||||||
self.timestep_range = None
|
|
||||||
|
|
||||||
def get_models(self):
|
|
||||||
out = []
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
out += self.previous_controlnet.get_models()
|
|
||||||
return out
|
|
||||||
|
|
||||||
def get_extra_hooks(self):
|
|
||||||
out = []
|
|
||||||
if self.extra_hooks is not None:
|
|
||||||
out.append(self.extra_hooks)
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
out += self.previous_controlnet.get_extra_hooks()
|
|
||||||
return out
|
|
||||||
|
|
||||||
def copy_to(self, c):
|
|
||||||
c.cond_hint_original = self.cond_hint_original
|
|
||||||
c.strength = self.strength
|
|
||||||
c.timestep_percent_range = self.timestep_percent_range
|
|
||||||
c.global_average_pooling = self.global_average_pooling
|
|
||||||
c.compression_ratio = self.compression_ratio
|
|
||||||
c.upscale_algorithm = self.upscale_algorithm
|
|
||||||
c.latent_format = self.latent_format
|
|
||||||
c.extra_args = self.extra_args.copy()
|
|
||||||
c.vae = self.vae
|
|
||||||
c.extra_conds = self.extra_conds.copy()
|
|
||||||
c.strength_type = self.strength_type
|
|
||||||
c.concat_mask = self.concat_mask
|
|
||||||
c.extra_concat_orig = self.extra_concat_orig.copy()
|
|
||||||
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
|
|
||||||
c.preprocess_image = self.preprocess_image
|
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
return self.previous_controlnet.inference_memory_requirements(dtype)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def control_merge(self, control, control_prev, output_dtype):
|
|
||||||
out = {'input':[], 'middle':[], 'output': []}
|
|
||||||
|
|
||||||
for key in control:
|
|
||||||
control_output = control[key]
|
|
||||||
applied_to = set()
|
|
||||||
for i in range(len(control_output)):
|
|
||||||
x = control_output[i]
|
|
||||||
if x is not None:
|
|
||||||
if self.global_average_pooling:
|
|
||||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
|
||||||
|
|
||||||
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
|
||||||
applied_to.add(x)
|
|
||||||
if self.strength_type == StrengthType.CONSTANT:
|
|
||||||
x *= self.strength
|
|
||||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
|
||||||
x *= (self.strength ** float(len(control_output) - i))
|
|
||||||
|
|
||||||
if output_dtype is not None and x.dtype != output_dtype:
|
|
||||||
x = x.to(output_dtype)
|
|
||||||
|
|
||||||
out[key].append(x)
|
|
||||||
|
|
||||||
if control_prev is not None:
|
|
||||||
for x in ['input', 'middle', 'output']:
|
|
||||||
o = out[x]
|
|
||||||
for i in range(len(control_prev[x])):
|
|
||||||
prev_val = control_prev[x][i]
|
|
||||||
if i >= len(o):
|
|
||||||
o.append(prev_val)
|
|
||||||
elif prev_val is not None:
|
|
||||||
if o[i] is None:
|
|
||||||
o[i] = prev_val
|
|
||||||
else:
|
|
||||||
if o[i].shape[0] < prev_val.shape[0]:
|
|
||||||
o[i] = prev_val + o[i]
|
|
||||||
else:
|
|
||||||
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
|
|
||||||
return out
|
|
||||||
|
|
||||||
def set_extra_arg(self, argument, value=None):
|
|
||||||
self.extra_args[argument] = value
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
|
|
||||||
super().__init__()
|
|
||||||
self.control_model = control_model
|
|
||||||
self.load_device = load_device
|
|
||||||
if control_model is not None:
|
|
||||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
|
||||||
|
|
||||||
self.compression_ratio = compression_ratio
|
|
||||||
self.global_average_pooling = global_average_pooling
|
|
||||||
self.model_sampling_current = None
|
|
||||||
self.manual_cast_dtype = manual_cast_dtype
|
|
||||||
self.latent_format = latent_format
|
|
||||||
self.extra_conds += extra_conds
|
|
||||||
self.strength_type = strength_type
|
|
||||||
self.concat_mask = concat_mask
|
|
||||||
self.preprocess_image = preprocess_image
|
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
|
||||||
control_prev = None
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
|
||||||
|
|
||||||
if self.timestep_range is not None:
|
|
||||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
|
||||||
if control_prev is not None:
|
|
||||||
return control_prev
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
dtype = self.control_model.dtype
|
|
||||||
if self.manual_cast_dtype is not None:
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
|
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
|
||||||
if self.cond_hint is not None:
|
|
||||||
del self.cond_hint
|
|
||||||
self.cond_hint = None
|
|
||||||
compression_ratio = self.compression_ratio
|
|
||||||
if self.vae is not None:
|
|
||||||
compression_ratio *= self.vae.spacial_compression_encode()
|
|
||||||
else:
|
|
||||||
if self.latent_format is not None:
|
|
||||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
|
|
||||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
|
||||||
if self.vae is not None:
|
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
|
||||||
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
|
||||||
comfy.model_management.load_models_gpu(loaded_models)
|
|
||||||
if self.latent_format is not None:
|
|
||||||
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
|
||||||
if len(self.extra_concat_orig) > 0:
|
|
||||||
to_concat = []
|
|
||||||
for c in self.extra_concat_orig:
|
|
||||||
c = c.to(self.cond_hint.device)
|
|
||||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
|
|
||||||
if c.ndim < self.cond_hint.ndim:
|
|
||||||
c = c.unsqueeze(2)
|
|
||||||
c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
|
|
||||||
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
|
||||||
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
|
||||||
|
|
||||||
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
|
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
|
||||||
|
|
||||||
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
|
||||||
extra = self.extra_args.copy()
|
|
||||||
for c in self.extra_conds:
|
|
||||||
temp = cond.get(c, None)
|
|
||||||
if temp is not None:
|
|
||||||
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
|
|
||||||
|
|
||||||
timestep = self.model_sampling_current.timestep(t)
|
|
||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
|
|
||||||
return self.control_merge(control, control_prev, output_dtype=None)
|
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
|
||||||
c.control_model = self.control_model
|
|
||||||
c.control_model_wrapped = self.control_model_wrapped
|
|
||||||
self.copy_to(c)
|
|
||||||
return c
|
|
||||||
|
|
||||||
def get_models(self):
|
|
||||||
out = super().get_models()
|
|
||||||
out.append(self.control_model_wrapped)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
|
||||||
super().pre_run(model, percent_to_timestep_function)
|
|
||||||
self.model_sampling_current = model.model_sampling
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
self.model_sampling_current = None
|
|
||||||
super().cleanup()
|
|
||||||
|
|
||||||
class ControlLoraOps:
|
|
||||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
|
||||||
device=None, dtype=None) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
self.weight = None
|
|
||||||
self.up = None
|
|
||||||
self.down = None
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
|
||||||
if self.up is not None:
|
|
||||||
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
|
||||||
else:
|
|
||||||
x = torch.nn.functional.linear(input, weight, bias)
|
|
||||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
bias=True,
|
|
||||||
padding_mode='zeros',
|
|
||||||
device=None,
|
|
||||||
dtype=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.stride = stride
|
|
||||||
self.padding = padding
|
|
||||||
self.dilation = dilation
|
|
||||||
self.transposed = False
|
|
||||||
self.output_padding = 0
|
|
||||||
self.groups = groups
|
|
||||||
self.padding_mode = padding_mode
|
|
||||||
|
|
||||||
self.weight = None
|
|
||||||
self.bias = None
|
|
||||||
self.up = None
|
|
||||||
self.down = None
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
|
||||||
if self.up is not None:
|
|
||||||
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
|
||||||
else:
|
|
||||||
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
|
||||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
|
||||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
|
||||||
ControlBase.__init__(self)
|
|
||||||
self.control_weights = control_weights
|
|
||||||
self.global_average_pooling = global_average_pooling
|
|
||||||
self.extra_conds += ["y"]
|
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
|
||||||
super().pre_run(model, percent_to_timestep_function)
|
|
||||||
controlnet_config = model.model_config.unet_config.copy()
|
|
||||||
controlnet_config.pop("out_channels")
|
|
||||||
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
|
||||||
self.manual_cast_dtype = model.manual_cast_dtype
|
|
||||||
dtype = model.get_dtype()
|
|
||||||
if self.manual_cast_dtype is None:
|
|
||||||
class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
|
|
||||||
pass
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
|
|
||||||
controlnet_config["operations"] = control_lora_ops
|
|
||||||
controlnet_config["dtype"] = dtype
|
|
||||||
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
|
||||||
self.control_model.to(comfy.model_management.get_torch_device())
|
|
||||||
diffusion_model = model.diffusion_model
|
|
||||||
sd = diffusion_model.state_dict()
|
|
||||||
|
|
||||||
for k in sd:
|
|
||||||
weight = sd[k]
|
|
||||||
try:
|
|
||||||
comfy.utils.set_attr_param(self.control_model, k, weight)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
for k in self.control_weights:
|
|
||||||
if (k not in {"lora_controlnet"}):
|
|
||||||
if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k):
|
|
||||||
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
|
||||||
self.copy_to(c)
|
|
||||||
return c
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
del self.control_model
|
|
||||||
self.control_model = None
|
|
||||||
super().cleanup()
|
|
||||||
|
|
||||||
def get_models(self):
|
|
||||||
out = ControlBase.get_models(self)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
|
||||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
|
||||||
|
|
||||||
def controlnet_config(sd, model_options={}):
|
|
||||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
|
||||||
|
|
||||||
unet_dtype = model_options.get("dtype", None)
|
|
||||||
if unet_dtype is None:
|
|
||||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
|
||||||
|
|
||||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
|
|
||||||
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
|
||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
|
||||||
if operations is None:
|
|
||||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
|
||||||
|
|
||||||
offload_device = comfy.model_management.unet_offload_device()
|
|
||||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
|
||||||
|
|
||||||
def controlnet_load_state_dict(control_model, sd):
|
|
||||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
|
||||||
|
|
||||||
if len(missing) > 0:
|
|
||||||
logging.warning("missing controlnet keys: {}".format(missing))
|
|
||||||
|
|
||||||
if len(unexpected) > 0:
|
|
||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
|
||||||
return control_model
|
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_mmdit(sd, model_options={}):
|
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
|
||||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
|
||||||
for k in sd:
|
|
||||||
new_sd[k] = sd[k]
|
|
||||||
|
|
||||||
concat_mask = False
|
|
||||||
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
|
|
||||||
if control_latent_channels == 17: #inpaint controlnet
|
|
||||||
concat_mask = True
|
|
||||||
|
|
||||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
|
||||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
|
||||||
return control
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetSD35(ControlNet):
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
|
||||||
if self.control_model.double_y_emb:
|
|
||||||
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
|
||||||
else:
|
|
||||||
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
|
||||||
super().pre_run(model, percent_to_timestep_function)
|
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
|
||||||
c.control_model = self.control_model
|
|
||||||
c.control_model_wrapped = self.control_model_wrapped
|
|
||||||
self.copy_to(c)
|
|
||||||
return c
|
|
||||||
|
|
||||||
def load_controlnet_sd35(sd, model_options={}):
|
|
||||||
control_type = -1
|
|
||||||
if "control_type" in sd:
|
|
||||||
control_type = round(sd.pop("control_type").item())
|
|
||||||
|
|
||||||
# blur_cnet = control_type == 0
|
|
||||||
canny_cnet = control_type == 1
|
|
||||||
depth_cnet = control_type == 2
|
|
||||||
|
|
||||||
new_sd = {}
|
|
||||||
for k in comfy.utils.MMDIT_MAP_BASIC:
|
|
||||||
if k[1] in sd:
|
|
||||||
new_sd[k[0]] = sd.pop(k[1])
|
|
||||||
for k in sd:
|
|
||||||
new_sd[k] = sd[k]
|
|
||||||
sd = new_sd
|
|
||||||
|
|
||||||
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
|
|
||||||
depth = y_emb_shape[0] // 64
|
|
||||||
hidden_size = 64 * depth
|
|
||||||
num_heads = depth
|
|
||||||
head_dim = hidden_size // num_heads
|
|
||||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
|
|
||||||
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
|
||||||
offload_device = comfy.model_management.unet_offload_device()
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
|
|
||||||
|
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
|
||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
|
||||||
if operations is None:
|
|
||||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
|
||||||
|
|
||||||
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
|
|
||||||
patch_size=2,
|
|
||||||
in_chans=16,
|
|
||||||
num_layers=num_blocks,
|
|
||||||
main_model_double=depth,
|
|
||||||
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
|
|
||||||
attention_head_dim=head_dim,
|
|
||||||
num_attention_heads=num_heads,
|
|
||||||
adm_in_channels=2048,
|
|
||||||
device=offload_device,
|
|
||||||
dtype=unet_dtype,
|
|
||||||
operations=operations)
|
|
||||||
|
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
|
||||||
preprocess_image = lambda a: a
|
|
||||||
if canny_cnet:
|
|
||||||
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
|
|
||||||
elif depth_cnet:
|
|
||||||
preprocess_image = lambda a: 1.0 - a
|
|
||||||
|
|
||||||
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
|
|
||||||
return control
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
|
||||||
|
|
||||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
|
||||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SDXL()
|
|
||||||
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
|
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
|
||||||
return control
|
|
||||||
|
|
||||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
|
||||||
extra_conds = ['y', 'guidance']
|
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
|
||||||
return control
|
|
||||||
|
|
||||||
def load_controlnet_flux_instantx(sd, model_options={}):
|
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
|
||||||
for k in sd:
|
|
||||||
new_sd[k] = sd[k]
|
|
||||||
|
|
||||||
num_union_modes = 0
|
|
||||||
union_cnet = "controlnet_mode_embedder.weight"
|
|
||||||
if union_cnet in new_sd:
|
|
||||||
num_union_modes = new_sd[union_cnet].shape[0]
|
|
||||||
|
|
||||||
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
|
||||||
concat_mask = False
|
|
||||||
if control_latent_channels == 17:
|
|
||||||
concat_mask = True
|
|
||||||
|
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.Flux()
|
|
||||||
extra_conds = ['y', 'guidance']
|
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
|
||||||
return control
|
|
||||||
|
|
||||||
def load_controlnet_qwen_instantx(sd, model_options={}):
|
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
|
||||||
control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
|
|
||||||
|
|
||||||
extra_condition_channels = 0
|
|
||||||
concat_mask = False
|
|
||||||
if control_latent_channels == 68: #inpaint controlnet
|
|
||||||
extra_condition_channels = control_latent_channels - 64
|
|
||||||
concat_mask = True
|
|
||||||
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
|
||||||
latent_format = comfy.latent_formats.Wan21()
|
|
||||||
extra_conds = []
|
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
|
||||||
return control
|
|
||||||
|
|
||||||
def convert_mistoline(sd):
|
|
||||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|
||||||
controlnet_data = state_dict
|
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
|
||||||
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
|
|
||||||
|
|
||||||
if "lora_controlnet" in controlnet_data:
|
|
||||||
return ControlLora(controlnet_data, model_options=model_options)
|
|
||||||
|
|
||||||
controlnet_config = None
|
|
||||||
supported_inference_dtypes = None
|
|
||||||
|
|
||||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
|
||||||
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
|
|
||||||
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
|
||||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
|
||||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
loop = True
|
|
||||||
while loop:
|
|
||||||
suffix = [".weight", ".bias"]
|
|
||||||
for s in suffix:
|
|
||||||
k_in = "controlnet_down_blocks.{}{}".format(count, s)
|
|
||||||
k_out = "zero_convs.{}.0{}".format(count, s)
|
|
||||||
if k_in not in controlnet_data:
|
|
||||||
loop = False
|
|
||||||
break
|
|
||||||
diffusers_keys[k_in] = k_out
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
loop = True
|
|
||||||
while loop:
|
|
||||||
suffix = [".weight", ".bias"]
|
|
||||||
for s in suffix:
|
|
||||||
if count == 0:
|
|
||||||
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
|
|
||||||
else:
|
|
||||||
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
|
|
||||||
k_out = "input_hint_block.{}{}".format(count * 2, s)
|
|
||||||
if k_in not in controlnet_data:
|
|
||||||
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
|
|
||||||
loop = False
|
|
||||||
diffusers_keys[k_in] = k_out
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
new_sd = {}
|
|
||||||
for k in diffusers_keys:
|
|
||||||
if k in controlnet_data:
|
|
||||||
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
|
||||||
|
|
||||||
if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
|
|
||||||
controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
|
|
||||||
for k in list(controlnet_data.keys()):
|
|
||||||
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
|
|
||||||
new_sd[new_k] = controlnet_data.pop(k)
|
|
||||||
|
|
||||||
leftover_keys = controlnet_data.keys()
|
|
||||||
if len(leftover_keys) > 0:
|
|
||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
|
||||||
controlnet_data = new_sd
|
|
||||||
elif "controlnet_blocks.0.weight" in controlnet_data:
|
|
||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
|
||||||
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
|
||||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
|
||||||
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
|
|
||||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
|
||||||
else:
|
|
||||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
|
||||||
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
|
|
||||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
|
||||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
|
||||||
|
|
||||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
|
||||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
|
||||||
pth = False
|
|
||||||
key = 'zero_convs.0.0.weight'
|
|
||||||
if pth_key in controlnet_data:
|
|
||||||
pth = True
|
|
||||||
key = pth_key
|
|
||||||
prefix = "control_model."
|
|
||||||
elif key in controlnet_data:
|
|
||||||
prefix = ""
|
|
||||||
else:
|
|
||||||
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
|
||||||
if net is None:
|
|
||||||
logging.error("error could not detect control model type.")
|
|
||||||
return net
|
|
||||||
|
|
||||||
if controlnet_config is None:
|
|
||||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
|
||||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
|
||||||
controlnet_config = model_config.unet_config
|
|
||||||
|
|
||||||
unet_dtype = model_options.get("dtype", None)
|
|
||||||
if unet_dtype is None:
|
|
||||||
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
|
|
||||||
|
|
||||||
if supported_inference_dtypes is None:
|
|
||||||
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
|
||||||
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
|
|
||||||
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
|
||||||
|
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
|
||||||
operations = model_options.get("custom_operations", None)
|
|
||||||
if operations is None:
|
|
||||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
|
|
||||||
|
|
||||||
controlnet_config["operations"] = operations
|
|
||||||
controlnet_config["dtype"] = unet_dtype
|
|
||||||
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
|
||||||
controlnet_config.pop("out_channels")
|
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
|
||||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
|
||||||
|
|
||||||
if pth:
|
|
||||||
if 'difference' in controlnet_data:
|
|
||||||
if model is not None:
|
|
||||||
comfy.model_management.load_models_gpu([model])
|
|
||||||
model_sd = model.model_state_dict()
|
|
||||||
for x in controlnet_data:
|
|
||||||
c_m = "control_model."
|
|
||||||
if x.startswith(c_m):
|
|
||||||
sd_key = "diffusion_model.{}".format(x[len(c_m):])
|
|
||||||
if sd_key in model_sd:
|
|
||||||
cd = controlnet_data[x]
|
|
||||||
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
|
||||||
else:
|
|
||||||
logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
|
||||||
pass
|
|
||||||
w = WeightsLoader()
|
|
||||||
w.control_model = control_model
|
|
||||||
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
|
||||||
else:
|
|
||||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
|
||||||
|
|
||||||
if len(missing) > 0:
|
|
||||||
logging.warning("missing controlnet keys: {}".format(missing))
|
|
||||||
|
|
||||||
if len(unexpected) > 0:
|
|
||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
|
||||||
|
|
||||||
global_average_pooling = model_options.get("global_average_pooling", False)
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
|
||||||
return control
|
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
|
||||||
model_options = model_options.copy()
|
|
||||||
if "global_average_pooling" not in model_options:
|
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
|
||||||
model_options["global_average_pooling"] = True
|
|
||||||
|
|
||||||
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
|
|
||||||
if cnet is None:
|
|
||||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
|
||||||
return cnet
|
|
||||||
|
|
||||||
class T2IAdapter(ControlBase):
|
|
||||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
|
||||||
super().__init__()
|
|
||||||
self.t2i_model = t2i_model
|
|
||||||
self.channels_in = channels_in
|
|
||||||
self.control_input = None
|
|
||||||
self.compression_ratio = compression_ratio
|
|
||||||
self.upscale_algorithm = upscale_algorithm
|
|
||||||
if device is None:
|
|
||||||
device = comfy.model_management.get_torch_device()
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def scale_image_to(self, width, height):
|
|
||||||
unshuffle_amount = self.t2i_model.unshuffle_amount
|
|
||||||
width = math.ceil(width / unshuffle_amount) * unshuffle_amount
|
|
||||||
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
|
|
||||||
return width, height
|
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
|
||||||
control_prev = None
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
|
|
||||||
|
|
||||||
if self.timestep_range is not None:
|
|
||||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
|
||||||
if control_prev is not None:
|
|
||||||
return control_prev
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
|
||||||
if self.cond_hint is not None:
|
|
||||||
del self.cond_hint
|
|
||||||
self.control_input = None
|
|
||||||
self.cond_hint = None
|
|
||||||
width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
|
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
|
|
||||||
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
|
||||||
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
|
||||||
if self.control_input is None:
|
|
||||||
self.t2i_model.to(x_noisy.dtype)
|
|
||||||
self.t2i_model.to(self.device)
|
|
||||||
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
|
||||||
self.t2i_model.cpu()
|
|
||||||
|
|
||||||
control_input = {}
|
|
||||||
for k in self.control_input:
|
|
||||||
control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))
|
|
||||||
|
|
||||||
return self.control_merge(control_input, control_prev, x_noisy.dtype)
|
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
|
|
||||||
self.copy_to(c)
|
|
||||||
return c
|
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
|
||||||
compression_ratio = 8
|
|
||||||
upscale_algorithm = 'nearest-exact'
|
|
||||||
|
|
||||||
if 'adapter' in t2i_data:
|
|
||||||
t2i_data = t2i_data['adapter']
|
|
||||||
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
|
|
||||||
prefix_replace = {}
|
|
||||||
for i in range(4):
|
|
||||||
for j in range(2):
|
|
||||||
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
|
|
||||||
prefix_replace["adapter.body.{}.".format(i, )] = "body.{}.".format(i * 2)
|
|
||||||
prefix_replace["adapter."] = ""
|
|
||||||
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
|
|
||||||
keys = t2i_data.keys()
|
|
||||||
|
|
||||||
if "body.0.in_conv.weight" in keys:
|
|
||||||
cin = t2i_data['body.0.in_conv.weight'].shape[1]
|
|
||||||
model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
|
|
||||||
elif 'conv_in.weight' in keys:
|
|
||||||
cin = t2i_data['conv_in.weight'].shape[1]
|
|
||||||
channel = t2i_data['conv_in.weight'].shape[0]
|
|
||||||
ksize = t2i_data['body.0.block2.weight'].shape[2]
|
|
||||||
use_conv = False
|
|
||||||
down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
|
|
||||||
if len(down_opts) > 0:
|
|
||||||
use_conv = True
|
|
||||||
xl = False
|
|
||||||
if cin == 256 or cin == 768:
|
|
||||||
xl = True
|
|
||||||
model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
|
||||||
elif "backbone.0.0.weight" in keys:
|
|
||||||
model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
|
||||||
compression_ratio = 32
|
|
||||||
upscale_algorithm = 'bilinear'
|
|
||||||
elif "backbone.10.blocks.0.weight" in keys:
|
|
||||||
model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
|
||||||
compression_ratio = 1
|
|
||||||
upscale_algorithm = 'nearest-exact'
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
missing, unexpected = model_ad.load_state_dict(t2i_data)
|
|
||||||
if len(missing) > 0:
|
|
||||||
logging.warning("t2i missing {}".format(missing))
|
|
||||||
|
|
||||||
if len(unexpected) > 0:
|
|
||||||
logging.debug("t2i unexpected {}".format(unexpected))
|
|
||||||
|
|
||||||
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
|
|
||||||
@ -1,9 +1,116 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
from comfy.ldm.util import instantiate_from_config
|
||||||
|
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
|
||||||
|
import os.path as osp
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
import logging
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||||
|
|
||||||
|
# =================#
|
||||||
|
# UNet Conversion #
|
||||||
|
# =================#
|
||||||
|
|
||||||
|
unet_conversion_map = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||||
|
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||||
|
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||||
|
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
||||||
|
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||||
|
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||||
|
("out.0.weight", "conv_norm_out.weight"),
|
||||||
|
("out.0.bias", "conv_norm_out.bias"),
|
||||||
|
("out.2.weight", "conv_out.weight"),
|
||||||
|
("out.2.bias", "conv_out.bias"),
|
||||||
|
]
|
||||||
|
|
||||||
|
unet_conversion_map_resnet = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("in_layers.0", "norm1"),
|
||||||
|
("in_layers.2", "conv1"),
|
||||||
|
("out_layers.0", "norm2"),
|
||||||
|
("out_layers.3", "conv2"),
|
||||||
|
("emb_layers.1", "time_emb_proj"),
|
||||||
|
("skip_connection", "conv_shortcut"),
|
||||||
|
]
|
||||||
|
|
||||||
|
unet_conversion_map_layer = []
|
||||||
|
# hardcoded number of downblocks and resnets/attentions...
|
||||||
|
# would need smarter logic for other networks.
|
||||||
|
for i in range(4):
|
||||||
|
# loop over downblocks/upblocks
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
# loop over resnets/attentions for downblocks
|
||||||
|
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||||
|
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no attention layers in down_blocks.3
|
||||||
|
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||||
|
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(3):
|
||||||
|
# loop over resnets/attentions for upblocks
|
||||||
|
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||||
|
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||||
|
|
||||||
|
if i > 0:
|
||||||
|
# no attention layers in up_blocks.0
|
||||||
|
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||||
|
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no downsample in down_blocks.3
|
||||||
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||||
|
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||||
|
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||||
|
|
||||||
|
# no upsample in up_blocks.3
|
||||||
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||||
|
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
||||||
|
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||||
|
|
||||||
|
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||||
|
sd_mid_atn_prefix = "middle_block.1."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||||
|
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_unet_state_dict(unet_state_dict):
|
||||||
|
# buyer beware: this is a *brittle* function,
|
||||||
|
# and correct output requires that all of these pieces interact in
|
||||||
|
# the exact order in which I have arranged them.
|
||||||
|
mapping = {k: k for k in unet_state_dict.keys()}
|
||||||
|
for sd_name, hf_name in unet_conversion_map:
|
||||||
|
mapping[hf_name] = sd_name
|
||||||
|
for k, v in mapping.items():
|
||||||
|
if "resnets" in k:
|
||||||
|
for sd_part, hf_part in unet_conversion_map_resnet:
|
||||||
|
v = v.replace(hf_part, sd_part)
|
||||||
|
mapping[k] = v
|
||||||
|
for k, v in mapping.items():
|
||||||
|
for sd_part, hf_part in unet_conversion_map_layer:
|
||||||
|
v = v.replace(hf_part, sd_part)
|
||||||
|
mapping[k] = v
|
||||||
|
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
# ================#
|
# ================#
|
||||||
# VAE Conversion #
|
# VAE Conversion #
|
||||||
# ================#
|
# ================#
|
||||||
@ -50,31 +157,20 @@ vae_conversion_map_attn = [
|
|||||||
("q.", "query."),
|
("q.", "query."),
|
||||||
("k.", "key."),
|
("k.", "key."),
|
||||||
("v.", "value."),
|
("v.", "value."),
|
||||||
("q.", "to_q."),
|
|
||||||
("k.", "to_k."),
|
|
||||||
("v.", "to_v."),
|
|
||||||
("proj_out.", "to_out.0."),
|
|
||||||
("proj_out.", "proj_attn."),
|
("proj_out.", "proj_attn."),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def reshape_weight_for_sd(w, conv3d=False):
|
def reshape_weight_for_sd(w):
|
||||||
# convert HF linear weights to SD conv2d weights
|
# convert HF linear weights to SD conv2d weights
|
||||||
if conv3d:
|
return w.reshape(*w.shape, 1, 1)
|
||||||
return w.reshape(*w.shape, 1, 1, 1)
|
|
||||||
else:
|
|
||||||
return w.reshape(*w.shape, 1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_vae_state_dict(vae_state_dict):
|
def convert_vae_state_dict(vae_state_dict):
|
||||||
mapping = {k: k for k in vae_state_dict.keys()}
|
mapping = {k: k for k in vae_state_dict.keys()}
|
||||||
conv3d = False
|
|
||||||
for k, v in mapping.items():
|
for k, v in mapping.items():
|
||||||
for sd_part, hf_part in vae_conversion_map:
|
for sd_part, hf_part in vae_conversion_map:
|
||||||
v = v.replace(hf_part, sd_part)
|
v = v.replace(hf_part, sd_part)
|
||||||
if v.endswith(".conv.weight"):
|
|
||||||
if not conv3d and vae_state_dict[k].ndim == 5:
|
|
||||||
conv3d = True
|
|
||||||
mapping[k] = v
|
mapping[k] = v
|
||||||
for k, v in mapping.items():
|
for k, v in mapping.items():
|
||||||
if "attentions" in k:
|
if "attentions" in k:
|
||||||
@ -86,8 +182,8 @@ def convert_vae_state_dict(vae_state_dict):
|
|||||||
for k, v in new_state_dict.items():
|
for k, v in new_state_dict.items():
|
||||||
for weight_name in weights_to_convert:
|
for weight_name in weights_to_convert:
|
||||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||||
logging.debug(f"Reshaping {k} for SD format")
|
print(f"Reshaping {k} for SD format")
|
||||||
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
|
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
@ -115,30 +211,11 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||||
|
|
||||||
|
|
||||||
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
def convert_text_enc_state_dict_v20(text_enc_dict):
|
||||||
def cat_tensors(tensors):
|
|
||||||
x = 0
|
|
||||||
for t in tensors:
|
|
||||||
x += t.shape[0]
|
|
||||||
|
|
||||||
shape = [x] + list(tensors[0].shape)[1:]
|
|
||||||
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
|
|
||||||
|
|
||||||
x = 0
|
|
||||||
for t in tensors:
|
|
||||||
out[x:x + t.shape[0]] = t
|
|
||||||
x += t.shape[0]
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
capture_qkv_weight = {}
|
capture_qkv_weight = {}
|
||||||
capture_qkv_bias = {}
|
capture_qkv_bias = {}
|
||||||
for k, v in text_enc_dict.items():
|
for k, v in text_enc_dict.items():
|
||||||
if not k.startswith(prefix):
|
|
||||||
continue
|
|
||||||
if (
|
if (
|
||||||
k.endswith(".self_attn.q_proj.weight")
|
k.endswith(".self_attn.q_proj.weight")
|
||||||
or k.endswith(".self_attn.k_proj.weight")
|
or k.endswith(".self_attn.k_proj.weight")
|
||||||
@ -163,27 +240,123 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
|||||||
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
||||||
continue
|
continue
|
||||||
|
|
||||||
text_proj = "transformer.text_projection.weight"
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
||||||
if k.endswith(text_proj):
|
new_state_dict[relabelled_key] = v
|
||||||
new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
|
|
||||||
else:
|
|
||||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
|
||||||
new_state_dict[relabelled_key] = v
|
|
||||||
|
|
||||||
for k_pre, tensors in capture_qkv_weight.items():
|
for k_pre, tensors in capture_qkv_weight.items():
|
||||||
if None in tensors:
|
if None in tensors:
|
||||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||||
new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
|
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
||||||
|
|
||||||
for k_pre, tensors in capture_qkv_bias.items():
|
for k_pre, tensors in capture_qkv_bias.items():
|
||||||
if None in tensors:
|
if None in tensors:
|
||||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||||
new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
|
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
||||||
|
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
def convert_text_enc_state_dict(text_enc_dict):
|
def convert_text_enc_state_dict(text_enc_dict):
|
||||||
return text_enc_dict
|
return text_enc_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||||
|
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||||
|
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
|
||||||
|
|
||||||
|
# magic
|
||||||
|
v2 = diffusers_unet_conf["sample_size"] == 96
|
||||||
|
if 'prediction_type' in diffusers_scheduler_conf:
|
||||||
|
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
|
||||||
|
|
||||||
|
if v2:
|
||||||
|
if v_pred:
|
||||||
|
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
|
||||||
|
else:
|
||||||
|
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
|
||||||
|
else:
|
||||||
|
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
|
||||||
|
|
||||||
|
with open(config_path, 'r') as stream:
|
||||||
|
config = yaml.safe_load(stream)
|
||||||
|
|
||||||
|
model_config_params = config['model']['params']
|
||||||
|
clip_config = model_config_params['cond_stage_config']
|
||||||
|
scale_factor = model_config_params['scale_factor']
|
||||||
|
vae_config = model_config_params['first_stage_config']
|
||||||
|
vae_config['scale_factor'] = scale_factor
|
||||||
|
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
|
||||||
|
|
||||||
|
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||||
|
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||||
|
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||||
|
|
||||||
|
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||||
|
if osp.exists(unet_path):
|
||||||
|
unet_state_dict = load_file(unet_path, device="cpu")
|
||||||
|
else:
|
||||||
|
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
||||||
|
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||||
|
|
||||||
|
if osp.exists(vae_path):
|
||||||
|
vae_state_dict = load_file(vae_path, device="cpu")
|
||||||
|
else:
|
||||||
|
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
||||||
|
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||||
|
|
||||||
|
if osp.exists(text_enc_path):
|
||||||
|
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||||
|
else:
|
||||||
|
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||||
|
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||||
|
|
||||||
|
# Convert the UNet model
|
||||||
|
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||||
|
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||||
|
|
||||||
|
# Convert the VAE model
|
||||||
|
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||||
|
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||||
|
|
||||||
|
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||||
|
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||||
|
|
||||||
|
if is_v20_model:
|
||||||
|
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
||||||
|
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
||||||
|
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
|
||||||
|
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
||||||
|
else:
|
||||||
|
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
||||||
|
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||||
|
|
||||||
|
# Put together new checkpoint
|
||||||
|
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||||
|
|
||||||
|
clip = None
|
||||||
|
vae = None
|
||||||
|
|
||||||
|
class WeightsLoader(torch.nn.Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
w = WeightsLoader()
|
||||||
|
load_state_dict_to = []
|
||||||
|
if output_vae:
|
||||||
|
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
||||||
|
w.first_stage_model = vae.first_stage_model
|
||||||
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
|
if output_clip:
|
||||||
|
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||||
|
w.cond_stage_model = clip.cond_stage_model
|
||||||
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
|
model = instantiate_from_config(config["model"])
|
||||||
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||||
|
|
||||||
|
if fp16:
|
||||||
|
model = model.half()
|
||||||
|
|
||||||
|
return ModelPatcher(model), clip, vae
|
||||||
|
|||||||
@ -1,36 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import comfy.sd
|
|
||||||
|
|
||||||
def first_file(path, filenames):
|
|
||||||
for f in filenames:
|
|
||||||
p = os.path.join(path, f)
|
|
||||||
if os.path.exists(p):
|
|
||||||
return p
|
|
||||||
return None
|
|
||||||
|
|
||||||
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
|
|
||||||
diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
|
|
||||||
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
|
|
||||||
vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
|
|
||||||
|
|
||||||
text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
|
|
||||||
text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names)
|
|
||||||
text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names)
|
|
||||||
|
|
||||||
text_encoder_paths = [text_encoder1_path]
|
|
||||||
if text_encoder2_path is not None:
|
|
||||||
text_encoder_paths.append(text_encoder2_path)
|
|
||||||
|
|
||||||
unet = comfy.sd.load_diffusion_model(unet_path)
|
|
||||||
|
|
||||||
clip = None
|
|
||||||
if output_clip:
|
|
||||||
clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
|
|
||||||
|
|
||||||
vae = None
|
|
||||||
if output_vae:
|
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
|
||||||
|
|
||||||
return (unet, clip, vae)
|
|
||||||
@ -1,10 +1,10 @@
|
|||||||
#code taken from: https://github.com/wl-zhao/UniPC and modified
|
#code taken from: https://github.com/wl-zhao/UniPC and modified
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
import logging
|
|
||||||
|
|
||||||
from tqdm.auto import trange
|
from tqdm.auto import trange, tqdm
|
||||||
|
|
||||||
|
|
||||||
class NoiseScheduleVP:
|
class NoiseScheduleVP:
|
||||||
@ -16,7 +16,7 @@ class NoiseScheduleVP:
|
|||||||
continuous_beta_0=0.1,
|
continuous_beta_0=0.1,
|
||||||
continuous_beta_1=20.,
|
continuous_beta_1=20.,
|
||||||
):
|
):
|
||||||
r"""Create a wrapper class for the forward SDE (VP type).
|
"""Create a wrapper class for the forward SDE (VP type).
|
||||||
|
|
||||||
***
|
***
|
||||||
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
||||||
@ -80,7 +80,7 @@ class NoiseScheduleVP:
|
|||||||
'linear' or 'cosine' for continuous-time DPMs.
|
'linear' or 'cosine' for continuous-time DPMs.
|
||||||
Returns:
|
Returns:
|
||||||
A wrapper object of the forward SDE (VP type).
|
A wrapper object of the forward SDE (VP type).
|
||||||
|
|
||||||
===============================================================
|
===============================================================
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -180,6 +180,7 @@ class NoiseScheduleVP:
|
|||||||
|
|
||||||
def model_wrapper(
|
def model_wrapper(
|
||||||
model,
|
model,
|
||||||
|
sampling_function,
|
||||||
noise_schedule,
|
noise_schedule,
|
||||||
model_type="noise",
|
model_type="noise",
|
||||||
model_kwargs={},
|
model_kwargs={},
|
||||||
@ -208,7 +209,7 @@ def model_wrapper(
|
|||||||
arXiv preprint arXiv:2202.00512 (2022).
|
arXiv preprint arXiv:2202.00512 (2022).
|
||||||
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
||||||
arXiv preprint arXiv:2210.02303 (2022).
|
arXiv preprint arXiv:2210.02303 (2022).
|
||||||
|
|
||||||
4. "score": marginal score function. (Trained by denoising score matching).
|
4. "score": marginal score function. (Trained by denoising score matching).
|
||||||
Note that the score function and the noise prediction model follows a simple relationship:
|
Note that the score function and the noise prediction model follows a simple relationship:
|
||||||
```
|
```
|
||||||
@ -226,7 +227,7 @@ def model_wrapper(
|
|||||||
The input `model` has the following format:
|
The input `model` has the following format:
|
||||||
``
|
``
|
||||||
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
||||||
``
|
``
|
||||||
|
|
||||||
The input `classifier_fn` has the following format:
|
The input `classifier_fn` has the following format:
|
||||||
``
|
``
|
||||||
@ -240,12 +241,12 @@ def model_wrapper(
|
|||||||
The input `model` has the following format:
|
The input `model` has the following format:
|
||||||
``
|
``
|
||||||
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
||||||
``
|
``
|
||||||
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
||||||
|
|
||||||
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
||||||
arXiv preprint arXiv:2207.12598 (2022).
|
arXiv preprint arXiv:2207.12598 (2022).
|
||||||
|
|
||||||
|
|
||||||
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
||||||
or continuous-time labels (i.e. epsilon to T).
|
or continuous-time labels (i.e. epsilon to T).
|
||||||
@ -254,7 +255,7 @@ def model_wrapper(
|
|||||||
``
|
``
|
||||||
def model_fn(x, t_continuous) -> noise:
|
def model_fn(x, t_continuous) -> noise:
|
||||||
t_input = get_model_input_time(t_continuous)
|
t_input = get_model_input_time(t_continuous)
|
||||||
return noise_pred(model, x, t_input, **model_kwargs)
|
return noise_pred(model, x, t_input, **model_kwargs)
|
||||||
``
|
``
|
||||||
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
||||||
|
|
||||||
@ -294,7 +295,7 @@ def model_wrapper(
|
|||||||
if t_continuous.reshape((-1,)).shape[0] == 1:
|
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||||
t_continuous = t_continuous.expand((x.shape[0]))
|
t_continuous = t_continuous.expand((x.shape[0]))
|
||||||
t_input = get_model_input_time(t_continuous)
|
t_input = get_model_input_time(t_continuous)
|
||||||
output = model(x, t_input, **model_kwargs)
|
output = sampling_function(model, x, t_input, **model_kwargs)
|
||||||
if model_type == "noise":
|
if model_type == "noise":
|
||||||
return output
|
return output
|
||||||
elif model_type == "x_start":
|
elif model_type == "x_start":
|
||||||
@ -358,8 +359,11 @@ class UniPC:
|
|||||||
thresholding=False,
|
thresholding=False,
|
||||||
max_val=1.,
|
max_val=1.,
|
||||||
variant='bh1',
|
variant='bh1',
|
||||||
|
noise_mask=None,
|
||||||
|
masked_image=None,
|
||||||
|
noise=None,
|
||||||
):
|
):
|
||||||
"""Construct a UniPC.
|
"""Construct a UniPC.
|
||||||
|
|
||||||
We support both data_prediction and noise_prediction.
|
We support both data_prediction and noise_prediction.
|
||||||
"""
|
"""
|
||||||
@ -369,10 +373,13 @@ class UniPC:
|
|||||||
self.predict_x0 = predict_x0
|
self.predict_x0 = predict_x0
|
||||||
self.thresholding = thresholding
|
self.thresholding = thresholding
|
||||||
self.max_val = max_val
|
self.max_val = max_val
|
||||||
|
self.noise_mask = noise_mask
|
||||||
|
self.masked_image = masked_image
|
||||||
|
self.noise = noise
|
||||||
|
|
||||||
def dynamic_thresholding_fn(self, x0, t=None):
|
def dynamic_thresholding_fn(self, x0, t=None):
|
||||||
"""
|
"""
|
||||||
The dynamic thresholding method.
|
The dynamic thresholding method.
|
||||||
"""
|
"""
|
||||||
dims = x0.dim()
|
dims = x0.dim()
|
||||||
p = self.dynamic_thresholding_ratio
|
p = self.dynamic_thresholding_ratio
|
||||||
@ -385,7 +392,10 @@ class UniPC:
|
|||||||
"""
|
"""
|
||||||
Return the noise prediction model.
|
Return the noise prediction model.
|
||||||
"""
|
"""
|
||||||
return self.model(x, t)
|
if self.noise_mask is not None:
|
||||||
|
return self.model(x, t) * self.noise_mask
|
||||||
|
else:
|
||||||
|
return self.model(x, t)
|
||||||
|
|
||||||
def data_prediction_fn(self, x, t):
|
def data_prediction_fn(self, x, t):
|
||||||
"""
|
"""
|
||||||
@ -400,11 +410,13 @@ class UniPC:
|
|||||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||||
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||||
x0 = torch.clamp(x0, -s, s) / s
|
x0 = torch.clamp(x0, -s, s) / s
|
||||||
|
if self.noise_mask is not None:
|
||||||
|
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
|
||||||
return x0
|
return x0
|
||||||
|
|
||||||
def model_fn(self, x, t):
|
def model_fn(self, x, t):
|
||||||
"""
|
"""
|
||||||
Convert the model to the noise prediction model or the data prediction model.
|
Convert the model to the noise prediction model or the data prediction model.
|
||||||
"""
|
"""
|
||||||
if self.predict_x0:
|
if self.predict_x0:
|
||||||
return self.data_prediction_fn(x, t)
|
return self.data_prediction_fn(x, t)
|
||||||
@ -461,7 +473,7 @@ class UniPC:
|
|||||||
|
|
||||||
def denoise_to_zero_fn(self, x, s):
|
def denoise_to_zero_fn(self, x, s):
|
||||||
"""
|
"""
|
||||||
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
||||||
"""
|
"""
|
||||||
return self.data_prediction_fn(x, s)
|
return self.data_prediction_fn(x, s)
|
||||||
|
|
||||||
@ -475,7 +487,7 @@ class UniPC:
|
|||||||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||||
|
|
||||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||||
logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||||
ns = self.noise_schedule
|
ns = self.noise_schedule
|
||||||
assert order <= len(model_prev_list)
|
assert order <= len(model_prev_list)
|
||||||
|
|
||||||
@ -510,7 +522,7 @@ class UniPC:
|
|||||||
col = torch.ones_like(rks)
|
col = torch.ones_like(rks)
|
||||||
for k in range(1, K + 1):
|
for k in range(1, K + 1):
|
||||||
C.append(col)
|
C.append(col)
|
||||||
col = col * rks / (k + 1)
|
col = col * rks / (k + 1)
|
||||||
C = torch.stack(C, dim=1)
|
C = torch.stack(C, dim=1)
|
||||||
|
|
||||||
if len(D1s) > 0:
|
if len(D1s) > 0:
|
||||||
@ -519,6 +531,7 @@ class UniPC:
|
|||||||
A_p = C_inv_p
|
A_p = C_inv_p
|
||||||
|
|
||||||
if use_corrector:
|
if use_corrector:
|
||||||
|
print('using corrector')
|
||||||
C_inv = torch.linalg.inv(C)
|
C_inv = torch.linalg.inv(C)
|
||||||
A_c = C_inv
|
A_c = C_inv
|
||||||
|
|
||||||
@ -621,12 +634,12 @@ class UniPC:
|
|||||||
B_h = torch.expm1(hh)
|
B_h = torch.expm1(hh)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
for i in range(1, order + 1):
|
for i in range(1, order + 1):
|
||||||
R.append(torch.pow(rks, i - 1))
|
R.append(torch.pow(rks, i - 1))
|
||||||
b.append(h_phi_k * factorial_i / B_h)
|
b.append(h_phi_k * factorial_i / B_h)
|
||||||
factorial_i *= (i + 1)
|
factorial_i *= (i + 1)
|
||||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||||
|
|
||||||
R = torch.stack(R)
|
R = torch.stack(R)
|
||||||
b = torch.tensor(b, device=x.device)
|
b = torch.tensor(b, device=x.device)
|
||||||
@ -661,7 +674,7 @@ class UniPC:
|
|||||||
|
|
||||||
if x_t is None:
|
if x_t is None:
|
||||||
if use_predictor:
|
if use_predictor:
|
||||||
pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||||
else:
|
else:
|
||||||
pred_res = 0
|
pred_res = 0
|
||||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
||||||
@ -669,14 +682,14 @@ class UniPC:
|
|||||||
if use_corrector:
|
if use_corrector:
|
||||||
model_t = self.model_fn(x_t, t)
|
model_t = self.model_fn(x_t, t)
|
||||||
if D1s is not None:
|
if D1s is not None:
|
||||||
corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||||
else:
|
else:
|
||||||
corr_res = 0
|
corr_res = 0
|
||||||
D1_t = (model_t - model_prev_0)
|
D1_t = (model_t - model_prev_0)
|
||||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
||||||
else:
|
else:
|
||||||
x_t_ = (
|
x_t_ = (
|
||||||
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x
|
||||||
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
||||||
)
|
)
|
||||||
if x_t is None:
|
if x_t is None:
|
||||||
@ -701,8 +714,9 @@ class UniPC:
|
|||||||
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||||
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
|
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
|
||||||
):
|
):
|
||||||
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||||
# t_T = self.noise_schedule.T if t_start is None else t_start
|
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||||
|
device = x.device
|
||||||
steps = len(timesteps) - 1
|
steps = len(timesteps) - 1
|
||||||
if method == 'multistep':
|
if method == 'multistep':
|
||||||
assert steps >= order
|
assert steps >= order
|
||||||
@ -710,6 +724,8 @@ class UniPC:
|
|||||||
assert timesteps.shape[0] - 1 == steps
|
assert timesteps.shape[0] - 1 == steps
|
||||||
# with torch.no_grad():
|
# with torch.no_grad():
|
||||||
for step_index in trange(steps, disable=disable_pbar):
|
for step_index in trange(steps, disable=disable_pbar):
|
||||||
|
if self.noise_mask is not None:
|
||||||
|
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
|
||||||
if step_index == 0:
|
if step_index == 0:
|
||||||
vec_t = timesteps[0].expand((x.shape[0]))
|
vec_t = timesteps[0].expand((x.shape[0]))
|
||||||
model_prev_list = [self.model_fn(x, vec_t)]
|
model_prev_list = [self.model_fn(x, vec_t)]
|
||||||
@ -751,11 +767,11 @@ class UniPC:
|
|||||||
model_x = self.model_fn(x, vec_t)
|
model_x = self.model_fn(x, vec_t)
|
||||||
model_prev_list[-1] = model_x
|
model_prev_list[-1] = model_x
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
|
callback(step_index, model_prev_list[-1], x, steps)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
# if denoise_to_zero:
|
if denoise_to_zero:
|
||||||
# x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -818,56 +834,52 @@ def expand_dims(v, dims):
|
|||||||
return v[(...,) + (None,)*(dims - 1)]
|
return v[(...,) + (None,)*(dims - 1)]
|
||||||
|
|
||||||
|
|
||||||
class SigmaConvert:
|
|
||||||
schedule = ""
|
|
||||||
def marginal_log_mean_coeff(self, sigma):
|
|
||||||
return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
|
|
||||||
|
|
||||||
def marginal_alpha(self, t):
|
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||||
return torch.exp(self.marginal_log_mean_coeff(t))
|
to_zero = False
|
||||||
|
|
||||||
def marginal_std(self, t):
|
|
||||||
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
|
||||||
|
|
||||||
def marginal_lambda(self, t):
|
|
||||||
"""
|
|
||||||
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
|
||||||
"""
|
|
||||||
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
|
||||||
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
|
||||||
return log_mean_coeff - log_std
|
|
||||||
|
|
||||||
def predict_eps_sigma(model, input, sigma_in, **kwargs):
|
|
||||||
sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
|
|
||||||
input = input * ((sigma ** 2 + 1.0) ** 0.5)
|
|
||||||
return (input - model(input, sigma_in, **kwargs)) / sigma
|
|
||||||
|
|
||||||
|
|
||||||
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
|
|
||||||
timesteps = sigmas.clone()
|
|
||||||
if sigmas[-1] == 0:
|
if sigmas[-1] == 0:
|
||||||
timesteps = sigmas[:]
|
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
|
||||||
timesteps[-1] = 0.001
|
to_zero = True
|
||||||
else:
|
else:
|
||||||
timesteps = sigmas.clone()
|
timesteps = sigmas.clone()
|
||||||
ns = SigmaConvert()
|
|
||||||
|
|
||||||
noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
|
for s in range(timesteps.shape[0]):
|
||||||
model_type = "noise"
|
timesteps[s] = (model.sigma_to_t(timesteps[s]) / 1000) + (1 / len(model.sigmas))
|
||||||
|
|
||||||
|
ns = NoiseScheduleVP('discrete', alphas_cumprod=model.inner_model.alphas_cumprod)
|
||||||
|
|
||||||
|
if image is not None:
|
||||||
|
img = image * ns.marginal_alpha(timesteps[0])
|
||||||
|
if max_denoise:
|
||||||
|
noise_mult = 1.0
|
||||||
|
else:
|
||||||
|
noise_mult = ns.marginal_std(timesteps[0])
|
||||||
|
img += noise * noise_mult
|
||||||
|
else:
|
||||||
|
img = noise
|
||||||
|
|
||||||
|
if to_zero:
|
||||||
|
timesteps[-1] = (1 / len(model.sigmas))
|
||||||
|
|
||||||
|
device = noise.device
|
||||||
|
|
||||||
|
if model.parameterization == "v":
|
||||||
|
model_type = "v"
|
||||||
|
else:
|
||||||
|
model_type = "noise"
|
||||||
|
|
||||||
model_fn = model_wrapper(
|
model_fn = model_wrapper(
|
||||||
lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
|
model.inner_model.inner_model.apply_model,
|
||||||
|
sampling_function,
|
||||||
ns,
|
ns,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
guidance_type="uncond",
|
guidance_type="uncond",
|
||||||
model_kwargs=extra_args,
|
model_kwargs=extra_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
order = min(3, len(timesteps) - 2)
|
order = min(3, len(timesteps) - 1)
|
||||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
||||||
x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
||||||
x /= ns.marginal_alpha(timesteps[-1])
|
if not to_zero:
|
||||||
|
x /= ns.marginal_alpha(timesteps[-1])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
|
|
||||||
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|
|
||||||
|
|||||||
@ -1,67 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
|
||||||
mantissa_scaled = torch.where(
|
|
||||||
normal_mask,
|
|
||||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
|
||||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
|
||||||
)
|
|
||||||
|
|
||||||
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
|
||||||
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
|
||||||
|
|
||||||
#Not 100% sure about this
|
|
||||||
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
|
||||||
if dtype == torch.float8_e4m3fn:
|
|
||||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
|
||||||
elif dtype == torch.float8_e5m2:
|
|
||||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported dtype")
|
|
||||||
|
|
||||||
x = x.half()
|
|
||||||
sign = torch.sign(x)
|
|
||||||
abs_x = x.abs()
|
|
||||||
sign = torch.where(abs_x == 0, 0, sign)
|
|
||||||
|
|
||||||
# Combine exponent calculation and clamping
|
|
||||||
exponent = torch.clamp(
|
|
||||||
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
|
||||||
0, 2**EXPONENT_BITS - 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Combine mantissa calculation and rounding
|
|
||||||
normal_mask = ~(exponent == 0)
|
|
||||||
|
|
||||||
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
|
||||||
|
|
||||||
sign *= torch.where(
|
|
||||||
normal_mask,
|
|
||||||
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
|
||||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
|
||||||
)
|
|
||||||
|
|
||||||
inf = torch.finfo(dtype)
|
|
||||||
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
|
|
||||||
return sign
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def stochastic_rounding(value, dtype, seed=0):
|
|
||||||
if dtype == torch.float32:
|
|
||||||
return value.to(dtype=torch.float32)
|
|
||||||
if dtype == torch.float16:
|
|
||||||
return value.to(dtype=torch.float16)
|
|
||||||
if dtype == torch.bfloat16:
|
|
||||||
return value.to(dtype=torch.bfloat16)
|
|
||||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
|
||||||
generator = torch.Generator(device=value.device)
|
|
||||||
generator.manual_seed(seed)
|
|
||||||
output = torch.empty_like(value, dtype=dtype)
|
|
||||||
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
|
||||||
slice_size = max(1, round(value.shape[0] / num_slices))
|
|
||||||
for i in range(0, value.shape[0], slice_size):
|
|
||||||
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
|
||||||
return output
|
|
||||||
|
|
||||||
return value.to(dtype=dtype)
|
|
||||||
121
comfy/gligen.py
121
comfy/gligen.py
@ -1,9 +1,52 @@
|
|||||||
import math
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn, einsum
|
||||||
from .ldm.modules.attention import CrossAttention, FeedForward
|
from .ldm.modules.attention import CrossAttention
|
||||||
import comfy.ops
|
from inspect import isfunction
|
||||||
ops = comfy.ops.manual_cast
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def uniq(arr):
|
||||||
|
return{el: True for el in arr}.keys()
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
if exists(val):
|
||||||
|
return val
|
||||||
|
return d() if isfunction(d) else d
|
||||||
|
|
||||||
|
|
||||||
|
# feedforward
|
||||||
|
class GEGLU(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
|
return x * torch.nn.functional.gelu(gate)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
dim_out = default(dim_out, dim)
|
||||||
|
project_in = nn.Sequential(
|
||||||
|
nn.Linear(dim, inner_dim),
|
||||||
|
nn.GELU()
|
||||||
|
) if not glu else GEGLU(dim, inner_dim)
|
||||||
|
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
project_in,
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(inner_dim, dim_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
class GatedCrossAttentionDense(nn.Module):
|
class GatedCrossAttentionDense(nn.Module):
|
||||||
@ -14,12 +57,11 @@ class GatedCrossAttentionDense(nn.Module):
|
|||||||
query_dim=query_dim,
|
query_dim=query_dim,
|
||||||
context_dim=context_dim,
|
context_dim=context_dim,
|
||||||
heads=n_heads,
|
heads=n_heads,
|
||||||
dim_head=d_head,
|
dim_head=d_head)
|
||||||
operations=ops)
|
|
||||||
self.ff = FeedForward(query_dim, glu=True)
|
self.ff = FeedForward(query_dim, glu=True)
|
||||||
|
|
||||||
self.norm1 = ops.LayerNorm(query_dim)
|
self.norm1 = nn.LayerNorm(query_dim)
|
||||||
self.norm2 = ops.LayerNorm(query_dim)
|
self.norm2 = nn.LayerNorm(query_dim)
|
||||||
|
|
||||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||||
@ -45,18 +87,17 @@ class GatedSelfAttentionDense(nn.Module):
|
|||||||
|
|
||||||
# we need a linear projection since we need cat visual feature and obj
|
# we need a linear projection since we need cat visual feature and obj
|
||||||
# feature
|
# feature
|
||||||
self.linear = ops.Linear(context_dim, query_dim)
|
self.linear = nn.Linear(context_dim, query_dim)
|
||||||
|
|
||||||
self.attn = CrossAttention(
|
self.attn = CrossAttention(
|
||||||
query_dim=query_dim,
|
query_dim=query_dim,
|
||||||
context_dim=query_dim,
|
context_dim=query_dim,
|
||||||
heads=n_heads,
|
heads=n_heads,
|
||||||
dim_head=d_head,
|
dim_head=d_head)
|
||||||
operations=ops)
|
|
||||||
self.ff = FeedForward(query_dim, glu=True)
|
self.ff = FeedForward(query_dim, glu=True)
|
||||||
|
|
||||||
self.norm1 = ops.LayerNorm(query_dim)
|
self.norm1 = nn.LayerNorm(query_dim)
|
||||||
self.norm2 = ops.LayerNorm(query_dim)
|
self.norm2 = nn.LayerNorm(query_dim)
|
||||||
|
|
||||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||||
@ -85,14 +126,14 @@ class GatedSelfAttentionDense2(nn.Module):
|
|||||||
|
|
||||||
# we need a linear projection since we need cat visual feature and obj
|
# we need a linear projection since we need cat visual feature and obj
|
||||||
# feature
|
# feature
|
||||||
self.linear = ops.Linear(context_dim, query_dim)
|
self.linear = nn.Linear(context_dim, query_dim)
|
||||||
|
|
||||||
self.attn = CrossAttention(
|
self.attn = CrossAttention(
|
||||||
query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
|
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
|
||||||
self.ff = FeedForward(query_dim, glu=True)
|
self.ff = FeedForward(query_dim, glu=True)
|
||||||
|
|
||||||
self.norm1 = ops.LayerNorm(query_dim)
|
self.norm1 = nn.LayerNorm(query_dim)
|
||||||
self.norm2 = ops.LayerNorm(query_dim)
|
self.norm2 = nn.LayerNorm(query_dim)
|
||||||
|
|
||||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||||
@ -160,11 +201,11 @@ class PositionNet(nn.Module):
|
|||||||
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
|
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
|
||||||
|
|
||||||
self.linears = nn.Sequential(
|
self.linears = nn.Sequential(
|
||||||
ops.Linear(self.in_dim + self.position_dim, 512),
|
nn.Linear(self.in_dim + self.position_dim, 512),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
ops.Linear(512, 512),
|
nn.Linear(512, 512),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
ops.Linear(512, out_dim),
|
nn.Linear(512, out_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.null_positive_feature = torch.nn.Parameter(
|
self.null_positive_feature = torch.nn.Parameter(
|
||||||
@ -175,14 +216,13 @@ class PositionNet(nn.Module):
|
|||||||
def forward(self, boxes, masks, positive_embeddings):
|
def forward(self, boxes, masks, positive_embeddings):
|
||||||
B, N, _ = boxes.shape
|
B, N, _ = boxes.shape
|
||||||
masks = masks.unsqueeze(-1)
|
masks = masks.unsqueeze(-1)
|
||||||
positive_embeddings = positive_embeddings
|
|
||||||
|
|
||||||
# embedding position (it may includes padding as placeholder)
|
# embedding position (it may includes padding as placeholder)
|
||||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
||||||
|
|
||||||
# learnable null embedding
|
# learnable null embedding
|
||||||
positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
|
positive_null = self.null_positive_feature.view(1, 1, -1)
|
||||||
xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
|
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
||||||
|
|
||||||
# replace padding with learnable null embedding
|
# replace padding with learnable null embedding
|
||||||
positive_embeddings = positive_embeddings * \
|
positive_embeddings = positive_embeddings * \
|
||||||
@ -202,15 +242,28 @@ class Gligen(nn.Module):
|
|||||||
self.position_net = position_net
|
self.position_net = position_net
|
||||||
self.key_dim = key_dim
|
self.key_dim = key_dim
|
||||||
self.max_objs = 30
|
self.max_objs = 30
|
||||||
self.current_device = torch.device("cpu")
|
self.lowvram = False
|
||||||
|
|
||||||
def _set_position(self, boxes, masks, positive_embeddings):
|
def _set_position(self, boxes, masks, positive_embeddings):
|
||||||
|
if self.lowvram == True:
|
||||||
|
self.position_net.to(boxes.device)
|
||||||
|
|
||||||
objs = self.position_net(boxes, masks, positive_embeddings)
|
objs = self.position_net(boxes, masks, positive_embeddings)
|
||||||
def func(x, extra_options):
|
|
||||||
key = extra_options["transformer_index"]
|
if self.lowvram == True:
|
||||||
module = self.module_list[key]
|
self.position_net.cpu()
|
||||||
return module(x, objs.to(device=x.device, dtype=x.dtype))
|
def func_lowvram(key, x):
|
||||||
return func
|
module = self.module_list[key]
|
||||||
|
module.to(x.device)
|
||||||
|
r = module(x, objs)
|
||||||
|
module.cpu()
|
||||||
|
return r
|
||||||
|
return func_lowvram
|
||||||
|
else:
|
||||||
|
def func(key, x):
|
||||||
|
module = self.module_list[key]
|
||||||
|
return module(x, objs)
|
||||||
|
return func
|
||||||
|
|
||||||
def set_position(self, latent_image_shape, position_params, device):
|
def set_position(self, latent_image_shape, position_params, device):
|
||||||
batch, c, h, w = latent_image_shape
|
batch, c, h, w = latent_image_shape
|
||||||
@ -255,6 +308,14 @@ class Gligen(nn.Module):
|
|||||||
masks.to(device),
|
masks.to(device),
|
||||||
conds.to(device))
|
conds.to(device))
|
||||||
|
|
||||||
|
def set_lowvram(self, value=True):
|
||||||
|
self.lowvram = value
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
self.lowvram = False
|
||||||
|
|
||||||
|
def get_models(self):
|
||||||
|
return [self]
|
||||||
|
|
||||||
def load_gligen(sd):
|
def load_gligen(sd):
|
||||||
sd_k = sd.keys()
|
sd_k = sd.keys()
|
||||||
|
|||||||
785
comfy/hooks.py
785
comfy/hooks.py
@ -1,785 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
from typing import TYPE_CHECKING, Callable
|
|
||||||
import enum
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import itertools
|
|
||||||
import logging
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from comfy.model_patcher import ModelPatcher, PatcherInjection
|
|
||||||
from comfy.model_base import BaseModel
|
|
||||||
from comfy.sd import CLIP
|
|
||||||
import comfy.lora
|
|
||||||
import comfy.model_management
|
|
||||||
import comfy.patcher_extension
|
|
||||||
from node_helpers import conditioning_set_values
|
|
||||||
|
|
||||||
# #######################################################################################################
|
|
||||||
# Hooks explanation
|
|
||||||
# -------------------
|
|
||||||
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
|
|
||||||
# make explicit special cases like it does for ControlNet and GLIGEN.
|
|
||||||
#
|
|
||||||
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
|
|
||||||
# that should run special code when a 'marked' cond is used in sampling.
|
|
||||||
# #######################################################################################################
|
|
||||||
|
|
||||||
class EnumHookMode(enum.Enum):
|
|
||||||
'''
|
|
||||||
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.
|
|
||||||
|
|
||||||
MinVram: No caching will occur for any operations related to hooks.
|
|
||||||
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
|
|
||||||
'''
|
|
||||||
MinVram = "minvram"
|
|
||||||
MaxSpeed = "maxspeed"
|
|
||||||
|
|
||||||
class EnumHookType(enum.Enum):
|
|
||||||
'''
|
|
||||||
Hook types, each of which has different expected behavior.
|
|
||||||
'''
|
|
||||||
Weight = "weight"
|
|
||||||
ObjectPatch = "object_patch"
|
|
||||||
AdditionalModels = "add_models"
|
|
||||||
TransformerOptions = "transformer_options"
|
|
||||||
Injections = "add_injections"
|
|
||||||
|
|
||||||
class EnumWeightTarget(enum.Enum):
|
|
||||||
Model = "model"
|
|
||||||
Clip = "clip"
|
|
||||||
|
|
||||||
class EnumHookScope(enum.Enum):
|
|
||||||
'''
|
|
||||||
Determines if hook should be limited in its influence over sampling.
|
|
||||||
|
|
||||||
AllConditioning: hook will affect all conds used in sampling.
|
|
||||||
HookedOnly: hook will only affect the conds it was attached to.
|
|
||||||
'''
|
|
||||||
AllConditioning = "all_conditioning"
|
|
||||||
HookedOnly = "hooked_only"
|
|
||||||
|
|
||||||
|
|
||||||
class _HookRef:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
|
||||||
'''Example for how custom_should_register function can look like.'''
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
|
|
||||||
'''Creates base dictionary for use with Hooks' target param.'''
|
|
||||||
d = {}
|
|
||||||
if target is not None:
|
|
||||||
d['target'] = target
|
|
||||||
d.update(kwargs)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
class Hook:
|
|
||||||
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
|
|
||||||
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
|
|
||||||
self.hook_type = hook_type
|
|
||||||
'''Enum identifying the general class of this hook.'''
|
|
||||||
self.hook_ref = hook_ref if hook_ref else _HookRef()
|
|
||||||
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
|
|
||||||
self.hook_id = hook_id
|
|
||||||
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
|
|
||||||
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
|
|
||||||
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
|
|
||||||
self.hook_scope = hook_scope
|
|
||||||
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
|
||||||
self.custom_should_register = default_should_register
|
|
||||||
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
|
||||||
|
|
||||||
@property
|
|
||||||
def strength(self):
|
|
||||||
return self.hook_keyframe.strength
|
|
||||||
|
|
||||||
def initialize_timesteps(self, model: BaseModel):
|
|
||||||
self.reset()
|
|
||||||
self.hook_keyframe.initialize_timesteps(model)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.hook_keyframe.reset()
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
c: Hook = self.__class__()
|
|
||||||
c.hook_type = self.hook_type
|
|
||||||
c.hook_ref = self.hook_ref
|
|
||||||
c.hook_id = self.hook_id
|
|
||||||
c.hook_keyframe = self.hook_keyframe
|
|
||||||
c.hook_scope = self.hook_scope
|
|
||||||
c.custom_should_register = self.custom_should_register
|
|
||||||
return c
|
|
||||||
|
|
||||||
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
|
||||||
return self.custom_should_register(self, model, model_options, target_dict, registered)
|
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
|
||||||
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
|
|
||||||
|
|
||||||
def __eq__(self, other: Hook):
|
|
||||||
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash(self.hook_ref)
|
|
||||||
|
|
||||||
class WeightHook(Hook):
|
|
||||||
'''
|
|
||||||
Hook responsible for tracking weights to be applied to some model/clip.
|
|
||||||
|
|
||||||
Note, value of hook_scope is ignored and is treated as HookedOnly.
|
|
||||||
'''
|
|
||||||
def __init__(self, strength_model=1.0, strength_clip=1.0):
|
|
||||||
super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
|
|
||||||
self.weights: dict = None
|
|
||||||
self.weights_clip: dict = None
|
|
||||||
self.need_weight_init = True
|
|
||||||
self._strength_model = strength_model
|
|
||||||
self._strength_clip = strength_clip
|
|
||||||
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
|
|
||||||
|
|
||||||
@property
|
|
||||||
def strength_model(self):
|
|
||||||
return self._strength_model * self.strength
|
|
||||||
|
|
||||||
@property
|
|
||||||
def strength_clip(self):
|
|
||||||
return self._strength_clip * self.strength
|
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
|
||||||
if not self.should_register(model, model_options, target_dict, registered):
|
|
||||||
return False
|
|
||||||
weights = None
|
|
||||||
|
|
||||||
target = target_dict.get('target', None)
|
|
||||||
if target == EnumWeightTarget.Clip:
|
|
||||||
strength = self._strength_clip
|
|
||||||
else:
|
|
||||||
strength = self._strength_model
|
|
||||||
|
|
||||||
if self.need_weight_init:
|
|
||||||
key_map = {}
|
|
||||||
if target == EnumWeightTarget.Clip:
|
|
||||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
|
||||||
else:
|
|
||||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
|
||||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
|
||||||
else:
|
|
||||||
if target == EnumWeightTarget.Clip:
|
|
||||||
weights = self.weights_clip
|
|
||||||
else:
|
|
||||||
weights = self.weights
|
|
||||||
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
|
||||||
registered.add(self)
|
|
||||||
return True
|
|
||||||
# TODO: add logs about any keys that were not applied
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
c: WeightHook = super().clone()
|
|
||||||
c.weights = self.weights
|
|
||||||
c.weights_clip = self.weights_clip
|
|
||||||
c.need_weight_init = self.need_weight_init
|
|
||||||
c._strength_model = self._strength_model
|
|
||||||
c._strength_clip = self._strength_clip
|
|
||||||
return c
|
|
||||||
|
|
||||||
class ObjectPatchHook(Hook):
|
|
||||||
def __init__(self, object_patches: dict[str]=None,
|
|
||||||
hook_scope=EnumHookScope.AllConditioning):
|
|
||||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
|
||||||
self.object_patches = object_patches
|
|
||||||
self.hook_scope = hook_scope
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
c: ObjectPatchHook = super().clone()
|
|
||||||
c.object_patches = self.object_patches
|
|
||||||
return c
|
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
|
||||||
raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")
|
|
||||||
|
|
||||||
class AdditionalModelsHook(Hook):
|
|
||||||
'''
|
|
||||||
Hook responsible for telling model management any additional models that should be loaded.
|
|
||||||
|
|
||||||
Note, value of hook_scope is ignored and is treated as AllConditioning.
|
|
||||||
'''
|
|
||||||
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
|
|
||||||
super().__init__(hook_type=EnumHookType.AdditionalModels)
|
|
||||||
self.models = models
|
|
||||||
self.key = key
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
c: AdditionalModelsHook = super().clone()
|
|
||||||
c.models = self.models.copy() if self.models else self.models
|
|
||||||
c.key = self.key
|
|
||||||
return c
|
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
|
||||||
if not self.should_register(model, model_options, target_dict, registered):
|
|
||||||
return False
|
|
||||||
registered.add(self)
|
|
||||||
return True
|
|
||||||
|
|
||||||
class TransformerOptionsHook(Hook):
|
|
||||||
'''
|
|
||||||
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
|
|
||||||
'''
|
|
||||||
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
|
|
||||||
hook_scope=EnumHookScope.AllConditioning):
|
|
||||||
super().__init__(hook_type=EnumHookType.TransformerOptions)
|
|
||||||
self.transformers_dict = transformers_dict
|
|
||||||
self.hook_scope = hook_scope
|
|
||||||
self._skip_adding = False
|
|
||||||
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
c: TransformerOptionsHook = super().clone()
|
|
||||||
c.transformers_dict = self.transformers_dict
|
|
||||||
c._skip_adding = self._skip_adding
|
|
||||||
return c
|
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
|
||||||
if not self.should_register(model, model_options, target_dict, registered):
|
|
||||||
return False
|
|
||||||
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
|
|
||||||
self._skip_adding = False
|
|
||||||
if self.hook_scope == EnumHookScope.AllConditioning:
|
|
||||||
add_model_options = {"transformer_options": self.transformers_dict,
|
|
||||||
"to_load_options": self.transformers_dict}
|
|
||||||
# skip_adding if included in AllConditioning to avoid double loading
|
|
||||||
self._skip_adding = True
|
|
||||||
else:
|
|
||||||
add_model_options = {"to_load_options": self.transformers_dict}
|
|
||||||
registered.add(self)
|
|
||||||
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
|
|
||||||
if not self._skip_adding:
|
|
||||||
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
|
|
||||||
|
|
||||||
WrapperHook = TransformerOptionsHook
|
|
||||||
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
|
|
||||||
|
|
||||||
class InjectionsHook(Hook):
|
|
||||||
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
|
|
||||||
hook_scope=EnumHookScope.AllConditioning):
|
|
||||||
super().__init__(hook_type=EnumHookType.Injections)
|
|
||||||
self.key = key
|
|
||||||
self.injections = injections
|
|
||||||
self.hook_scope = hook_scope
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
c: InjectionsHook = super().clone()
|
|
||||||
c.key = self.key
|
|
||||||
c.injections = self.injections.copy() if self.injections else self.injections
|
|
||||||
return c
|
|
||||||
|
|
||||||
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
|
||||||
raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")
|
|
||||||
|
|
||||||
class HookGroup:
|
|
||||||
'''
|
|
||||||
Stores groups of hooks, and allows them to be queried by type.
|
|
||||||
|
|
||||||
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
|
|
||||||
always use the provided functions on HookGroup.
|
|
||||||
'''
|
|
||||||
def __init__(self):
|
|
||||||
self.hooks: list[Hook] = []
|
|
||||||
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.hooks)
|
|
||||||
|
|
||||||
def add(self, hook: Hook):
|
|
||||||
if hook not in self.hooks:
|
|
||||||
self.hooks.append(hook)
|
|
||||||
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
|
|
||||||
|
|
||||||
def remove(self, hook: Hook):
|
|
||||||
if hook in self.hooks:
|
|
||||||
self.hooks.remove(hook)
|
|
||||||
self._hook_dict[hook.hook_type].remove(hook)
|
|
||||||
|
|
||||||
def get_type(self, hook_type: EnumHookType):
|
|
||||||
return self._hook_dict.get(hook_type, [])
|
|
||||||
|
|
||||||
def contains(self, hook: Hook):
|
|
||||||
return hook in self.hooks
|
|
||||||
|
|
||||||
def is_subset_of(self, other: HookGroup):
|
|
||||||
self_hooks = set(self.hooks)
|
|
||||||
other_hooks = set(other.hooks)
|
|
||||||
return self_hooks.issubset(other_hooks)
|
|
||||||
|
|
||||||
def new_with_common_hooks(self, other: HookGroup):
|
|
||||||
c = HookGroup()
|
|
||||||
for hook in self.hooks:
|
|
||||||
if other.contains(hook):
|
|
||||||
c.add(hook.clone())
|
|
||||||
return c
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
c = HookGroup()
|
|
||||||
for hook in self.hooks:
|
|
||||||
c.add(hook.clone())
|
|
||||||
return c
|
|
||||||
|
|
||||||
def clone_and_combine(self, other: HookGroup):
|
|
||||||
c = self.clone()
|
|
||||||
if other is not None:
|
|
||||||
for hook in other.hooks:
|
|
||||||
c.add(hook.clone())
|
|
||||||
return c
|
|
||||||
|
|
||||||
def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
|
|
||||||
if hook_kf is None:
|
|
||||||
hook_kf = HookKeyframeGroup()
|
|
||||||
else:
|
|
||||||
hook_kf = hook_kf.clone()
|
|
||||||
for hook in self.hooks:
|
|
||||||
hook.hook_keyframe = hook_kf
|
|
||||||
|
|
||||||
def get_hooks_for_clip_schedule(self):
|
|
||||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
|
|
||||||
# only care about WeightHooks, for now
|
|
||||||
for hook in self.get_type(EnumHookType.Weight):
|
|
||||||
hook: WeightHook
|
|
||||||
hook_schedule = []
|
|
||||||
# if no hook keyframes, assign default value
|
|
||||||
if len(hook.hook_keyframe.keyframes) == 0:
|
|
||||||
hook_schedule.append(((0.0, 1.0), None))
|
|
||||||
scheduled_hooks[hook] = hook_schedule
|
|
||||||
continue
|
|
||||||
# find ranges of values
|
|
||||||
prev_keyframe = hook.hook_keyframe.keyframes[0]
|
|
||||||
for keyframe in hook.hook_keyframe.keyframes:
|
|
||||||
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
|
|
||||||
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
|
|
||||||
prev_keyframe = keyframe
|
|
||||||
elif keyframe.start_percent == prev_keyframe.start_percent:
|
|
||||||
prev_keyframe = keyframe
|
|
||||||
# create final range, assuming last start_percent was not 1.0
|
|
||||||
if not math.isclose(prev_keyframe.start_percent, 1.0):
|
|
||||||
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
|
|
||||||
scheduled_hooks[hook] = hook_schedule
|
|
||||||
# hooks should not have their schedules in a list of tuples
|
|
||||||
all_ranges: list[tuple[float, float]] = []
|
|
||||||
for range_kfs in scheduled_hooks.values():
|
|
||||||
for t_range, keyframe in range_kfs:
|
|
||||||
all_ranges.append(t_range)
|
|
||||||
# turn list of ranges into boundaries
|
|
||||||
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
|
|
||||||
boundaries_set.add(0.0)
|
|
||||||
boundaries = sorted(boundaries_set)
|
|
||||||
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
|
|
||||||
# with real ranges defined, give appropriate hooks w/ keyframes for each range
|
|
||||||
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
|
|
||||||
for t_range in real_ranges:
|
|
||||||
hooks_schedule = []
|
|
||||||
for hook, val in scheduled_hooks.items():
|
|
||||||
keyframe = None
|
|
||||||
# check if is a keyframe that works for the current t_range
|
|
||||||
for stored_range, stored_kf in val:
|
|
||||||
# if stored start is less than current end, then fits - give it assigned keyframe
|
|
||||||
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
|
|
||||||
keyframe = stored_kf
|
|
||||||
break
|
|
||||||
hooks_schedule.append((hook, keyframe))
|
|
||||||
scheduled_keyframes.append((t_range, hooks_schedule))
|
|
||||||
return scheduled_keyframes
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
for hook in self.hooks:
|
|
||||||
hook.reset()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
|
|
||||||
actual: list[HookGroup] = []
|
|
||||||
for group in hooks_list:
|
|
||||||
if group is not None:
|
|
||||||
actual.append(group)
|
|
||||||
if len(actual) < require_count:
|
|
||||||
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
|
|
||||||
# if no hooks, then return None
|
|
||||||
if len(actual) == 0:
|
|
||||||
return None
|
|
||||||
# if only 1 hook, just return itself without cloning
|
|
||||||
elif len(actual) == 1:
|
|
||||||
return actual[0]
|
|
||||||
final_hook: HookGroup = None
|
|
||||||
for hook in actual:
|
|
||||||
if final_hook is None:
|
|
||||||
final_hook = hook.clone()
|
|
||||||
else:
|
|
||||||
final_hook = final_hook.clone_and_combine(hook)
|
|
||||||
return final_hook
|
|
||||||
|
|
||||||
|
|
||||||
class HookKeyframe:
|
|
||||||
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
|
|
||||||
self.strength = strength
|
|
||||||
# scheduling
|
|
||||||
self.start_percent = float(start_percent)
|
|
||||||
self.start_t = 999999999.9
|
|
||||||
self.guarantee_steps = guarantee_steps
|
|
||||||
|
|
||||||
def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
|
|
||||||
'''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
|
|
||||||
if self.start_t > max_sigma:
|
|
||||||
return 0
|
|
||||||
return self.guarantee_steps
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
c = HookKeyframe(strength=self.strength,
|
|
||||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
|
||||||
c.start_t = self.start_t
|
|
||||||
return c
|
|
||||||
|
|
||||||
class HookKeyframeGroup:
|
|
||||||
def __init__(self):
|
|
||||||
self.keyframes: list[HookKeyframe] = []
|
|
||||||
self._current_keyframe: HookKeyframe = None
|
|
||||||
self._current_used_steps = 0
|
|
||||||
self._current_index = 0
|
|
||||||
self._current_strength = None
|
|
||||||
self._curr_t = -1.
|
|
||||||
|
|
||||||
# properties shadow those of HookWeightsKeyframe
|
|
||||||
@property
|
|
||||||
def strength(self):
|
|
||||||
if self._current_keyframe is not None:
|
|
||||||
return self._current_keyframe.strength
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self._current_keyframe = None
|
|
||||||
self._current_used_steps = 0
|
|
||||||
self._current_index = 0
|
|
||||||
self._current_strength = None
|
|
||||||
self.curr_t = -1.
|
|
||||||
self._set_first_as_current()
|
|
||||||
|
|
||||||
def add(self, keyframe: HookKeyframe):
|
|
||||||
# add to end of list, then sort
|
|
||||||
self.keyframes.append(keyframe)
|
|
||||||
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
|
|
||||||
self._set_first_as_current()
|
|
||||||
|
|
||||||
def _set_first_as_current(self):
|
|
||||||
if len(self.keyframes) > 0:
|
|
||||||
self._current_keyframe = self.keyframes[0]
|
|
||||||
else:
|
|
||||||
self._current_keyframe = None
|
|
||||||
|
|
||||||
def has_guarantee_steps(self):
|
|
||||||
for kf in self.keyframes:
|
|
||||||
if kf.guarantee_steps > 0:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def has_index(self, index: int):
|
|
||||||
return index >= 0 and index < len(self.keyframes)
|
|
||||||
|
|
||||||
def is_empty(self):
|
|
||||||
return len(self.keyframes) == 0
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
c = HookKeyframeGroup()
|
|
||||||
for keyframe in self.keyframes:
|
|
||||||
c.keyframes.append(keyframe.clone())
|
|
||||||
c._set_first_as_current()
|
|
||||||
return c
|
|
||||||
|
|
||||||
def initialize_timesteps(self, model: BaseModel):
|
|
||||||
for keyframe in self.keyframes:
|
|
||||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
|
||||||
|
|
||||||
def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
|
|
||||||
if self.is_empty():
|
|
||||||
return False
|
|
||||||
if curr_t == self._curr_t:
|
|
||||||
return False
|
|
||||||
max_sigma = torch.max(transformer_options["sample_sigmas"])
|
|
||||||
prev_index = self._current_index
|
|
||||||
prev_strength = self._current_strength
|
|
||||||
# if met guaranteed steps, look for next keyframe in case need to switch
|
|
||||||
if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
|
|
||||||
# if has next index, loop through and see if need to switch
|
|
||||||
if self.has_index(self._current_index+1):
|
|
||||||
for i in range(self._current_index+1, len(self.keyframes)):
|
|
||||||
eval_c = self.keyframes[i]
|
|
||||||
# check if start_t is greater or equal to curr_t
|
|
||||||
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
|
|
||||||
if eval_c.start_t >= curr_t:
|
|
||||||
self._current_index = i
|
|
||||||
self._current_strength = eval_c.strength
|
|
||||||
self._current_keyframe = eval_c
|
|
||||||
self._current_used_steps = 0
|
|
||||||
# if guarantee_steps greater than zero, stop searching for other keyframes
|
|
||||||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
|
||||||
break
|
|
||||||
# if eval_c is outside the percent range, stop looking further
|
|
||||||
else: break
|
|
||||||
# update steps current context is used
|
|
||||||
self._current_used_steps += 1
|
|
||||||
# update current timestep this was performed on
|
|
||||||
self._curr_t = curr_t
|
|
||||||
# return True if keyframe changed, False if no change
|
|
||||||
return prev_index != self._current_index and prev_strength != self._current_strength
|
|
||||||
|
|
||||||
|
|
||||||
class InterpolationMethod:
|
|
||||||
LINEAR = "linear"
|
|
||||||
EASE_IN = "ease_in"
|
|
||||||
EASE_OUT = "ease_out"
|
|
||||||
EASE_IN_OUT = "ease_in_out"
|
|
||||||
|
|
||||||
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
|
|
||||||
diff = num_to - num_from
|
|
||||||
if method == cls.LINEAR:
|
|
||||||
weights = torch.linspace(num_from, num_to, length)
|
|
||||||
elif method == cls.EASE_IN:
|
|
||||||
index = torch.linspace(0, 1, length)
|
|
||||||
weights = diff * np.power(index, 2) + num_from
|
|
||||||
elif method == cls.EASE_OUT:
|
|
||||||
index = torch.linspace(0, 1, length)
|
|
||||||
weights = diff * (1 - np.power(1 - index, 2)) + num_from
|
|
||||||
elif method == cls.EASE_IN_OUT:
|
|
||||||
index = torch.linspace(0, 1, length)
|
|
||||||
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unrecognized interpolation method '{method}'.")
|
|
||||||
if reverse:
|
|
||||||
weights = weights.flip(dims=(0,))
|
|
||||||
return weights
|
|
||||||
|
|
||||||
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
|
||||||
if not objects:
|
|
||||||
return objects
|
|
||||||
elif len(objects) <= 1:
|
|
||||||
return [x for x in objects]
|
|
||||||
# now that we know we have to sort, do it following these rules:
|
|
||||||
# a) if objects have same value of attribute, maintain their relative order
|
|
||||||
# b) perform sorting of the groups of objects with same attributes
|
|
||||||
unique_attrs = {}
|
|
||||||
for o in objects:
|
|
||||||
val_attr = getattr(o, attr)
|
|
||||||
attr_list: list = unique_attrs.get(val_attr, list())
|
|
||||||
attr_list.append(o)
|
|
||||||
if val_attr not in unique_attrs:
|
|
||||||
unique_attrs[val_attr] = attr_list
|
|
||||||
# now that we have the unique attr values grouped together in relative order, sort them by key
|
|
||||||
sorted_attrs = dict(sorted(unique_attrs.items()))
|
|
||||||
# now flatten out the dict into a list to return
|
|
||||||
sorted_list = []
|
|
||||||
for object_list in sorted_attrs.values():
|
|
||||||
sorted_list.extend(object_list)
|
|
||||||
return sorted_list
|
|
||||||
|
|
||||||
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
|
|
||||||
# if no hooks or is not a ModelPatcher for sampling, return empty dict
|
|
||||||
if hooks is None or model.is_clip:
|
|
||||||
return {}
|
|
||||||
if transformer_options is None:
|
|
||||||
transformer_options = {}
|
|
||||||
for hook in hooks.get_type(EnumHookType.TransformerOptions):
|
|
||||||
hook: TransformerOptionsHook
|
|
||||||
hook.on_apply_hooks(model, transformer_options)
|
|
||||||
return transformer_options
|
|
||||||
|
|
||||||
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
|
|
||||||
hook_group = HookGroup()
|
|
||||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
|
||||||
hook_group.add(hook)
|
|
||||||
hook.weights = lora
|
|
||||||
return hook_group
|
|
||||||
|
|
||||||
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
|
|
||||||
hook_group = HookGroup()
|
|
||||||
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
|
|
||||||
hook_group.add(hook)
|
|
||||||
patches_model = None
|
|
||||||
patches_clip = None
|
|
||||||
if weights_model is not None:
|
|
||||||
patches_model = {}
|
|
||||||
for key in weights_model:
|
|
||||||
patches_model[key] = ("model_as_lora", (weights_model[key],))
|
|
||||||
if weights_clip is not None:
|
|
||||||
patches_clip = {}
|
|
||||||
for key in weights_clip:
|
|
||||||
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
|
|
||||||
hook.weights = patches_model
|
|
||||||
hook.weights_clip = patches_clip
|
|
||||||
hook.need_weight_init = False
|
|
||||||
return hook_group
|
|
||||||
|
|
||||||
def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
|
|
||||||
if model is None:
|
|
||||||
return None
|
|
||||||
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
|
|
||||||
if discard_model_sampling:
|
|
||||||
# do not include ANY model_sampling components of the model that should act as a patch
|
|
||||||
for key in list(patches_model.keys()):
|
|
||||||
if key.startswith("model_sampling"):
|
|
||||||
patches_model.pop(key, None)
|
|
||||||
return patches_model
|
|
||||||
|
|
||||||
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
|
|
||||||
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
|
|
||||||
strength_model: float, strength_clip: float):
|
|
||||||
key_map = {}
|
|
||||||
if model is not None:
|
|
||||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
|
||||||
if clip is not None:
|
|
||||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
|
||||||
|
|
||||||
hook_group = HookGroup()
|
|
||||||
hook = WeightHook()
|
|
||||||
hook_group.add(hook)
|
|
||||||
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
|
|
||||||
if model is not None:
|
|
||||||
new_modelpatcher = model.clone()
|
|
||||||
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
|
|
||||||
else:
|
|
||||||
k = ()
|
|
||||||
new_modelpatcher = None
|
|
||||||
|
|
||||||
if clip is not None:
|
|
||||||
new_clip = clip.clone()
|
|
||||||
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
|
||||||
else:
|
|
||||||
k1 = ()
|
|
||||||
new_clip = None
|
|
||||||
k = set(k)
|
|
||||||
k1 = set(k1)
|
|
||||||
for x in loaded:
|
|
||||||
if (x not in k) and (x not in k1):
|
|
||||||
logging.warning(f"NOT LOADED {x}")
|
|
||||||
return (new_modelpatcher, new_clip, hook_group)
|
|
||||||
|
|
||||||
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
|
|
||||||
hooks_key = 'hooks'
|
|
||||||
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
|
|
||||||
if hooks_key not in values:
|
|
||||||
return
|
|
||||||
if hooks_key not in c_dict:
|
|
||||||
hooks_value = values.get(hooks_key, None)
|
|
||||||
if hooks_value is not None:
|
|
||||||
c_dict[hooks_key] = hooks_value
|
|
||||||
return
|
|
||||||
# otherwise, need to combine with minimum duplication via cache
|
|
||||||
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
|
|
||||||
cached_hooks = cache.get(hooks_tuple, None)
|
|
||||||
if cached_hooks is None:
|
|
||||||
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
|
|
||||||
cache[hooks_tuple] = new_hooks
|
|
||||||
c_dict[hooks_key] = new_hooks
|
|
||||||
else:
|
|
||||||
c_dict[hooks_key] = cache[hooks_tuple]
|
|
||||||
|
|
||||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
|
|
||||||
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
|
||||||
c = []
|
|
||||||
if cache is None:
|
|
||||||
cache = {}
|
|
||||||
for t in conditioning:
|
|
||||||
n = [t[0], t[1].copy()]
|
|
||||||
for k in values:
|
|
||||||
if append_hooks and k == 'hooks':
|
|
||||||
_combine_hooks_from_values(n[1], values, cache)
|
|
||||||
else:
|
|
||||||
n[1][k] = values[k]
|
|
||||||
c.append(n)
|
|
||||||
|
|
||||||
return c
|
|
||||||
|
|
||||||
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
|
|
||||||
if hooks is None:
|
|
||||||
return cond
|
|
||||||
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
|
|
||||||
|
|
||||||
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
|
|
||||||
if timestep_range is None:
|
|
||||||
return cond
|
|
||||||
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
|
|
||||||
"end_percent": timestep_range[1]})
|
|
||||||
|
|
||||||
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
|
|
||||||
if mask is None:
|
|
||||||
return cond
|
|
||||||
set_area_to_bounds = False
|
|
||||||
if set_cond_area != 'default':
|
|
||||||
set_area_to_bounds = True
|
|
||||||
if len(mask.shape) < 3:
|
|
||||||
mask = mask.unsqueeze(0)
|
|
||||||
return conditioning_set_values(cond, {'mask': mask,
|
|
||||||
'set_area_to_bounds': set_area_to_bounds,
|
|
||||||
'mask_strength': strength})
|
|
||||||
|
|
||||||
def combine_conditioning(conds: list):
|
|
||||||
combined_conds = []
|
|
||||||
for cond in conds:
|
|
||||||
combined_conds.extend(cond)
|
|
||||||
return combined_conds
|
|
||||||
|
|
||||||
def combine_with_new_conds(conds: list, new_conds: list):
|
|
||||||
combined_conds = []
|
|
||||||
for c, new_c in zip(conds, new_conds):
|
|
||||||
combined_conds.append(combine_conditioning([c, new_c]))
|
|
||||||
return combined_conds
|
|
||||||
|
|
||||||
def set_conds_props(conds: list, strength: float, set_cond_area: str,
|
|
||||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
|
||||||
final_conds = []
|
|
||||||
cache = {}
|
|
||||||
for c in conds:
|
|
||||||
# first, apply lora_hook to conditioning, if provided
|
|
||||||
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
|
|
||||||
# next, apply mask to conditioning
|
|
||||||
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
|
|
||||||
# apply timesteps, if present
|
|
||||||
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
|
|
||||||
# finally, apply mask to conditioning and store
|
|
||||||
final_conds.append(c)
|
|
||||||
return final_conds
|
|
||||||
|
|
||||||
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
|
|
||||||
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
|
||||||
combined_conds = []
|
|
||||||
cache = {}
|
|
||||||
for c, masked_c in zip(conds, new_conds):
|
|
||||||
# first, apply lora_hook to new conditioning, if provided
|
|
||||||
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
|
|
||||||
# next, apply mask to new conditioning, if provided
|
|
||||||
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
|
|
||||||
# apply timesteps, if present
|
|
||||||
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
|
|
||||||
# finally, combine with existing conditioning and store
|
|
||||||
combined_conds.append(combine_conditioning([c, masked_c]))
|
|
||||||
return combined_conds
|
|
||||||
|
|
||||||
def set_default_conds_and_combine(conds: list, new_conds: list,
|
|
||||||
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
|
|
||||||
combined_conds = []
|
|
||||||
cache = {}
|
|
||||||
for c, new_c in zip(conds, new_conds):
|
|
||||||
# first, apply lora_hook to new conditioning, if provided
|
|
||||||
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
|
|
||||||
# next, add default_cond key to cond so that during sampling, it can be identified
|
|
||||||
new_c = conditioning_set_values(new_c, {'default': True})
|
|
||||||
# apply timesteps, if present
|
|
||||||
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
|
|
||||||
# finally, combine with existing conditioning and store
|
|
||||||
combined_conds.append(combine_conditioning([c, new_c]))
|
|
||||||
return combined_conds
|
|
||||||
@ -1,160 +0,0 @@
|
|||||||
import torch
|
|
||||||
from comfy.text_encoders.bert import BertAttention
|
|
||||||
import comfy.model_management
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
|
||||||
|
|
||||||
|
|
||||||
class Dino2AttentionOutput(torch.nn.Module):
|
|
||||||
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.dense(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Dino2AttentionBlock(torch.nn.Module):
|
|
||||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
|
||||||
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
|
||||||
|
|
||||||
def forward(self, x, mask, optimized_attention):
|
|
||||||
return self.output(self.attention(x, mask, optimized_attention))
|
|
||||||
|
|
||||||
|
|
||||||
class LayerScale(torch.nn.Module):
|
|
||||||
def __init__(self, dim, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
|
||||||
|
|
||||||
class Dinov2MLP(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size: int, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
mlp_ratio = 4
|
|
||||||
hidden_features = int(hidden_size * mlp_ratio)
|
|
||||||
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
|
|
||||||
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
|
||||||
hidden_state = self.fc1(hidden_state)
|
|
||||||
hidden_state = torch.nn.functional.gelu(hidden_state)
|
|
||||||
hidden_state = self.fc2(hidden_state)
|
|
||||||
return hidden_state
|
|
||||||
|
|
||||||
class SwiGLUFFN(torch.nn.Module):
|
|
||||||
def __init__(self, dim, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
in_features = out_features = dim
|
|
||||||
hidden_features = int(dim * 4)
|
|
||||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
|
||||||
|
|
||||||
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
|
|
||||||
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.weights_in(x)
|
|
||||||
x1, x2 = x.chunk(2, dim=-1)
|
|
||||||
x = torch.nn.functional.silu(x1) * x2
|
|
||||||
return self.weights_out(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Dino2Block(torch.nn.Module):
|
|
||||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
|
||||||
super().__init__()
|
|
||||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
|
||||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
|
||||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
|
||||||
if use_swiglu_ffn:
|
|
||||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
|
||||||
else:
|
|
||||||
self.mlp = Dinov2MLP(dim, dtype, device, operations)
|
|
||||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
|
||||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, x, optimized_attention):
|
|
||||||
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
|
||||||
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Dino2Encoder(torch.nn.Module):
|
|
||||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
|
||||||
super().__init__()
|
|
||||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
|
||||||
for _ in range(num_layers)])
|
|
||||||
|
|
||||||
def forward(self, x, intermediate_output=None):
|
|
||||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
|
||||||
|
|
||||||
if intermediate_output is not None:
|
|
||||||
if intermediate_output < 0:
|
|
||||||
intermediate_output = len(self.layer) + intermediate_output
|
|
||||||
|
|
||||||
intermediate = None
|
|
||||||
for i, layer in enumerate(self.layer):
|
|
||||||
x = layer(x, optimized_attention)
|
|
||||||
if i == intermediate_output:
|
|
||||||
intermediate = x.clone()
|
|
||||||
return x, intermediate
|
|
||||||
|
|
||||||
|
|
||||||
class Dino2PatchEmbeddings(torch.nn.Module):
|
|
||||||
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.projection = operations.Conv2d(
|
|
||||||
in_channels=num_channels,
|
|
||||||
out_channels=dim,
|
|
||||||
kernel_size=patch_size,
|
|
||||||
stride=patch_size,
|
|
||||||
bias=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
|
||||||
return self.projection(pixel_values).flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
|
|
||||||
class Dino2Embeddings(torch.nn.Module):
|
|
||||||
def __init__(self, dim, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
patch_size = 14
|
|
||||||
image_size = 518
|
|
||||||
|
|
||||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
|
||||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
|
||||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
|
||||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
|
||||||
x = self.patch_embeddings(pixel_values)
|
|
||||||
# TODO: mask_token?
|
|
||||||
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
|
||||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Dinov2Model(torch.nn.Module):
|
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
|
||||||
super().__init__()
|
|
||||||
num_layers = config_dict["num_hidden_layers"]
|
|
||||||
dim = config_dict["hidden_size"]
|
|
||||||
heads = config_dict["num_attention_heads"]
|
|
||||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
|
||||||
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
|
||||||
|
|
||||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
|
||||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
|
||||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
|
||||||
x = self.embeddings(pixel_values)
|
|
||||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
|
||||||
x = self.layernorm(x)
|
|
||||||
pooled_output = x[:, 0, :]
|
|
||||||
return x, i, pooled_output, None
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
{
|
|
||||||
"attention_probs_dropout_prob": 0.0,
|
|
||||||
"drop_path_rate": 0.0,
|
|
||||||
"hidden_act": "gelu",
|
|
||||||
"hidden_dropout_prob": 0.0,
|
|
||||||
"hidden_size": 1536,
|
|
||||||
"image_size": 518,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"layer_norm_eps": 1e-06,
|
|
||||||
"layerscale_value": 1.0,
|
|
||||||
"mlp_ratio": 4,
|
|
||||||
"model_type": "dinov2",
|
|
||||||
"num_attention_heads": 24,
|
|
||||||
"num_channels": 3,
|
|
||||||
"num_hidden_layers": 40,
|
|
||||||
"patch_size": 14,
|
|
||||||
"qkv_bias": true,
|
|
||||||
"use_swiglu_ffn": true,
|
|
||||||
"image_mean": [0.485, 0.456, 0.406],
|
|
||||||
"image_std": [0.229, 0.224, 0.225]
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user