diff --git a/.ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat similarity index 100% rename from .ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat rename to .ci/windows_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index f7d30a9a4..a046ff9ea 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -12,7 +12,7 @@ on: description: 'CUDA version' required: true type: string - default: "126" + default: "128" python_minor: description: 'Python minor version' required: true @@ -22,7 +22,7 @@ on: description: 'Python patch version' required: true type: string - default: "9" + default: "10" jobs: @@ -36,7 +36,7 @@ jobs: - uses: actions/checkout@v4 with: ref: ${{ inputs.git_tag }} - fetch-depth: 0 + fetch-depth: 150 persist-credentials: false - uses: actions/cache/restore@v4 id: cache @@ -70,7 +70,7 @@ jobs: cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd - cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ + cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/ mkdir ComfyUI_windows_portable mv python_embeded ComfyUI_windows_portable @@ -85,12 +85,14 @@ jobs: cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z cd ComfyUI_windows_portable python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu + python_embeded/python.exe -s ./update/update.py ComfyUI/ + ls - name: Upload binaries to release diff --git a/.github/workflows/update-api-stubs.yml b/.github/workflows/update-api-stubs.yml new file mode 100644 index 000000000..2ae99b673 --- /dev/null +++ b/.github/workflows/update-api-stubs.yml @@ -0,0 +1,47 @@ +name: Generate Pydantic Stubs from api.comfy.org + +on: + schedule: + - cron: '0 0 * * 1' + workflow_dispatch: + +jobs: + generate-models: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install 'datamodel-code-generator[http]' + + - name: Generate API models + run: | + datamodel-codegen --use-subclass-enum --url https://api.comfy.org/openapi --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel + + - name: Check for changes + id: git-check + run: | + git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT + + - name: Create Pull Request + if: steps.git-check.outputs.changes == 'true' + uses: peter-evans/create-pull-request@v5 + with: + commit-message: 'chore: update API models from OpenAPI spec' + title: 'Update API models from api.comfy.org' + body: | + This PR updates the API models based on the latest api.comfy.org OpenAPI specification. + + Generated automatically by the a Github workflow. + branch: update-api-stubs + delete-branch: true + base: main diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index 7a8ec5782..dfdb96d50 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -17,7 +17,7 @@ on: description: 'cuda version' required: true type: string - default: "126" + default: "128" python_minor: description: 'python minor version' @@ -29,7 +29,7 @@ on: description: 'python patch version' required: true type: string - default: "9" + default: "10" # push: # branches: # - master diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 24599249a..eb5ed9c91 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -56,7 +56,7 @@ jobs: cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd - cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ + cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/ mkdir ComfyUI_windows_portable_nightly_pytorch mv python_embeded ComfyUI_windows_portable_nightly_pytorch diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 416544f71..3926a65f3 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -7,7 +7,7 @@ on: description: 'cuda version' required: true type: string - default: "126" + default: "128" python_minor: description: 'python minor version' @@ -19,7 +19,7 @@ on: description: 'python patch version' required: true type: string - default: "9" + default: "10" # push: # branches: # - master @@ -50,7 +50,7 @@ jobs: - uses: actions/checkout@v4 with: - fetch-depth: 0 + fetch-depth: 150 persist-credentials: false - shell: bash run: | @@ -67,7 +67,7 @@ jobs: cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd - cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ + cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/ mkdir ComfyUI_windows_portable mv python_embeded ComfyUI_windows_portable @@ -82,12 +82,14 @@ jobs: cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z cd ComfyUI_windows_portable python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu + python_embeded/python.exe -s ./update/update.py ComfyUI/ + ls - name: Upload binaries to release diff --git a/CODEOWNERS b/CODEOWNERS index eeec358de..013ea8622 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -5,19 +5,20 @@ # Inlined the team members for now. # Maintainers -*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink +*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne # Python web server -/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata -/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata -/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata +/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne +/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne +/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne -# Extra nodes -/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink +# Node developers +/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne +/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne diff --git a/README.md b/README.md index a807ea9d6..0f39cfce2 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,6 @@ Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, ## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/) See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/). - ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Image Models @@ -62,6 +61,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/) - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) + - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - Video Models - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) @@ -69,6 +69,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) +- 3D Models + - [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2) - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. @@ -96,6 +98,23 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) +## Release Process + +ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories: + +1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** + - Releases a new stable version (e.g., v0.7.0) + - Serves as the foundation for the desktop release + +2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)** + - Builds a new release using the latest stable core version + - Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0) + +3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)** + - Weekly frontend updates are merged into the core repository + - Features are frozen for the upcoming core release + - Development continues for the next release cycle + ## Shortcuts | Keybind | Explanation | @@ -146,8 +165,6 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you If you have trouble extracting it, right click the file -> properties -> unblock -If you have a 50 series Blackwell card like a 5090 or 5080 see [this discussion thread](https://github.com/comfyanonymous/ComfyUI/discussions/6643) - #### How do I share models between another UI and ComfyUI? See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. @@ -213,9 +230,9 @@ Additional discussion and help can be found [here](https://github.com/comfyanony Nvidia users should install stable pytorch using this command: -```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126``` +```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128``` -This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements. +This is the command to install pytorch nightly instead which might have performance improvements. ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128``` diff --git a/app/app_settings.py b/app/app_settings.py index a545df92e..c7ac73bf6 100644 --- a/app/app_settings.py +++ b/app/app_settings.py @@ -9,8 +9,14 @@ class AppSettings(): self.user_manager = user_manager def get_settings(self, request): - file = self.user_manager.get_request_user_filepath( - request, "comfy.settings.json") + try: + file = self.user_manager.get_request_user_filepath( + request, + "comfy.settings.json" + ) + except KeyError as e: + logging.error("User settings not found.") + raise web.HTTPUnauthorized() from e if os.path.isfile(file): try: with open(file) as f: diff --git a/app/custom_node_manager.py b/app/custom_node_manager.py index 42b0d75ba..27d85d9ce 100644 --- a/app/custom_node_manager.py +++ b/app/custom_node_manager.py @@ -93,16 +93,20 @@ class CustomNodeManager: def add_routes(self, routes, webapp, loadedModules): + example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"] + @routes.get("/workflow_templates") async def get_workflow_templates(request): """Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted.""" - files = [ - file - for folder in folder_paths.get_folder_paths("custom_nodes") - for file in glob.glob( - os.path.join(folder, "*/example_workflows/*.json") - ) - ] + + files = [] + + for folder in folder_paths.get_folder_paths("custom_nodes"): + for folder_name in example_workflow_folder_names: + pattern = os.path.join(folder, f"*/{folder_name}/*.json") + matched_files = glob.glob(pattern) + files.extend(matched_files) + workflow_templates_dict = ( {} ) # custom_nodes folder name -> example workflow names @@ -118,15 +122,22 @@ class CustomNodeManager: # Serve workflow templates from custom nodes. for module_name, module_dir in loadedModules: - workflows_dir = os.path.join(module_dir, "example_workflows") - if os.path.exists(workflows_dir): - webapp.add_routes( - [ - web.static( - "/api/workflow_templates/" + module_name, workflows_dir - ) - ] - ) + for folder_name in example_workflow_folder_names: + workflows_dir = os.path.join(module_dir, folder_name) + + if os.path.exists(workflows_dir): + if folder_name != "example_workflows": + logging.warning( + "WARNING: Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'", + folder_name, module_name) + + webapp.add_routes( + [ + web.static( + "/api/workflow_templates/" + module_name, workflows_dir + ) + ] + ) @routes.get("/i18n") async def get_i18n(request): diff --git a/app/frontend_management.py b/app/frontend_management.py index 4b7dfbb98..7b7923b79 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -17,17 +17,26 @@ import requests from typing_extensions import NotRequired from comfy.cli_args import DEFAULT_VERSION_STRING +import app.logger # The path to the requirements.txt file req_path = Path(__file__).parents[1] / "requirements.txt" + def frontend_install_warning_message(): """The warning message to display when the frontend version is not up to date.""" extra = "" if sys.flags.no_user_site: extra = "-s " - return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem" + return f""" +Please install the updated requirements.txt file by running: +{sys.executable} {extra}-m pip install -r {req_path} + +This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. + +If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem +""".strip() def check_frontend_version(): @@ -42,7 +51,17 @@ def check_frontend_version(): with open(req_path, "r", encoding="utf-8") as f: required_frontend = parse_version(f.readline().split("=")[-1]) if frontend_version < required_frontend: - logging.warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message())) + app.logger.log_startup_warning( + f""" +________________________________________________________________________ +WARNING WARNING WARNING WARNING WARNING + +Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}. + +{frontend_install_warning_message()} +________________________________________________________________________ +""".strip() + ) else: logging.info("ComfyUI frontend version: {}".format(frontend_version_str)) except Exception as e: @@ -149,11 +168,43 @@ class FrontendManager: def default_frontend_path(cls) -> str: try: import comfyui_frontend_package + return str(importlib.resources.files(comfyui_frontend_package) / "static") except ImportError: - logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n") + logging.error( + f""" +********** ERROR *********** + +comfyui-frontend-package is not installed. + +{frontend_install_warning_message()} + +********** ERROR *********** +""".strip() + ) sys.exit(-1) + @classmethod + def templates_path(cls) -> str: + try: + import comfyui_workflow_templates + + return str( + importlib.resources.files(comfyui_workflow_templates) / "templates" + ) + except ImportError: + logging.error( + f""" +********** ERROR *********** + +comfyui-workflow-templates is not installed. + +{frontend_install_warning_message()} + +********** ERROR *********** +""".strip() + ) + @classmethod def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ @@ -174,7 +225,9 @@ class FrontendManager: return match_result.group(1), match_result.group(2), match_result.group(3) @classmethod - def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str: + def init_frontend_unsafe( + cls, version_string: str, provider: Optional[FrontEndProvider] = None + ) -> str: """ Initializes the frontend for the specified version. @@ -196,12 +249,20 @@ class FrontendManager: repo_owner, repo_name, version = cls.parse_version_string(version_string) if version.startswith("v"): - expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v")) + expected_path = str( + Path(cls.CUSTOM_FRONTENDS_ROOT) + / f"{repo_owner}_{repo_name}" + / version.lstrip("v") + ) if os.path.exists(expected_path): - logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}") + logging.info( + f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}" + ) return expected_path - logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...") + logging.info( + f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..." + ) provider = provider or FrontEndProvider(repo_owner, repo_name) release = provider.get_release(version) diff --git a/app/logger.py b/app/logger.py index 9e9f84ccf..3d26d98fe 100644 --- a/app/logger.py +++ b/app/logger.py @@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool logger.addHandler(stdout_handler) logger.addHandler(stream_handler) + + +STARTUP_WARNINGS = [] + + +def log_startup_warning(msg): + logging.warning(msg) + STARTUP_WARNINGS.append(msg) + + +def print_startup_warnings(): + for s in STARTUP_WARNINGS: + logging.warning(s) + STARTUP_WARNINGS.clear() diff --git a/app/user_manager.py b/app/user_manager.py index e7381e621..d31da5b9b 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -197,6 +197,112 @@ class UserManager(): return web.json_response(results) + @routes.get("/v2/userdata") + async def list_userdata_v2(request): + """ + List files and directories in a user's data directory. + + This endpoint provides a structured listing of contents within a specified + subdirectory of the user's data storage. + + Query Parameters: + - path (optional): The relative path within the user's data directory + to list. Defaults to the root (''). + + Returns: + - 400: If the requested path is invalid, outside the user's data directory, or is not a directory. + - 404: If the requested path does not exist. + - 403: If the user is invalid. + - 500: If there is an error reading the directory contents. + - 200: JSON response containing a list of file and directory objects. + Each object includes: + - name: The name of the file or directory. + - type: 'file' or 'directory'. + - path: The relative path from the user's data root. + - size (for files): The size in bytes. + - modified (for files): The last modified timestamp (Unix epoch). + """ + requested_rel_path = request.rel_url.query.get('path', '') + + # URL-decode the path parameter + try: + requested_rel_path = parse.unquote(requested_rel_path) + except Exception as e: + logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}") + return web.Response(status=400, text="Invalid characters in path parameter") + + + # Check user validity and get the absolute path for the requested directory + try: + base_user_path = self.get_request_user_filepath(request, None, create_dir=False) + + if requested_rel_path: + target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False) + else: + target_abs_path = base_user_path + + except KeyError as e: + # Invalid user detected by get_request_user_id inside get_request_user_filepath + logging.warning(f"Access denied for user: {e}") + return web.Response(status=403, text="Invalid user specified in request") + + + if not target_abs_path: + # Path traversal or other issue detected by get_request_user_filepath + return web.Response(status=400, text="Invalid path requested") + + # Handle cases where the user directory or target path doesn't exist + if not os.path.exists(target_abs_path): + # Check if it's the base user directory that's missing (new user case) + if target_abs_path == base_user_path: + # It's okay if the base user directory doesn't exist yet, return empty list + return web.json_response([]) + else: + # A specific subdirectory was requested but doesn't exist + return web.Response(status=404, text="Requested path not found") + + if not os.path.isdir(target_abs_path): + return web.Response(status=400, text="Requested path is not a directory") + + results = [] + try: + for root, dirs, files in os.walk(target_abs_path, topdown=True): + # Process directories + for dir_name in dirs: + dir_path = os.path.join(root, dir_name) + rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/') + results.append({ + "name": dir_name, + "path": rel_path, + "type": "directory" + }) + + # Process files + for file_name in files: + file_path = os.path.join(root, file_name) + rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/') + entry_info = { + "name": file_name, + "path": rel_path, + "type": "file" + } + try: + stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk + entry_info["size"] = stats.st_size + entry_info["modified"] = stats.st_mtime + except OSError as stat_error: + logging.warning(f"Could not stat file {file_path}: {stat_error}") + pass # Include file with available info + results.append(entry_info) + except OSError as e: + logging.error(f"Error listing directory {target_abs_path}: {e}") + return web.Response(status=500, text="Error reading directory contents") + + # Sort results alphabetically, directories first then files + results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower())) + + return web.json_response(results) + def get_user_data_path(request, check_exists = False, param = "file"): file = request.match_info.get(param, None) if not file: diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a864205be..f89a7aab4 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -66,6 +66,7 @@ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diff fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16") fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") +fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.") fpvae_group = parser.add_mutually_exclusive_group() fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") @@ -79,6 +80,7 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.") fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") +fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.") parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") @@ -100,12 +102,14 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi cache_group = parser.add_mutually_exclusive_group() cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") +cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") +attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") @@ -124,6 +128,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") +parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.") parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") @@ -133,8 +138,9 @@ parser.add_argument("--deterministic", action="store_true", help="Make pytorch u class PerformanceFeature(enum.Enum): Fp16Accumulation = "fp16_accumulation" Fp8MatrixMultiplication = "fp8_matrix_mult" + CublasOps = "cublas_ops" -parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult") +parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 297b3bca3..00aab9164 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -9,6 +9,7 @@ import comfy.model_patcher import comfy.model_management import comfy.utils import comfy.clip_model +import comfy.image_encoders.dino2 class Output: def __getitem__(self, key): @@ -17,6 +18,7 @@ class Output: setattr(self, key, item) def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): + image = image[:, :, :, :3] if image.shape[3] > 3 else image mean = torch.tensor(mean, device=image.device, dtype=image.dtype) std = torch.tensor(std, device=image.device, dtype=image.dtype) image = image.movedim(-1, 1) @@ -34,6 +36,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s image = torch.clip((255. * image), 0, 255).round() / 255.0 return (image - mean.view([3,1,1])) / std.view([3,1,1]) +IMAGE_ENCODERS = { + "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection, + "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, + "dinov2": comfy.image_encoders.dino2.Dinov2Model, +} + class ClipVisionModel(): def __init__(self, json_config): with open(json_config) as f: @@ -42,10 +50,11 @@ class ClipVisionModel(): self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) + model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model")) self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) - self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast) + self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) @@ -102,15 +111,21 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: + embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0] if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") - elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: + if embed_shape == 729: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") + elif embed_shape == 1024: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") + elif embed_shape == 577: if "multi_modal_projector.linear_1.bias" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json") else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") + elif "embeddings.patch_embeddings.projection.weight" in sd: + json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") else: return None diff --git a/comfy/clip_vision_siglip_512.json b/comfy/clip_vision_siglip_512.json new file mode 100644 index 000000000..7fb93ce15 --- /dev/null +++ b/comfy/clip_vision_siglip_512.json @@ -0,0 +1,13 @@ +{ + "num_channels": 3, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 512, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 16, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5] +} diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 4967de716..2ffc9c021 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -1,7 +1,8 @@ """Comfy-specific type hinting""" from __future__ import annotations -from typing import Literal, TypedDict +from typing import Literal, TypedDict, Optional +from typing_extensions import NotRequired from abc import ABC, abstractmethod from enum import Enum @@ -26,6 +27,7 @@ class IO(StrEnum): BOOLEAN = "BOOLEAN" INT = "INT" FLOAT = "FLOAT" + COMBO = "COMBO" CONDITIONING = "CONDITIONING" SAMPLER = "SAMPLER" SIGMAS = "SIGMAS" @@ -46,6 +48,7 @@ class IO(StrEnum): FACE_ANALYSIS = "FACE_ANALYSIS" BBOX = "BBOX" SEGS = "SEGS" + VIDEO = "VIDEO" ANY = "*" """Always matches any type, but at a price. @@ -66,6 +69,7 @@ class IO(StrEnum): b = frozenset(value.split(",")) return not (b.issubset(a) or a.issubset(b)) + class RemoteInputOptions(TypedDict): route: str """The route to the remote source.""" @@ -80,6 +84,14 @@ class RemoteInputOptions(TypedDict): refresh: int """The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed.""" + +class MultiSelectOptions(TypedDict): + placeholder: NotRequired[str] + """The placeholder text to display in the multi-select widget when no items are selected.""" + chip: NotRequired[bool] + """Specifies whether to use chips instead of comma separated values for the multi-select widget.""" + + class InputTypeOptions(TypedDict): """Provides type hinting for the return type of the INPUT_TYPES node function. @@ -88,68 +100,94 @@ class InputTypeOptions(TypedDict): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes """ - default: bool | str | float | int | list | tuple + default: NotRequired[bool | str | float | int | list | tuple] """The default value of the widget""" - defaultInput: bool - """Defaults to an input slot rather than a widget""" - forceInput: bool - """`defaultInput` and also don't allow converting to a widget""" - lazy: bool + defaultInput: NotRequired[bool] + """@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist. + - defaultInput on required inputs should be dropped. + - defaultInput on optional inputs should be replaced with forceInput. + Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364 + """ + forceInput: NotRequired[bool] + """Forces the input to be an input slot rather than a widget even a widget is available for the input type.""" + lazy: NotRequired[bool] """Declares that this input uses lazy evaluation""" - rawLink: bool + rawLink: NotRequired[bool] """When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", ]`). Designed for node expansion.""" - tooltip: str + tooltip: NotRequired[str] """Tooltip for the input (or widget), shown on pointer hover""" + socketless: NotRequired[bool] + """All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created. + Available from frontend v1.17.5 + Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548 + """ + widgetType: NotRequired[str] + """Specifies a type to be used for widget initialization if different from the input type. + Available from frontend v1.18.0 + https://github.com/Comfy-Org/ComfyUI_frontend/pull/3550""" # class InputTypeNumber(InputTypeOptions): # default: float | int - min: float + min: NotRequired[float] """The minimum value of a number (``FLOAT`` | ``INT``)""" - max: float + max: NotRequired[float] """The maximum value of a number (``FLOAT`` | ``INT``)""" - step: float + step: NotRequired[float] """The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)""" - round: float + round: NotRequired[float] """Floats are rounded by this value (``FLOAT``)""" # class InputTypeBoolean(InputTypeOptions): # default: bool - label_on: str + label_on: NotRequired[str] """The label to use in the UI when the bool is True (``BOOLEAN``)""" - label_off: str + label_off: NotRequired[str] """The label to use in the UI when the bool is False (``BOOLEAN``)""" # class InputTypeString(InputTypeOptions): # default: str - multiline: bool + multiline: NotRequired[bool] """Use a multiline text box (``STRING``)""" - placeholder: str + placeholder: NotRequired[str] """Placeholder text to display in the UI when empty (``STRING``)""" # Deprecated: # defaultVal: str - dynamicPrompts: bool + dynamicPrompts: NotRequired[bool] """Causes the front-end to evaluate dynamic prompts (``STRING``)""" # class InputTypeCombo(InputTypeOptions): - image_upload: bool + image_upload: NotRequired[bool] """Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`.""" - image_folder: Literal["input", "output", "temp"] + image_folder: NotRequired[Literal["input", "output", "temp"]] """Specifies which folder to get preview images from if the input has the ``image_upload`` flag. """ - remote: RemoteInputOptions - """Specifies the configuration for a remote input.""" - control_after_generate: bool + remote: NotRequired[RemoteInputOptions] + """Specifies the configuration for a remote input. + Available after ComfyUI frontend v1.9.7 + https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422""" + control_after_generate: NotRequired[bool] """Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types.""" + options: NotRequired[list[str | int | float]] + """COMBO type only. Specifies the selectable options for the combo widget. + Prefer: + ["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}] + Over: + [["Option 1", "Option 2", "Option 3"]] + """ + multi_select: NotRequired[MultiSelectOptions] + """COMBO type only. Specifies the configuration for a multi-select widget. + Available after ComfyUI frontend v1.13.4 + https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987""" class HiddenInputTypeDict(TypedDict): """Provides type hinting for the hidden entry of node INPUT_TYPES.""" - node_id: Literal["UNIQUE_ID"] + node_id: NotRequired[Literal["UNIQUE_ID"]] """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" - unique_id: Literal["UNIQUE_ID"] + unique_id: NotRequired[Literal["UNIQUE_ID"]] """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" - prompt: Literal["PROMPT"] + prompt: NotRequired[Literal["PROMPT"]] """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" - extra_pnginfo: Literal["EXTRA_PNGINFO"] + extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]] """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" - dynprompt: Literal["DYNPROMPT"] + dynprompt: NotRequired[Literal["DYNPROMPT"]] """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" @@ -159,11 +197,11 @@ class InputTypeDict(TypedDict): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs """ - required: dict[str, tuple[IO, InputTypeOptions]] + required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]] """Describes all inputs that must be connected for the node to execute.""" - optional: dict[str, tuple[IO, InputTypeOptions]] + optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]] """Describes inputs which do not need to be connected.""" - hidden: HiddenInputTypeDict + hidden: NotRequired[HiddenInputTypeDict] """Offers advanced functionality and server-client communication. Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs @@ -196,6 +234,8 @@ class ComfyNodeABC(ABC): """Flags a node as experimental, informing users that it may change or not work as expected.""" DEPRECATED: bool """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" + API_NODE: Optional[bool] + """Flags a node as an API node.""" @classmethod @abstractmethod @@ -234,7 +274,7 @@ class ComfyNodeABC(ABC): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing """ - OUTPUT_IS_LIST: tuple[bool] + OUTPUT_IS_LIST: tuple[bool, ...] """A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items. Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list. @@ -253,7 +293,7 @@ class ComfyNodeABC(ABC): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing """ - RETURN_TYPES: tuple[IO] + RETURN_TYPES: tuple[IO, ...] """A tuple representing the outputs of this node. Usage:: @@ -262,12 +302,12 @@ class ComfyNodeABC(ABC): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types """ - RETURN_NAMES: tuple[str] + RETURN_NAMES: tuple[str, ...] """The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")`` Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names """ - OUTPUT_TOOLTIPS: tuple[str] + OUTPUT_TOOLTIPS: tuple[str, ...] """A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`.""" FUNCTION: str """The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"` diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ceb24c852..11483e21d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -736,6 +736,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): return control def load_controlnet(ckpt_path, model=None, model_options={}): + model_options = model_options.copy() if "global_average_pooling" not in model_options: filename = os.path.splitext(ckpt_path)[0] if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py new file mode 100644 index 000000000..976f98c65 --- /dev/null +++ b/comfy/image_encoders/dino2.py @@ -0,0 +1,141 @@ +import torch +from comfy.text_encoders.bert import BertAttention +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention_for_device + + +class Dino2AttentionOutput(torch.nn.Module): + def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations): + super().__init__() + self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device) + + def forward(self, x): + return self.dense(x) + + +class Dino2AttentionBlock(torch.nn.Module): + def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations): + super().__init__() + self.attention = BertAttention(embed_dim, heads, dtype, device, operations) + self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations) + + def forward(self, x, mask, optimized_attention): + return self.output(self.attention(x, mask, optimized_attention)) + + +class LayerScale(torch.nn.Module): + def __init__(self, dim, dtype, device, operations): + super().__init__() + self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + + def forward(self, x): + return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype) + + +class SwiGLUFFN(torch.nn.Module): + def __init__(self, dim, dtype, device, operations): + super().__init__() + in_features = out_features = dim + hidden_features = int(dim * 4) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype) + self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype) + + def forward(self, x): + x = self.weights_in(x) + x1, x2 = x.chunk(2, dim=-1) + x = torch.nn.functional.silu(x1) * x2 + return self.weights_out(x) + + +class Dino2Block(torch.nn.Module): + def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations): + super().__init__() + self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations) + self.layer_scale1 = LayerScale(dim, dtype, device, operations) + self.layer_scale2 = LayerScale(dim, dtype, device, operations) + self.mlp = SwiGLUFFN(dim, dtype, device, operations) + self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) + + def forward(self, x, optimized_attention): + x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention)) + x = x + self.layer_scale2(self.mlp(self.norm2(x))) + return x + + +class Dino2Encoder(torch.nn.Module): + def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations): + super().__init__() + self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)]) + + def forward(self, x, intermediate_output=None): + optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) + + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layer) + intermediate_output + + intermediate = None + for i, l in enumerate(self.layer): + x = l(x, optimized_attention) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class Dino2PatchEmbeddings(torch.nn.Module): + def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None): + super().__init__() + self.projection = operations.Conv2d( + in_channels=num_channels, + out_channels=dim, + kernel_size=patch_size, + stride=patch_size, + bias=True, + dtype=dtype, + device=device + ) + + def forward(self, pixel_values): + return self.projection(pixel_values).flatten(2).transpose(1, 2) + + +class Dino2Embeddings(torch.nn.Module): + def __init__(self, dim, dtype, device, operations): + super().__init__() + patch_size = 14 + image_size = 518 + + self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations) + self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device)) + self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) + self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) + + def forward(self, pixel_values): + x = self.patch_embeddings(pixel_values) + # TODO: mask_token? + x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1) + x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype) + return x + + +class Dinov2Model(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + num_layers = config_dict["num_hidden_layers"] + dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + layer_norm_eps = config_dict["layer_norm_eps"] + + self.embeddings = Dino2Embeddings(dim, dtype, device, operations) + self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations) + self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) + + def forward(self, pixel_values, attention_mask=None, intermediate_output=None): + x = self.embeddings(pixel_values) + x, i = self.encoder(x, intermediate_output=intermediate_output) + x = self.layernorm(x) + pooled_output = x[:, 0, :] + return x, i, pooled_output, None diff --git a/comfy/image_encoders/dino2_giant.json b/comfy/image_encoders/dino2_giant.json new file mode 100644 index 000000000..f6076a4dc --- /dev/null +++ b/comfy/image_encoders/dino2_giant.json @@ -0,0 +1,21 @@ +{ + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 1536, + "image_size": 518, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "model_type": "dinov2", + "num_attention_heads": 24, + "num_channels": 3, + "num_hidden_layers": 40, + "patch_size": 14, + "qkv_bias": true, + "use_swiglu_ffn": true, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225] +} diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 456679989..77ef748e8 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N if len(sigmas) <= 1: return x + extra_args = {} if extra_args is None else extra_args sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() seed = extra_args.get("seed", None) noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler - extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() @@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if solver_type not in {'heun', 'midpoint'}: raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler - extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) old_denoised = None @@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if len(sigmas) <= 1: return x + extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler - extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) denoised_1, denoised_2 = None, None @@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): if len(sigmas) <= 1: return x - + extra_args = {} if extra_args is None else extra_args sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) @@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): if len(sigmas) <= 1: return x - + extra_args = {} if extra_args is None else extra_args sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) @@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): if len(sigmas) <= 1: return x - + extra_args = {} if extra_args is None else extra_args sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) @@ -1345,24 +1345,202 @@ def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, cal return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True) @torch.no_grad() -def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.): +def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False): """Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) old_d = None + uncond_denoised = None + def post_cfg_function(args): + nonlocal uncond_denoised + uncond_denoised = args["uncond_denoised"] + return args["denoised"] + + if cfg_pp: + model_options = extra_args.get("model_options", {}).copy() + extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) - d = to_d(x, sigmas[i], denoised) + if cfg_pp: + d = to_d(x, sigmas[i], uncond_denoised) + else: + d = to_d(x, sigmas[i], denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) dt = sigmas[i + 1] - sigmas[i] if i == 0: # Euler method - x = x + d * dt + if cfg_pp: + x = denoised + d * sigmas[i + 1] + else: + x = x + d * dt else: # Gradient estimation - d_bar = ge_gamma * d + (1 - ge_gamma) * old_d - x = x + d_bar * dt + if cfg_pp: + d_bar = (ge_gamma - 1) * (d - old_d) + x = denoised + d * sigmas[i + 1] + d_bar * dt + else: + d_bar = ge_gamma * d + (1 - ge_gamma) * old_d + x = x + d_bar * dt old_d = d return x + +@torch.no_grad() +def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.): + return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True) + +@torch.no_grad() +def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3): + """ + Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169. + Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py. + """ + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + def default_noise_scaler(sigma): + return sigma * ((sigma ** 0.3).exp() + 10.0) + noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler + num_integration_points = 200.0 + point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device) + + old_denoised = None + old_denoised_d = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + stage_used = min(max_stage, i + 1) + if sigmas[i + 1] == 0: + x = denoised + elif stage_used == 1: + r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i]) + x = r * x + (1 - r) * denoised + else: + r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i]) + x = r * x + (1 - r) * denoised + + dt = sigmas[i + 1] - sigmas[i] + sigma_step_size = -dt / num_integration_points + sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size + scaled_pos = noise_scaler(sigma_pos) + + # Stage 2 + s = torch.sum(1 / scaled_pos) * sigma_step_size + denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1]) + x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d + + if stage_used >= 3: + # Stage 3 + s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size + denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2) + x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u + old_denoised_d = denoised_d + + if s_noise != 0 and sigmas[i + 1] > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0) + old_denoised = denoised + return x + +@torch.no_grad() +def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): + ''' + SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2 + Arxiv: https://arxiv.org/abs/2305.14267 + ''' + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + inject_noise = eta > 0 and s_noise > 0 + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + x = denoised + else: + t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() + h = t_next - t + h_eta = h * (eta + 1) + s = t + r * h + fac = 1 / (2 * r) + sigma_s = s.neg().exp() + + coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() + if inject_noise: + noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt() + noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() + noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1]) + + # Step 1 + x_2 = (coeff_1 + 1) * x - coeff_1 * denoised + if inject_noise: + x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise + denoised_2 = model(x_2, sigma_s * s_in, **extra_args) + + # Step 2 + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = (coeff_2 + 1) * x - coeff_2 * denoised_d + if inject_noise: + x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + return x + +@torch.no_grad() +def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): + ''' + SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3 + Arxiv: https://arxiv.org/abs/2305.14267 + ''' + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + inject_noise = eta > 0 and s_noise > 0 + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + x = denoised + else: + t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() + h = t_next - t + h_eta = h * (eta + 1) + s_1 = t + r_1 * h + s_2 = t + r_2 * h + sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp() + + coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() + if inject_noise: + noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() + noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt() + noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() + noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) + + # Step 1 + x_2 = (coeff_1 + 1) * x - coeff_1 * denoised + if inject_noise: + x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + + # Step 2 + x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) + if inject_noise: + x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) + + # Step 3 + x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) + if inject_noise: + x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise + return x diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 622c1df54..556c39512 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -456,3 +456,13 @@ class Wan21(LatentFormat): latents_mean = self.latents_mean.to(latent.device, latent.dtype) latents_std = self.latents_std.to(latent.device, latent.dtype) return latent * latents_std / self.scale_factor + latents_mean + +class Hunyuan3Dv2(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + scale_factor = 0.9990943042622529 + +class Hunyuan3Dv2mini(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + scale_factor = 1.0188137142395404 diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py new file mode 100644 index 000000000..35da91ee2 --- /dev/null +++ b/comfy/ldm/chroma/layers.py @@ -0,0 +1,183 @@ +import torch +from torch import Tensor, nn + +from comfy.ldm.flux.math import attention +from comfy.ldm.flux.layers import ( + MLPEmbedder, + RMSNorm, + QKNorm, + SelfAttention, + ModulationOut, +) + + + +class ChromaModulationOut(ModulationOut): + @classmethod + def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut: + return cls( + shift=tensor[:, offset : offset + 1, :], + scale=tensor[:, offset + 1 : offset + 2, :], + gate=tensor[:, offset + 2 : offset + 3, :], + ) + + + + +class Approximator(nn.Module): + def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None): + super().__init__() + self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) + self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) + self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def forward(self, x: Tensor) -> Tensor: + x = self.in_proj(x) + + for layer, norms in zip(self.layers, self.norms): + x = x + layer(norms(x)) + + x = self.out_proj(x) + + return x + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + + self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.img_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) + + self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + + self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.txt_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.flipped_img_txt = flipped_img_txt + + def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None): + (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + attn = attention(torch.cat((txt_q, img_q), dim=2), + torch.cat((txt_k, img_k), dim=2), + torch.cat((txt_v, img_v), dim=2), + pe=pe, mask=attn_mask) + + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + + if txt.dtype == torch.float16: + txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) + + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float = None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) + # proj and mlp_out + self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) + + self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) + + self.hidden_size = hidden_size + self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.mlp_act = nn.GELU(approximate="tanh") + + def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor: + mod = vec + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe, mask=attn_mask) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + x += mod.gate * output + if x.dtype == torch.float16: + x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = vec + shift = shift.squeeze(1) + scale = scale.squeeze(1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py new file mode 100644 index 000000000..636748fc5 --- /dev/null +++ b/comfy/ldm/chroma/model.py @@ -0,0 +1,271 @@ +#Original code can be found on: https://github.com/black-forest-labs/flux + +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange, repeat +import comfy.ldm.common_dit + +from comfy.ldm.flux.layers import ( + EmbedND, + timestep_embedding, +) + +from .layers import ( + DoubleStreamBlock, + LastLayer, + SingleStreamBlock, + Approximator, + ChromaModulationOut, +) + + +@dataclass +class ChromaParams: + in_channels: int + out_channels: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list + theta: int + patch_size: int + qkv_bias: bool + in_dim: int + out_dim: int + hidden_dim: int + n_layers: int + + + + +class Chroma(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + params = ChromaParams(**kwargs) + self.params = params + self.patch_size = params.patch_size + self.in_channels = params.in_channels + self.out_channels = params.out_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.in_dim = params.in_dim + self.out_dim = params.out_dim + self.hidden_dim = params.hidden_dim + self.n_layers = params.n_layers + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + # set as nn identity for now, will overwrite it later. + self.distilled_guidance_layer = Approximator( + in_dim=self.in_dim, + hidden_dim=self.hidden_dim, + out_dim=self.out_dim, + n_layers=self.n_layers, + dtype=dtype, device=device, operations=operations + ) + + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + dtype=dtype, device=device, operations=operations + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + for _ in range(params.depth_single_blocks) + ] + ) + + if final_layer: + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) + + self.skip_mmdit = [] + self.skip_dit = [] + self.lite = False + + def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0): + # This function slices up the modulations tensor which has the following layout: + # single : num_single_blocks * 3 elements + # double_img : num_double_blocks * 6 elements + # double_txt : num_double_blocks * 6 elements + # final : 2 elements + if block_type == "final": + return (tensor[:, -2:-1, :], tensor[:, -1:, :]) + single_block_count = self.params.depth_single_blocks + double_block_count = self.params.depth + offset = 3 * idx + if block_type == "single": + return ChromaModulationOut.from_offset(tensor, offset) + # Double block modulations are 6 elements so we double 3 * idx. + offset *= 2 + if block_type in {"double_img", "double_txt"}: + # Advance past the single block modulations. + offset += 3 * single_block_count + if block_type == "double_txt": + # Advance past the double block img modulations. + offset += 6 * double_block_count + return ( + ChromaModulationOut.from_offset(tensor, offset), + ChromaModulationOut.from_offset(tensor, offset + 3), + ) + raise ValueError("Bad block_type") + + + def forward_orig( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + guidance: Tensor = None, + control = None, + transformer_options={}, + attn_mask: Tensor = None, + ) -> Tensor: + patches_replace = transformer_options.get("patches_replace", {}) + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + + # distilled vector guidance + mod_index_length = 344 + distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype) + # guidance = guidance * + distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype) + + # get all modulation index + modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype) + # and we need to broadcast timestep and guidance along too + timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype) + + mod_vectors = self.distilled_guidance_layer(input_vec) + + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.double_blocks): + if i not in self.skip_mmdit: + double_mod = ( + self.get_modulations(mod_vectors, "double_img", idx=i), + self.get_modulations(mod_vectors, "double_txt", idx=i), + ) + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": double_mod, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=double_mod, + pe=pe, + attn_mask=attn_mask) + + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + img += add + + img = torch.cat((txt, img), 1) + + for i, block in enumerate(self.single_blocks): + if i not in self.skip_dit: + single_mod = self.get_modulations(mod_vectors, "single", idx=i) + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("single_block", i)]({"img": img, + "vec": single_mod, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask) + + if control is not None: # Controlnet + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + img[:, txt.shape[1] :, ...] += add + + img = img[:, txt.shape[1] :, ...] + final_mod = self.get_modulations(mod_vectors, "final") + img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) + return img + + def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): + bs, c, h, w = x.shape + patch_size = 2 + x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index e0f3057f7..f7f56b72c 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,5 +1,6 @@ import torch -import comfy.ops +import comfy.rmsnorm + def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()): @@ -11,20 +12,5 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): return torch.nn.functional.pad(img, pad, mode=padding_mode) -try: - rms_norm_torch = torch.nn.functional.rms_norm -except: - rms_norm_torch = None -def rms_norm(x, weight=None, eps=1e-6): - if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): - if weight is None: - return rms_norm_torch(x, (x.shape[-1],), eps=eps) - else: - return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) - else: - r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) - if weight is None: - return r - else: - return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device) +rms_norm = comfy.rmsnorm.rms_norm diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 36b67931c..3e0978176 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -10,10 +10,11 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: q_shape = q.shape k_shape = k.shape - q = q.float().reshape(*q.shape[:-1], -1, 1, 2) - k = k.float().reshape(*k.shape[:-1], -1, 1, 2) - q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) - k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) + if pe is not None: + q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) + k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2) + q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) + k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) heads = q.shape[1] x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) @@ -36,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index cc34f7585..ef4ba4106 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -115,8 +115,11 @@ class Flux(nn.Module): vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) txt = self.txt_in(txt) - ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) + if img_ids is not None: + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + else: + pe = None blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.double_blocks): diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py new file mode 100644 index 000000000..0305747bf --- /dev/null +++ b/comfy/ldm/hidream/model.py @@ -0,0 +1,802 @@ +from typing import Optional, Tuple, List + +import torch +import torch.nn as nn +import einops +from einops import repeat + +from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps +import torch.nn.functional as F + +from comfy.ldm.flux.math import apply_rope, rope +from comfy.ldm.flux.layers import LastLayer + +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management +import comfy.ldm.common_dit + + +# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +class EmbedND(nn.Module): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(2) + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size=2, + in_channels=4, + out_channels=1024, + dtype=None, device=None, operations=None + ): + super().__init__() + self.patch_size = patch_size + self.out_channels = out_channels + self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, latent): + latent = self.proj(latent) + return latent + + +class PooledEmbed(nn.Module): + def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None): + super().__init__() + self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations) + + def forward(self, pooled_embed): + return self.pooled_embedder(pooled_embed) + + +class TimestepEmbed(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + super().__init__() + self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations) + + def forward(self, timesteps, wdtype): + t_emb = self.time_proj(timesteps).to(dtype=wdtype) + t_emb = self.timestep_embedder(t_emb) + return t_emb + + +def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): + return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) + + +class HiDreamAttnProcessor_flashattn: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __call__( + self, + attn, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + dtype = image_tokens.dtype + batch_size = image_tokens.shape[0] + + query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype) + key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype) + value_i = attn.to_v(image_tokens) + + inner_dim = key_i.shape[-1] + head_dim = inner_dim // attn.heads + + query_i = query_i.view(batch_size, -1, attn.heads, head_dim) + key_i = key_i.view(batch_size, -1, attn.heads, head_dim) + value_i = value_i.view(batch_size, -1, attn.heads, head_dim) + if image_tokens_masks is not None: + key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1) + + if not attn.single: + query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype) + key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype) + value_t = attn.to_v_t(text_tokens) + + query_t = query_t.view(batch_size, -1, attn.heads, head_dim) + key_t = key_t.view(batch_size, -1, attn.heads, head_dim) + value_t = value_t.view(batch_size, -1, attn.heads, head_dim) + + num_image_tokens = query_i.shape[1] + num_text_tokens = query_t.shape[1] + query = torch.cat([query_i, query_t], dim=1) + key = torch.cat([key_i, key_t], dim=1) + value = torch.cat([value_i, value_t], dim=1) + else: + query = query_i + key = key_i + value = value_i + + if query.shape[-1] == rope.shape[-3] * 2: + query, key = apply_rope(query, key, rope) + else: + query_1, query_2 = query.chunk(2, dim=-1) + key_1, key_2 = key.chunk(2, dim=-1) + query_1, key_1 = apply_rope(query_1, key_1, rope) + query = torch.cat([query_1, query_2], dim=-1) + key = torch.cat([key_1, key_2], dim=-1) + + hidden_states = attention(query, key, value) + + if not attn.single: + hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) + hidden_states_i = attn.to_out(hidden_states_i) + hidden_states_t = attn.to_out_t(hidden_states_t) + return hidden_states_i, hidden_states_t + else: + hidden_states = attn.to_out(hidden_states) + return hidden_states + +class HiDreamAttention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + upcast_softmax: bool = False, + scale_qk: bool = True, + eps: float = 1e-5, + processor = None, + out_dim: int = None, + single: bool = False, + dtype=None, device=None, operations=None + ): + # super(Attention, self).__init__() + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.out_dim = out_dim if out_dim is not None else query_dim + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + self.single = single + + linear_cls = operations.Linear + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device) + self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device) + self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + + if not single: + self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device) + self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device) + self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + + self.processor = processor + + def forward( + self, + norm_image_tokens: torch.FloatTensor, + image_tokens_masks: torch.FloatTensor = None, + norm_text_tokens: torch.FloatTensor = None, + rope: torch.FloatTensor = None, + ) -> torch.Tensor: + return self.processor( + self, + image_tokens = norm_image_tokens, + image_tokens_masks = image_tokens_masks, + text_tokens = norm_text_tokens, + rope = rope, + ) + + +class FeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + dtype=None, device=None, operations=None + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of + ) + + self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device) + self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MoEGate(nn.Module): + def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None): + super().__init__() + self.top_k = num_activated_experts + self.n_routed_experts = num_routed_experts + + self.scoring_func = 'softmax' + self.alpha = aux_loss_alpha + self.seq_aux = False + + # topk selection algorithm + self.norm_topk_prob = False + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device)) + self.reset_parameters() + + def reset_parameters(self) -> None: + pass + # import torch.nn.init as init + # init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None) + if self.scoring_func == 'softmax': + scores = logits.softmax(dim=-1) + else: + raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') + + ### select top-k experts + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MOEFeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_routed_experts: int, + num_activated_experts: int, + dtype=None, device=None, operations=None + ): + super().__init__() + self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations) + self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)]) + self.gate = MoEGate( + embed_dim = dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + self.num_activated_experts = num_activated_experts + + def forward(self, x): + wtype = x.dtype + identity = x + orig_shape = x.shape + topk_idx, topk_weight, aux_loss = self.gate(x) + x = x.view(-1, x.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if True: # self.training: # TODO: check which branch performs faster + x = x.repeat_interleave(self.num_activated_experts, dim=0) + y = torch.empty_like(x, dtype=wtype) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape).to(dtype=wtype) + #y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.num_activated_experts + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i-1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # for fp16 and other dtype + expert_cache = expert_cache.to(expert_out.dtype) + expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum') + return expert_cache + + +class TextProjection(nn.Module): + def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None): + super().__init__() + self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device) + + def forward(self, caption): + hidden_states = self.linear(caption) + return hidden_states + + +class BlockType: + TransformerBlock = 1 + SingleTransformerBlock = 2 + + +class HiDreamImageSingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + dtype=None, device=None, operations=None + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device) + ) + + # 1. Attention + self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor = HiDreamAttnProcessor_flashattn(), + single = True, + dtype=dtype, device=device, operations=operations + ) + + # 3. Feed-forward + self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim = dim, + hidden_dim = 4 * dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + else: + self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + + ) -> torch.FloatTensor: + wtype = image_tokens.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ + self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1) + + # 1. MM-Attention + norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i + attn_output_i = self.attn1( + norm_image_tokens, + image_tokens_masks, + rope = rope, + ) + image_tokens = gate_msa_i * attn_output_i + image_tokens + + # 2. Feed-forward + norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i + ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype)) + image_tokens = ff_output_i + image_tokens + return image_tokens + + +class HiDreamImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + dtype=None, device=None, operations=None + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device) + ) + # nn.init.zeros_(self.adaLN_modulation[1].weight) + # nn.init.zeros_(self.adaLN_modulation[1].bias) + + # 1. Attention + self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor = HiDreamAttnProcessor_flashattn(), + single = False, + dtype=dtype, device=device, operations=operations + ) + + # 3. Feed-forward + self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim = dim, + hidden_dim = 4 * dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + else: + self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations) + self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + ) -> torch.FloatTensor: + wtype = image_tokens.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ + shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \ + self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1) + + # 1. MM-Attention + norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i + norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype) + norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t + + attn_output_i, attn_output_t = self.attn1( + norm_image_tokens, + image_tokens_masks, + norm_text_tokens, + rope = rope, + ) + + image_tokens = gate_msa_i * attn_output_i + image_tokens + text_tokens = gate_msa_t * attn_output_t + text_tokens + + # 2. Feed-forward + norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i + norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype) + norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t + + ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens) + ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens) + image_tokens = ff_output_i + image_tokens + text_tokens = ff_output_t + text_tokens + return image_tokens, text_tokens + + +class HiDreamImageBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + block_type: BlockType = BlockType.TransformerBlock, + dtype=None, device=None, operations=None + ): + super().__init__() + block_classes = { + BlockType.TransformerBlock: HiDreamImageTransformerBlock, + BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock, + } + self.block = block_classes[block_type]( + dim, + num_attention_heads, + attention_head_dim, + num_routed_experts, + num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: torch.FloatTensor = None, + rope: torch.FloatTensor = None, + ) -> torch.FloatTensor: + return self.block( + image_tokens, + image_tokens_masks, + text_tokens, + adaln_input, + rope, + ) + + +class HiDreamImageTransformer2DModel(nn.Module): + def __init__( + self, + patch_size: Optional[int] = None, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 16, + num_single_layers: int = 32, + attention_head_dim: int = 128, + num_attention_heads: int = 20, + caption_channels: List[int] = None, + text_emb_dim: int = 2048, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + axes_dims_rope: Tuple[int, int] = (32, 32), + max_resolution: Tuple[int, int] = (128, 128), + llama_layers: List[int] = None, + image_model=None, + dtype=None, device=None, operations=None + ): + self.patch_size = patch_size + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.num_layers = num_layers + self.num_single_layers = num_single_layers + + self.gradient_checkpointing = False + + super().__init__() + self.dtype = dtype + self.out_channels = out_channels or in_channels + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.llama_layers = llama_layers + + self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations) + self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) + self.x_embedder = PatchEmbed( + patch_size = patch_size, + in_channels = in_channels, + out_channels = self.inner_dim, + dtype=dtype, device=device, operations=operations + ) + self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope) + + self.double_stream_blocks = nn.ModuleList( + [ + HiDreamImageBlock( + dim = self.inner_dim, + num_attention_heads = self.num_attention_heads, + attention_head_dim = self.attention_head_dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + block_type = BlockType.TransformerBlock, + dtype=dtype, device=device, operations=operations + ) + for i in range(self.num_layers) + ] + ) + + self.single_stream_blocks = nn.ModuleList( + [ + HiDreamImageBlock( + dim = self.inner_dim, + num_attention_heads = self.num_attention_heads, + attention_head_dim = self.attention_head_dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + block_type = BlockType.SingleTransformerBlock, + dtype=dtype, device=device, operations=operations + ) + for i in range(self.num_single_layers) + ] + ) + + self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) + + caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ] + caption_projection = [] + for caption_channel in caption_channels: + caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations)) + self.caption_projection = nn.ModuleList(caption_projection) + self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) + + def expand_timesteps(self, timesteps, batch_size, device): + if not torch.is_tensor(timesteps): + is_mps = device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(batch_size) + return timesteps + + def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]: + x_arr = [] + for i, img_size in enumerate(img_sizes): + pH, pW = img_size + x_arr.append( + einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)', + p1=self.patch_size, p2=self.patch_size) + ) + x = torch.cat(x_arr, dim=0) + return x + + def patchify(self, x, max_seq, img_sizes=None): + pz2 = self.patch_size * self.patch_size + if isinstance(x, torch.Tensor): + B = x.shape[0] + device = x.device + dtype = x.dtype + else: + B = len(x) + device = x[0].device + dtype = x[0].dtype + x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device) + + if img_sizes is not None: + for i, img_size in enumerate(img_sizes): + x_masks[i, 0:img_size[0] * img_size[1]] = 1 + x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2) + elif isinstance(x, torch.Tensor): + pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size) + img_sizes = [[pH, pW]] * B + x_masks = None + else: + raise NotImplementedError + return x, x_masks, img_sizes + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + encoder_hidden_states_llama3=None, + image_cond=None, + control = None, + transformer_options = {}, + ) -> torch.Tensor: + bs, c, h, w = x.shape + if image_cond is not None: + x = torch.cat([x, image_cond], dim=-1) + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + timesteps = t + pooled_embeds = y + T5_encoder_hidden_states = context + + img_sizes = None + + # spatial forward + batch_size = hidden_states.shape[0] + hidden_states_type = hidden_states.dtype + + # 0. time + timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) + timesteps = self.t_embedder(timesteps, hidden_states_type) + p_embedder = self.p_embedder(pooled_embeds) + adaln_input = timesteps + p_embedder + + hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) + if image_tokens_masks is None: + pH, pW = img_sizes[0] + img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + hidden_states = self.x_embedder(hidden_states) + + # T5_encoder_hidden_states = encoder_hidden_states[0] + encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0) + encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] + + if self.caption_projection is not None: + new_encoder_hidden_states = [] + for i, enc_hidden_state in enumerate(encoder_hidden_states): + enc_hidden_state = self.caption_projection[i](enc_hidden_state) + enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) + new_encoder_hidden_states.append(enc_hidden_state) + encoder_hidden_states = new_encoder_hidden_states + T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) + T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + encoder_hidden_states.append(T5_encoder_hidden_states) + + txt_ids = torch.zeros( + batch_size, + encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], + 3, + device=img_ids.device, dtype=img_ids.dtype + ) + ids = torch.cat((img_ids, txt_ids), dim=1) + rope = self.pe_embedder(ids) + + # 2. Blocks + block_id = 0 + initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) + initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] + for bid, block in enumerate(self.double_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) + hidden_states, initial_encoder_hidden_states = block( + image_tokens = hidden_states, + image_tokens_masks = image_tokens_masks, + text_tokens = cur_encoder_hidden_states, + adaln_input = adaln_input, + rope = rope, + ) + initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] + block_id += 1 + + image_tokens_seq_len = hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) + hidden_states_seq_len = hidden_states.shape[1] + if image_tokens_masks is not None: + encoder_attention_mask_ones = torch.ones( + (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), + device=image_tokens_masks.device, dtype=image_tokens_masks.dtype + ) + image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) + + for bid, block in enumerate(self.single_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) + hidden_states = block( + image_tokens=hidden_states, + image_tokens_masks=image_tokens_masks, + text_tokens=None, + adaln_input=adaln_input, + rope=rope, + ) + hidden_states = hidden_states[:, :hidden_states_seq_len] + block_id += 1 + + hidden_states = hidden_states[:, :image_tokens_seq_len, ...] + output = self.final_layer(hidden_states, adaln_input) + output = self.unpatchify(output, img_sizes) + return -output[:, :, :h, :w] diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py new file mode 100644 index 000000000..4e18358f0 --- /dev/null +++ b/comfy/ldm/hunyuan3d/model.py @@ -0,0 +1,135 @@ +import torch +from torch import nn +from comfy.ldm.flux.layers import ( + DoubleStreamBlock, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + + +class Hunyuan3Dv2(nn.Module): + def __init__( + self, + in_channels=64, + context_in_dim=1536, + hidden_size=1024, + mlp_ratio=4.0, + num_heads=16, + depth=16, + depth_single_blocks=32, + qkv_bias=True, + guidance_embed=False, + image_model=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.dtype = dtype + + if hidden_size % num_heads != 0: + raise ValueError( + f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}" + ) + + self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead + self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None + ) + self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device) + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + dtype=dtype, device=device, operations=operations + ) + for _ in range(depth) + ] + ) + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + dtype=dtype, device=device, operations=operations + ) + for _ in range(depth_single_blocks) + ] + ) + self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs): + x = x.movedim(-1, -2) + timestep = 1.0 - timestep + txt = context + img = self.latent_in(x) + + vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype)) + if self.guidance_in is not None: + if guidance is not None: + vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype)) + + txt = self.cond_in(txt) + pe = None + attn_mask = None + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.double_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=vec, + pe=pe, + attn_mask=attn_mask) + + img = torch.cat((txt, img), 1) + + for i, block in enumerate(self.single_blocks): + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("single_block", i)]({"img": img, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + + img = img[:, txt.shape[1]:, ...] + img = self.final_layer(img, vec) + return img.movedim(-2, -1) * (-1.0) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py new file mode 100644 index 000000000..5eb2c6548 --- /dev/null +++ b/comfy/ldm/hunyuan3d/vae.py @@ -0,0 +1,587 @@ +# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py +# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from typing import Union, Tuple, List, Callable, Optional + +import numpy as np +from einops import repeat, rearrange +from tqdm import tqdm +import logging + +import comfy.ops +ops = comfy.ops.disable_weight_init + +def generate_dense_grid_points( + bbox_min: np.ndarray, + bbox_max: np.ndarray, + octree_resolution: int, + indexing: str = "ij", +): + length = bbox_max - bbox_min + num_cells = octree_resolution + + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length + + +class VanillaVolumeDecoder: + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: Callable, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + octree_resolution: int = None, + enable_pbar: bool = True, + **kwargs, + ): + device = latents.device + dtype = latents.dtype + batch_size = latents.shape[0] + + # 1. generate query points + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=octree_resolution, + indexing="ij" + ) + xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) + + # 2. latents to 3d volume + batch_logits = [] + for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding", + disable=not enable_pbar): + chunk_queries = xyz_samples[start: start + num_chunks, :] + chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) + logits = geo_decoder(queries=chunk_queries, latents=latents) + batch_logits.append(logits) + + grid_logits = torch.cat(batch_logits, dim=1) + grid_logits = grid_logits.view((batch_size, *grid_size)).float() + + return grid_logits + + +class FourierEmbedder(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__(self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True) -> None: + + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + num_freqs, + dtype=torch.float32 + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (num_freqs - 1), + num_freqs, + dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class CrossAttentionProcessor: + def __call__(self, attn, q, k, v): + out = F.scaled_dot_product_attention(q, k, v) + return out + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if self.drop_prob == 0. or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob, 3):0.3f}' + + +class MLP(nn.Module): + def __init__( + self, *, + width: int, + expand_ratio: int = 4, + output_width: int = None, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.c_fc = ops.Linear(width, width * expand_ratio) + self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width) + self.gelu = nn.GELU() + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + heads: int, + width=None, + qk_norm=False, + norm_layer=ops.LayerNorm + ): + super().__init__() + self.heads = heads + self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + self.attn_processor = CrossAttentionProcessor() + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + out = self.attn_processor(self, q, k, v) + out = out.transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + data_width: Optional[int] = None, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + kv_cache: bool = False, + ): + super().__init__() + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = ops.Linear(width, width, bias=qkv_bias) + self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias) + self.c_proj = ops.Linear(width, width) + self.attention = QKVMultiheadCrossAttention( + heads=heads, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.kv_cache = kv_cache + self.data = None + + def forward(self, x, data): + x = self.c_q(x) + if self.kv_cache: + if self.data is None: + self.data = self.c_kv(data) + logging.info('Save kv cache,this should be called only once for one mesh') + data = self.data + else: + data = self.c_kv(data) + x = self.attention(x, data) + x = self.c_proj(x) + return x + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + data_width: Optional[int] = None, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + width=width, + heads=heads, + data_width=data_width, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) + self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__( + self, + *, + heads: int, + width=None, + qk_norm=False, + norm_layer=ops.LayerNorm + ): + super().__init__() + self.heads = heads + self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + + q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.heads = heads + self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias) + self.c_proj = ops.Linear(width, width) + self.attention = QKVMultiheadAttention( + heads=heads, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + x = self.c_qkv(x) + x = self.attention(x) + x = self.drop_path(self.c_proj(x)) + return x + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.attn = MultiheadAttention( + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) + self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) + + def forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + *, + width: int, + layers: int, + heads: int, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x + + +class CrossAttentionDecoder(nn.Module): + + def __init__( + self, + *, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + downsample_ratio: int = 1, + enable_ln_post: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary" + ): + super().__init__() + + self.enable_ln_post = enable_ln_post + self.fourier_embedder = fourier_embedder + self.downsample_ratio = downsample_ratio + self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width) + if self.downsample_ratio != 1: + self.latents_proj = ops.Linear(width * downsample_ratio, width) + if self.enable_ln_post == False: + qk_norm = False + self.cross_attn_decoder = ResidualCrossAttentionBlock( + width=width, + mlp_expand_ratio=mlp_expand_ratio, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm + ) + + if self.enable_ln_post: + self.ln_post = ops.LayerNorm(width) + self.output_proj = ops.Linear(width, out_channels) + self.label_type = label_type + self.count = 0 + + def forward(self, queries=None, query_embeddings=None, latents=None): + if query_embeddings is None: + query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype)) + self.count += query_embeddings.shape[1] + if self.downsample_ratio != 1: + latents = self.latents_proj(latents) + x = self.cross_attn_decoder(query_embeddings, latents) + if self.enable_ln_post: + x = self.ln_post(x) + occ = self.output_proj(x) + return occ + + +class ShapeVAE(nn.Module): + def __init__( + self, + *, + embed_dim: int, + width: int, + heads: int, + num_decoder_layers: int, + geo_decoder_downsample_ratio: int = 1, + geo_decoder_mlp_expand_ratio: int = 4, + geo_decoder_ln_post: bool = True, + num_freqs: int = 8, + include_pi: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary", + drop_path_rate: float = 0.0, + scale_factor: float = 1.0, + ): + super().__init__() + self.geo_decoder_ln_post = geo_decoder_ln_post + + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + self.post_kl = ops.Linear(embed_dim, width) + + self.transformer = Transformer( + width=width, + layers=num_decoder_layers, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + + self.geo_decoder = CrossAttentionDecoder( + fourier_embedder=self.fourier_embedder, + out_channels=1, + mlp_expand_ratio=geo_decoder_mlp_expand_ratio, + downsample_ratio=geo_decoder_downsample_ratio, + enable_ln_post=self.geo_decoder_ln_post, + width=width // geo_decoder_downsample_ratio, + heads=heads // geo_decoder_downsample_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + label_type=label_type, + ) + + self.volume_decoder = VanillaVolumeDecoder() + self.scale_factor = scale_factor + + def decode(self, latents, **kwargs): + latents = self.post_kl(latents.movedim(-2, -1)) + latents = self.transformer(latents) + + bounds = kwargs.get("bounds", 1.01) + num_chunks = kwargs.get("num_chunks", 8000) + octree_resolution = kwargs.get("octree_resolution", 256) + enable_pbar = kwargs.get("enable_pbar", True) + + grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar) + return grid_logits.movedim(-2, -1) + + def encode(self, x): + return None diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 2758f9508..45f9e311e 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -24,6 +24,13 @@ if model_management.sage_attention_enabled(): logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") exit(-1) +if model_management.flash_attention_enabled(): + try: + from flash_attn import flash_attn_func + except ModuleNotFoundError: + logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") + exit(-1) + from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -464,7 +471,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape - tensor_layout="HND" + tensor_layout = "HND" else: b, _, dim_head = q.shape dim_head //= heads @@ -472,7 +479,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= lambda t: t.view(b, -1, heads, dim_head), (q, k, v), ) - tensor_layout="NHD" + tensor_layout = "NHD" if mask is not None: # add a batch dimension if there isn't already one @@ -482,7 +489,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= if mask.ndim == 3: mask = mask.unsqueeze(1) - out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + try: + out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + except Exception as e: + logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) + if tensor_layout == "NHD": + q, k, v = map( + lambda t: t.transpose(1, 2), + (q, k, v), + ) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) + if tensor_layout == "HND": if not skip_output_reshape: out = ( @@ -496,6 +513,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= return out +try: + @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) + def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: + return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) + + + @flash_attn_wrapper.register_fake + def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False): + # Output shape is the same as q + return q.new_empty(q.shape) +except AttributeError as error: + FLASH_ATTN_ERROR = error + + def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: + assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" + + +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + (q, k, v), + ) + + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + try: + assert mask is None + out = flash_attn_wrapper( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + dropout_p=0.0, + causal=False, + ).transpose(1, 2) + except Exception as e: + logging.warning(f"Flash Attention failed, using default SDPA: {e}") + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + if not skip_output_reshape: + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + return out + + optimized_attention = attention_basic if model_management.sage_attention_enabled(): @@ -504,6 +578,9 @@ if model_management.sage_attention_enabled(): elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers +elif model_management.flash_attention_enabled(): + logging.info("Using Flash Attention") + optimized_attention = attention_flash elif model_management.pytorch_attention_enabled(): logging.info("Using pytorch attention") optimized_attention = attention_pytorch @@ -770,6 +847,7 @@ class SpatialTransformer(nn.Module): if not isinstance(context, list): context = [context] * len(self.transformer_blocks) b, c, h, w = x.shape + transformer_options["activations_shape"] = list(x.shape) x_in = x x = self.norm(x) if not self.use_linear: @@ -885,6 +963,7 @@ class SpatialVideoTransformer(SpatialTransformer): transformer_options={} ) -> torch.Tensor: _, _, h, w = x.shape + transformer_options["activations_shape"] = list(x.shape) x_in = x spatial_context = None if exists(context): diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index e78d846b2..66bee7480 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context): + def forward(self, x, context, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, context): + def forward(self, x, context, context_img_len): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] """ - context_img = context[:, :257] - context = context[:, 257:] + context_img = context[:, :context_img_len] + context = context[:, context_img_len:] # compute query, key, value q = self.norm_q(self.q(x)) @@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module): e, freqs, context, + context_img_len=257, ): r""" Args: @@ -213,12 +214,40 @@ class WanAttentionBlock(nn.Module): x = x + y * e[2] # cross-attention & ffn - x = x + self.cross_attn(self.norm3(x), context) + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) x = x + y * e[5] return x +class VaceWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=0, + operation_settings={} + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) + self.block_id = block_id + if block_id == 0: + self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + return c_skip, c + + class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}): @@ -250,7 +279,7 @@ class Head(nn.Module): class MLPProj(torch.nn.Module): - def __init__(self, in_dim, out_dim, operation_settings={}): + def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}): super().__init__() self.proj = torch.nn.Sequential( @@ -258,7 +287,15 @@ class MLPProj(torch.nn.Module): torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + if flf_pos_embed_token_number is not None: + self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + else: + self.emb_pos = None + def forward(self, image_embeds): + if self.emb_pos is not None: + image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device) + clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens @@ -284,6 +321,7 @@ class WanModel(torch.nn.Module): qk_norm=True, cross_attn_norm=True, eps=1e-6, + flf_pos_embed_token_number=None, image_model=None, device=None, dtype=None, @@ -373,7 +411,7 @@ class WanModel(torch.nn.Module): self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) if model_type == 'i2v': - self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings) + self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings) else: self.img_emb = None @@ -384,6 +422,8 @@ class WanModel(torch.nn.Module): context, clip_fea=None, freqs=None, + transformer_options={}, + **kwargs, ): r""" Forward pass through the diffusion model @@ -419,18 +459,25 @@ class WanModel(torch.nn.Module): # context context = self.text_embedding(context) - if clip_fea is not None and self.img_emb is not None: - context_clip = self.img_emb(clip_fea) # bs x 257 x dim - context = torch.concat([context_clip, context], dim=1) + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] - # arguments - kwargs = dict( - e=e0, - freqs=freqs, - context=context) - - for block in self.blocks: - x = block(x, **kwargs) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) # head x = self.head(x, e) @@ -439,7 +486,7 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x - def forward(self, x, timestep, context, clip_fea=None, **kwargs): + def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) patch_size = self.patch_size @@ -453,7 +500,7 @@ class WanModel(torch.nn.Module): img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) freqs = self.rope_embedder(img_ids).movedim(1, 2) - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w] + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): r""" @@ -478,3 +525,116 @@ class WanModel(torch.nn.Module): u = torch.einsum('bfhwpqrc->bcfphqwr', u) u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) return u + + +class VaceWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='vace', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + image_model=None, + vace_layers=None, + vace_in_dim=None, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + # Vace + if vace_layers is not None: + self.vace_layers = vace_layers + self.vace_in_dim = vace_in_dim + # vace blocks + self.vace_blocks = nn.ModuleList([ + VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings) + for i in range(self.vace_layers) + ]) + + self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))} + # vace patch embeddings + self.vace_patch_embedding = operations.Conv3d( + self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32 + ) + + def forward_orig( + self, + x, + t, + context, + vace_context, + vace_strength=1.0, + clip_fea=None, + freqs=None, + transformer_options={}, + **kwargs, + ): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype) + c = c.flatten(2).transpose(1, 2) + + # arguments + x_orig = x + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + + ii = self.vace_layers_mapping.get(i, None) + if ii is not None: + c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x += c_skip * vace_strength + del c_skip + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/lora.py b/comfy/lora.py index bc9f3022a..fff524be2 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -20,6 +20,7 @@ from __future__ import annotations import comfy.utils import comfy.model_management import comfy.model_base +import comfy.weight_adapter as weight_adapter import logging import torch @@ -49,139 +50,12 @@ def load_lora(lora, to_load, log_missing=True): dora_scale = lora[dora_scale_name] loaded_keys.add(dora_scale_name) - reshape_name = "{}.reshape_weight".format(x) - reshape = None - if reshape_name in lora.keys(): - try: - reshape = lora[reshape_name].tolist() - loaded_keys.add(reshape_name) - except: - pass - - regular_lora = "{}.lora_up.weight".format(x) - diffusers_lora = "{}_lora.up.weight".format(x) - diffusers2_lora = "{}.lora_B.weight".format(x) - diffusers3_lora = "{}.lora.up.weight".format(x) - mochi_lora = "{}.lora_B".format(x) - transformers_lora = "{}.lora_linear_layer.up.weight".format(x) - A_name = None - - if regular_lora in lora.keys(): - A_name = regular_lora - B_name = "{}.lora_down.weight".format(x) - mid_name = "{}.lora_mid.weight".format(x) - elif diffusers_lora in lora.keys(): - A_name = diffusers_lora - B_name = "{}_lora.down.weight".format(x) - mid_name = None - elif diffusers2_lora in lora.keys(): - A_name = diffusers2_lora - B_name = "{}.lora_A.weight".format(x) - mid_name = None - elif diffusers3_lora in lora.keys(): - A_name = diffusers3_lora - B_name = "{}.lora.down.weight".format(x) - mid_name = None - elif mochi_lora in lora.keys(): - A_name = mochi_lora - B_name = "{}.lora_A".format(x) - mid_name = None - elif transformers_lora in lora.keys(): - A_name = transformers_lora - B_name ="{}.lora_linear_layer.down.weight".format(x) - mid_name = None - - if A_name is not None: - mid = None - if mid_name is not None and mid_name in lora.keys(): - mid = lora[mid_name] - loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)) - loaded_keys.add(A_name) - loaded_keys.add(B_name) - - - ######## loha - hada_w1_a_name = "{}.hada_w1_a".format(x) - hada_w1_b_name = "{}.hada_w1_b".format(x) - hada_w2_a_name = "{}.hada_w2_a".format(x) - hada_w2_b_name = "{}.hada_w2_b".format(x) - hada_t1_name = "{}.hada_t1".format(x) - hada_t2_name = "{}.hada_t2".format(x) - if hada_w1_a_name in lora.keys(): - hada_t1 = None - hada_t2 = None - if hada_t1_name in lora.keys(): - hada_t1 = lora[hada_t1_name] - hada_t2 = lora[hada_t2_name] - loaded_keys.add(hada_t1_name) - loaded_keys.add(hada_t2_name) - - patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) - loaded_keys.add(hada_w1_a_name) - loaded_keys.add(hada_w1_b_name) - loaded_keys.add(hada_w2_a_name) - loaded_keys.add(hada_w2_b_name) - - - ######## lokr - lokr_w1_name = "{}.lokr_w1".format(x) - lokr_w2_name = "{}.lokr_w2".format(x) - lokr_w1_a_name = "{}.lokr_w1_a".format(x) - lokr_w1_b_name = "{}.lokr_w1_b".format(x) - lokr_t2_name = "{}.lokr_t2".format(x) - lokr_w2_a_name = "{}.lokr_w2_a".format(x) - lokr_w2_b_name = "{}.lokr_w2_b".format(x) - - lokr_w1 = None - if lokr_w1_name in lora.keys(): - lokr_w1 = lora[lokr_w1_name] - loaded_keys.add(lokr_w1_name) - - lokr_w2 = None - if lokr_w2_name in lora.keys(): - lokr_w2 = lora[lokr_w2_name] - loaded_keys.add(lokr_w2_name) - - lokr_w1_a = None - if lokr_w1_a_name in lora.keys(): - lokr_w1_a = lora[lokr_w1_a_name] - loaded_keys.add(lokr_w1_a_name) - - lokr_w1_b = None - if lokr_w1_b_name in lora.keys(): - lokr_w1_b = lora[lokr_w1_b_name] - loaded_keys.add(lokr_w1_b_name) - - lokr_w2_a = None - if lokr_w2_a_name in lora.keys(): - lokr_w2_a = lora[lokr_w2_a_name] - loaded_keys.add(lokr_w2_a_name) - - lokr_w2_b = None - if lokr_w2_b_name in lora.keys(): - lokr_w2_b = lora[lokr_w2_b_name] - loaded_keys.add(lokr_w2_b_name) - - lokr_t2 = None - if lokr_t2_name in lora.keys(): - lokr_t2 = lora[lokr_t2_name] - loaded_keys.add(lokr_t2_name) - - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) - - #glora - a1_name = "{}.a1.weight".format(x) - a2_name = "{}.a2.weight".format(x) - b1_name = "{}.b1.weight".format(x) - b2_name = "{}.b2.weight".format(x) - if a1_name in lora: - patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) - loaded_keys.add(a1_name) - loaded_keys.add(a2_name) - loaded_keys.add(b1_name) - loaded_keys.add(b2_name) + for adapter_cls in weight_adapter.adapters: + adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys) + if adapter is not None: + patch_dict[to_load[x]] = adapter + loaded_keys.update(adapter.loaded_keys) + continue w_norm_name = "{}.w_norm".format(x) b_norm_name = "{}.b_norm".format(x) @@ -405,29 +279,16 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(key_lora)] = k key_map["diffusion_model.{}".format(key_lora)] = k # Old loras + if isinstance(model, comfy.model_base.HiDream): + for k in sdk: + if k.startswith("diffusion_model."): + if k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") + key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format + return key_map -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): - dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) - lora_diff *= alpha - weight_calc = weight + function(lora_diff).type(weight.dtype) - weight_norm = ( - weight_calc.transpose(0, 1) - .reshape(weight_calc.shape[1], -1) - .norm(dim=1, keepdim=True) - .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) - .transpose(0, 1) - ) - - weight_calc *= (dora_scale / weight_norm).type(weight.dtype) - if strength != 1.0: - weight_calc -= weight - weight += strength * (weight_calc) - else: - weight[:] = weight_calc - return weight - def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: """ Pad a tensor to a new shape with zeros. @@ -482,6 +343,16 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori if isinstance(v, list): v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), ) + if isinstance(v, weight_adapter.WeightAdapterBase): + output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights) + if output is None: + logging.warning("Calculate Weight Failed: {} {}".format(v.name, key)) + else: + weight = output + if old_weight is not None: + weight = old_weight + continue + if len(v) == 1: patch_type = "diff" elif len(v) == 2: @@ -508,157 +379,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \ comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype) weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype)) - elif patch_type == "lora": #lora/locon - mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) - mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) - dora_scale = v[4] - reshape = v[5] - - if reshape is not None: - weight = pad_tensor_to_shape(weight, reshape) - - if v[2] is not None: - alpha = v[2] / mat2.shape[0] - else: - alpha = 1.0 - - if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype) - final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - try: - lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "lokr": - w1 = v[0] - w2 = v[1] - w1_a = v[3] - w1_b = v[4] - w2_a = v[5] - w2_b = v[6] - t2 = v[7] - dora_scale = v[8] - dim = None - - if w1 is None: - dim = w1_b.shape[0] - w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) - else: - w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) - - if w2 is None: - dim = w2_b.shape[0] - if t2 is None: - w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) - else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) - else: - w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - if v[2] is not None and dim is not None: - alpha = v[2] / dim - else: - alpha = 1.0 - - try: - lora_diff = torch.kron(w1, w2).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "loha": - w1a = v[0] - w1b = v[1] - if v[2] is not None: - alpha = v[2] / w1b.shape[0] - else: - alpha = 1.0 - - w2a = v[3] - w2b = v[4] - dora_scale = v[7] - if v[5] is not None: #cp decomposition - t1 = v[5] - t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) - - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) - else: - m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) - m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) - - try: - lora_diff = (m1 * m2).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "glora": - dora_scale = v[5] - - old_glora = False - if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: - rank = v[0].shape[0] - old_glora = True - - if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: - if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: - pass - else: - old_glora = False - rank = v[1].shape[0] - - a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) - a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) - b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) - b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) - - if v[4] is not None: - alpha = v[4] / rank - else: - alpha = 1.0 - - try: - if old_glora: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora - else: - if weight.dim() > 2: - lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) - else: - lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) - lora_diff += torch.mm(b1, b2).reshape(weight.shape) - - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: logging.warning("patch type not recognized {} {}".format(patch_type, key)) diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py index 05032c690..3e00b63db 100644 --- a/comfy/lora_convert.py +++ b/comfy/lora_convert.py @@ -1,4 +1,5 @@ import torch +import comfy.utils def convert_lora_bfl_control(sd): #BFL loras for Flux @@ -11,7 +12,13 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux return sd_out +def convert_lora_wan_fun(sd): #Wan Fun loras + return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"}) + + def convert_lora(sd): if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd: return convert_lora_bfl_control(sd) + if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd: + return convert_lora_wan_fun(sd) return sd diff --git a/comfy/model_base.py b/comfy/model_base.py index 976702b60..3d33086d8 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -36,6 +36,9 @@ import comfy.ldm.hunyuan_video.model import comfy.ldm.cosmos.model import comfy.ldm.lumina.model import comfy.ldm.wan.model +import comfy.ldm.hunyuan3d.model +import comfy.ldm.hidream.model +import comfy.ldm.chroma.model import comfy.model_management import comfy.patcher_extension @@ -58,6 +61,7 @@ class ModelType(Enum): FLOW = 6 V_PREDICTION_CONTINUOUS = 7 FLUX = 8 + IMG_TO_IMG = 9 from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV @@ -88,6 +92,8 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.FLUX: c = comfy.model_sampling.CONST s = comfy.model_sampling.ModelSamplingFlux + elif model_type == ModelType.IMG_TO_IMG: + c = comfy.model_sampling.IMG_TO_IMG class ModelSampling(s, c): pass @@ -139,6 +145,7 @@ class BaseModel(torch.nn.Module): def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t xc = self.model_sampling.calculate_input(sigma, x) + if c_concat is not None: xc = torch.cat([xc] + [c_concat], dim=1) @@ -600,6 +607,19 @@ class SDXL_instructpix2pix(IP2P, SDXL): else: self.process_ip2p_image_in = lambda image: image #diffusers ip2p +class Lotus(BaseModel): + def extra_conds(self, **kwargs): + out = {} + cross_attn = kwargs.get("cross_attn", None) + out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + device = kwargs["device"] + task_emb = torch.tensor([1, 0]).float().to(device) + task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)]).unsqueeze(0) + out['y'] = comfy.conds.CONDRegular(task_emb) + return out + + def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None): + super().__init__(model_config, model_type, device=device) class StableCascade_C(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): @@ -767,8 +787,8 @@ class PixArt(BaseModel): return out class Flux(BaseModel): - def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux) + def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux): + super().__init__(model_config, model_type, device=device, unet_model=unet_model) def concat_cond(self, **kwargs): try: @@ -974,31 +994,41 @@ class WAN21(BaseModel): def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) - if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]: + extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] + if extra_channels == 0: return None image = kwargs.get("concat_latent_image", None) device = kwargs["device"] if image is None: - image = torch.zeros_like(noise) + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], 16): + image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) + image = utils.resize_to_batch_size(image, noise.shape[0]) - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") - image = self.process_latent_in(image) - image = utils.resize_to_batch_size(image, noise.shape[0]) - - if not self.image_to_video: + if not self.image_to_video or extra_channels == image.shape[1]: return image + if image.shape[1] > (extra_channels - 4): + image = image[:, :(extra_channels - 4)] + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if mask is None: mask = torch.zeros_like(noise)[:, :4] else: - mask = 1.0 - torch.mean(mask, dim=1, keepdim=True) + if mask.shape[1] != 4: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") if mask.shape[-3] < noise.shape[-3]: mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) - mask = mask.repeat(1, 4, 1, 1, 1) + if mask.shape[1] == 1: + mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) return torch.cat((mask, image), dim=1) @@ -1013,3 +1043,81 @@ class WAN21(BaseModel): if clip_vision_output is not None: out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states) return out + + +class WAN21_Vace(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + noise = kwargs.get("noise", None) + noise_shape = list(noise.shape) + vace_frames = kwargs.get("vace_frames", None) + if vace_frames is None: + noise_shape[1] = 32 + vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype) + + for i in range(0, vace_frames.shape[1], 16): + vace_frames = vace_frames.clone() + vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16]) + + mask = kwargs.get("vace_mask", None) + if mask is None: + noise_shape[1] = 64 + mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype) + + out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1)) + + vace_strength = kwargs.get("vace_strength", 1.0) + out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) + return out + + +class Hunyuan3Dv2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + guidance = kwargs.get("guidance", 5.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + return out + +class HiDream(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) + + def encode_adm(self, **kwargs): + return kwargs["pooled_output"] + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + conditioning_llama3 = kwargs.get("conditioning_llama3", None) + if conditioning_llama3 is not None: + out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3) + image_cond = kwargs.get("concat_latent_image", None) + if image_cond is not None: + out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond)) + return out + +class Chroma(Flux): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + guidance = kwargs.get("guidance", 0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 403da5855..9254843ea 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -154,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 @@ -164,7 +164,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if in_key in state_dict_keys: dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) dit_config["out_channels"] = 16 - dit_config["vec_in_dim"] = 768 + vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) + if vec_in_key in state_dict_keys: + dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] dit_config["context_in_dim"] = 4096 dit_config["hidden_size"] = 3072 dit_config["mlp_ratio"] = 4.0 @@ -174,7 +176,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["axes_dim"] = [16, 56, 56] dit_config["theta"] = 10000 dit_config["qkv_bias"] = True - dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma + dit_config["image_model"] = "chroma" + dit_config["in_channels"] = 64 + dit_config["out_channels"] = 64 + dit_config["in_dim"] = 64 + dit_config["out_dim"] = 3072 + dit_config["hidden_dim"] = 5120 + dit_config["n_layers"] = 5 + else: + dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview @@ -317,10 +328,52 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["cross_attn_norm"] = True dit_config["eps"] = 1e-6 dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1] - if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: - dit_config["model_type"] = "i2v" + if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "vace" + dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] + dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') else: - dit_config["model_type"] = "t2v" + if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "i2v" + else: + dit_config["model_type"] = "t2v" + flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) + if flf_weight is not None: + dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] + return dit_config + + if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D + in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape + dit_config = {} + dit_config["image_model"] = "hunyuan3d2" + dit_config["in_channels"] = in_shape[1] + dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1] + dit_config["hidden_size"] = in_shape[0] + dit_config["mlp_ratio"] = 4.0 + dit_config["num_heads"] = 16 + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') + dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') + dit_config["qkv_bias"] = True + dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + return dit_config + + if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream + dit_config = {} + dit_config["image_model"] = "hidream" + dit_config["attention_head_dim"] = 128 + dit_config["axes_dims_rope"] = [64, 32, 32] + dit_config["caption_channels"] = [4096, 4096] + dit_config["max_resolution"] = [128, 128] + dit_config["in_channels"] = 16 + dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31] + dit_config["num_attention_heads"] = 20 + dit_config["num_routed_experts"] = 4 + dit_config["num_activated_experts"] = 2 + dit_config["num_layers"] = 16 + dit_config["num_single_layers"] = 32 + dit_config["out_channels"] = 16 + dit_config["patch_size"] = 2 + dit_config["text_emb_dim"] = 2048 return dit_config if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: @@ -667,8 +720,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 'use_temporal_attention': False, 'use_temporal_resblock': False} + LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4, + 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], + 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8, + 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint] + supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint] for unet_config in supported_models: matches = True diff --git a/comfy/model_management.py b/comfy/model_management.py index 3a4c93e30..44aff3762 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -46,6 +46,32 @@ cpu_state = CPUState.GPU total_vram = 0 +def get_supported_float8_types(): + float8_types = [] + try: + float8_types.append(torch.float8_e4m3fn) + except: + pass + try: + float8_types.append(torch.float8_e4m3fnuz) + except: + pass + try: + float8_types.append(torch.float8_e5m2) + except: + pass + try: + float8_types.append(torch.float8_e5m2fnuz) + except: + pass + try: + float8_types.append(torch.float8_e8m0fnu) + except: + pass + return float8_types + +FLOAT8_TYPES = get_supported_float8_types() + xpu_available = False torch_version = "" try: @@ -186,12 +212,21 @@ def get_total_memory(dev=None, torch_total_too=False): else: return mem_total +def mac_version(): + try: + return tuple(int(n) for n in platform.mac_ver()[0].split(".")) + except: + return None + total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) try: logging.info("pytorch version: {}".format(torch_version)) + mac_ver = mac_version() + if mac_ver is not None: + logging.info("Mac Version {}".format(mac_ver)) except: pass @@ -690,13 +725,12 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor return torch.float8_e4m3fn if args.fp8_e5m2_unet: return torch.float8_e5m2 + if args.fp8_e8m0fnu_unet: + return torch.float8_e8m0fnu fp8_dtype = None - try: - if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - fp8_dtype = weight_dtype - except: - pass + if weight_dtype in FLOAT8_TYPES: + fp8_dtype = weight_dtype if fp8_dtype is not None: if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive @@ -791,6 +825,8 @@ def text_encoder_dtype(device=None): return torch.float8_e5m2 elif args.fp16_text_enc: return torch.float16 + elif args.bf16_text_enc: + return torch.bfloat16 elif args.fp32_text_enc: return torch.float32 @@ -903,15 +939,61 @@ def force_channels_last(): #TODO return False -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): + +STREAMS = {} +NUM_STREAMS = 1 +if args.async_offload: + NUM_STREAMS = 2 + logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) + +stream_counters = {} +def get_offload_stream(device): + stream_counter = stream_counters.get(device, 0) + if NUM_STREAMS <= 1: + return None + + if device in STREAMS: + ss = STREAMS[device] + s = ss[stream_counter] + stream_counter = (stream_counter + 1) % len(ss) + if is_device_cuda(device): + ss[stream_counter].wait_stream(torch.cuda.current_stream()) + stream_counters[device] = stream_counter + return s + elif is_device_cuda(device): + ss = [] + for k in range(NUM_STREAMS): + ss.append(torch.cuda.Stream(device=device, priority=0)) + STREAMS[device] = ss + s = ss[stream_counter] + stream_counter = (stream_counter + 1) % len(ss) + stream_counters[device] = stream_counter + return s + return None + +def sync_stream(device, stream): + if stream is None: + return + if is_device_cuda(device): + torch.cuda.current_stream().wait_stream(stream) + +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: return weight + if stream is not None: + with stream: + return weight.to(dtype=dtype, copy=copy) return weight.to(dtype=dtype, copy=copy) - r = torch.empty_like(weight, dtype=dtype, device=device) - r.copy_(weight, non_blocking=non_blocking) + if stream is not None: + with stream: + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight, non_blocking=non_blocking) + else: + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight, non_blocking=non_blocking) return r def cast_to_device(tensor, device, dtype, copy=False): @@ -921,6 +1003,9 @@ def cast_to_device(tensor, device, dtype, copy=False): def sage_attention_enabled(): return args.use_sage_attention +def flash_attention_enabled(): + return args.use_flash_attention + def xformers_enabled(): global directml_enabled global cpu_state @@ -969,12 +1054,6 @@ def pytorch_attention_flash_attention(): return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention return False -def mac_version(): - try: - return tuple(int(n) for n in platform.mac_ver()[0].split(".")) - except: - return None - def force_upcast_attention_dtype(): upcast = args.force_upcast_attention @@ -1206,6 +1285,8 @@ def soft_empty_cache(force=False): torch.xpu.empty_cache() elif is_ascend_npu(): torch.npu.empty_cache() + elif is_mlu(): + torch.mlu.empty_cache() elif torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e291158ce..b7cb12dfc 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -747,6 +747,7 @@ class ModelPatcher: def partially_unload(self, device_to, memory_to_free=0): with self.use_ejected(): + hooks_unpatched = False memory_freed = 0 patch_counter = 0 unload_list = self._load_list() @@ -770,6 +771,10 @@ class ModelPatcher: move_weight = False break + if not hooks_unpatched: + self.unpatch_hooks() + hooks_unpatched = True + if bk.inplace_update: comfy.utils.copy_to_param(self.model, key, bk.weight) else: diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index ff27b09a8..7e7291476 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -69,6 +69,15 @@ class CONST: sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1)) return latent / (1.0 - sigma) +class X0(EPS): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + +class IMG_TO_IMG(X0): + def calculate_input(self, sigma, noise): + return noise + + class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None, zsnr=None): super().__init__() @@ -102,13 +111,14 @@ class ModelSamplingDiscrete(torch.nn.Module): self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end + self.zsnr = zsnr # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 - if zsnr: + if self.zsnr: sigmas = rescale_zero_terminal_snr_sigmas(sigmas) self.set_sigmas(sigmas) diff --git a/comfy/ops.py b/comfy/ops.py index ced461011..032787915 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -21,6 +21,8 @@ import logging import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float +import comfy.rmsnorm +import contextlib cast_to = comfy.model_management.cast_to #TODO: remove once no more references @@ -36,20 +38,31 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if device is None: device = input.device + offload_stream = comfy.model_management.get_offload_stream(device) + if offload_stream is not None: + wf_context = offload_stream + else: + wf_context = contextlib.nullcontext() + bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: has_function = len(s.bias_function) > 0 - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) + bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) + if has_function: - for f in s.bias_function: - bias = f(bias) + with wf_context: + for f in s.bias_function: + bias = f(bias) has_function = len(s.weight_function) > 0 - weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) + weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: - for f in s.weight_function: - weight = f(weight) + with wf_context: + for f in s.weight_function: + weight = f(weight) + + comfy.model_management.sync_stream(device, offload_stream) return weight, bias class CastWeightBiasOp: @@ -146,6 +159,25 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) + class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp): + def reset_parameters(self): + self.bias = None + return None + + def forward_comfy_cast_weights(self, input): + if self.weight is not None: + weight, bias = cast_bias_weight(self, input) + else: + weight = None + return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated + # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None @@ -243,6 +275,9 @@ class manual_cast(disable_weight_init): class ConvTranspose1d(disable_weight_init.ConvTranspose1d): comfy_cast_weights = True + class RMSNorm(disable_weight_init.RMSNorm): + comfy_cast_weights = True + class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True @@ -357,6 +392,25 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None return scaled_fp8_op +CUBLAS_IS_AVAILABLE = False +try: + from cublas_ops import CublasLinear + CUBLAS_IS_AVAILABLE = True +except ImportError: + pass + +if CUBLAS_IS_AVAILABLE: + class cublas_ops(disable_weight_init): + class Linear(CublasLinear, disable_weight_init.Linear): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + return super().forward(input) + + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: @@ -369,6 +423,15 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ ): return fp8_ops + if ( + PerformanceFeature.CublasOps in args.fast and + CUBLAS_IS_AVAILABLE and + weight_dtype == torch.float16 and + (compute_dtype == torch.float16 or compute_dtype is None) + ): + logging.info("Using cublas ops") + return cublas_ops + if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py index 859758244..965958f4c 100644 --- a/comfy/patcher_extension.py +++ b/comfy/patcher_extension.py @@ -48,6 +48,7 @@ def get_all_callbacks(call_type: str, transformer_options: dict, is_model_option class WrappersMP: OUTER_SAMPLE = "outer_sample" + PREPARE_SAMPLING = "prepare_sampling" SAMPLER_SAMPLE = "sampler_sample" CALC_COND_BATCH = "calc_cond_batch" APPLY_MODEL = "apply_model" diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py new file mode 100644 index 000000000..9d82bee1a --- /dev/null +++ b/comfy/rmsnorm.py @@ -0,0 +1,55 @@ +import torch +import comfy.model_management +import numbers + +RMSNorm = None + +try: + rms_norm_torch = torch.nn.functional.rms_norm + RMSNorm = torch.nn.RMSNorm +except: + rms_norm_torch = None + + +def rms_norm(x, weight=None, eps=1e-6): + if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): + if weight is None: + return rms_norm_torch(x, (x.shape[-1],), eps=eps) + else: + return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) + else: + r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) + if weight is None: + return r + else: + return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device) + + +if RMSNorm is None: + class RMSNorm(torch.nn.Module): + def __init__( + self, + normalized_shape, + eps=None, + elementwise_affine=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = torch.nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.bias = None + + def forward(self, x): + return rms_norm(x, self.weight, self.eps) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 92ec7ca7a..96a3040a1 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -106,6 +106,13 @@ def cleanup_additional_models(models): def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): + executor = comfy.patcher_extension.WrapperExecutor.new_executor( + _prepare_sampling, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) + ) + return executor.execute(model, noise_shape, conds, model_options=model_options) + +def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) diff --git a/comfy/samplers.py b/comfy/samplers.py index 7578ac1ef..67ae09a25 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", - "gradient_estimation"] + "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy/sd.py b/comfy/sd.py index fd98585a1..da9b36d0e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae +import comfy.ldm.hunyuan3d.vae import yaml import math @@ -40,6 +41,7 @@ import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 import comfy.text_encoders.wan +import comfy.text_encoders.hidream import comfy.model_patcher import comfy.lora @@ -118,6 +120,7 @@ class CLIP: self.layer_idx = None self.use_clip_schedule = False logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) + self.tokenizer_options = {} def clone(self): n = CLIP(no_init=True) @@ -125,6 +128,7 @@ class CLIP: n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer n.layer_idx = self.layer_idx + n.tokenizer_options = self.tokenizer_options.copy() n.use_clip_schedule = self.use_clip_schedule n.apply_hooks_to_conds = self.apply_hooks_to_conds return n @@ -132,10 +136,18 @@ class CLIP: def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): return self.patcher.add_patches(patches, strength_patch, strength_model) + def set_tokenizer_option(self, option_name, value): + self.tokenizer_options[option_name] = value + def clip_layer(self, layer_idx): self.layer_idx = layer_idx def tokenize(self, text, return_word_ids=False, **kwargs): + tokenizer_options = kwargs.get("tokenizer_options", {}) + if len(self.tokenizer_options) > 0: + tokenizer_options = {**self.tokenizer_options, **tokenizer_options} + if len(tokenizer_options) > 0: + kwargs["tokenizer_options"] = tokenizer_options return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs) def add_hooks_to_dict(self, pooled_dict: dict[str]): @@ -264,6 +276,7 @@ class VAE: self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] + self.disable_offload = False self.downscale_index_formula = None self.upscale_index_formula = None @@ -336,6 +349,7 @@ class VAE: self.process_output = lambda audio: audio self.process_input = lambda audio: audio self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.disable_offload = True elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae if "blocks.2.blocks.3.stack.5.weight" in sd: sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."}) @@ -412,6 +426,17 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: + self.latent_dim = 1 + ln_post = "geo_decoder.ln_post.weight" in sd + inner_size = sd["geo_decoder.output_proj.weight"].shape[1] + downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size + mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size + self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO + self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO + ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} + self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig) + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -440,6 +465,10 @@ class VAE: self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + def throw_exception_if_invalid(self): + if self.first_stage_model is None: + raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") + def vae_encode_crop_pixels(self, pixels): downscale_ratio = self.spacial_compression_encode() @@ -494,18 +523,19 @@ class VAE: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) - def decode(self, samples_in): + def decode(self, samples_in, vae_options={}): + self.throw_exception_if_invalid() pixel_samples = None try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) + out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float()) if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out @@ -525,8 +555,9 @@ class VAE: return pixel_samples def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) dims = samples.ndim - 2 args = {} if tile_x is not None: @@ -553,13 +584,14 @@ class VAE: return output.movedim(1, -1) def encode(self, pixel_samples): + self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) if self.latent_dim == 3 and pixel_samples.ndim < 5: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) @@ -585,6 +617,7 @@ class VAE: return samples def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) @@ -592,7 +625,7 @@ class VAE: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) args = {} if tile_x is not None: @@ -680,6 +713,8 @@ class CLIPType(Enum): COSMOS = 11 LUMINA2 = 12 WAN = 13 + HIDREAM = 14 + CHROMA = 15 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -768,6 +803,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer @@ -781,13 +819,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.LTXV: clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer - elif clip_type == CLIPType.PIXART: + elif clip_type == CLIPType.PIXART or clip_type == CLIPType.CHROMA: clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer elif clip_type == CLIPType.WAN: clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer @@ -804,10 +846,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.LLAMA3_8: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: + # clip_l if clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer @@ -825,12 +875,33 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.HUNYUAN_VIDEO: clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer + elif clip_type == CLIPType.HIDREAM: + # Detect + hidream_dualclip_classes = [] + for hidream_te in clip_data: + te_model = detect_te_model(hidream_te) + hidream_dualclip_classes.append(te_model) + + clip_l = TEModel.CLIP_L in hidream_dualclip_classes + clip_g = TEModel.CLIP_G in hidream_dualclip_classes + t5 = TEModel.T5_XXL in hidream_dualclip_classes + llama = TEModel.LLAMA3_8 in hidream_dualclip_classes + + # Initialize t5xxl_detect and llama_detect kwargs if needed + t5_kwargs = t5xxl_detect(clip_data) if t5 else {} + llama_kwargs = llama_detect(clip_data) if llama else {} + + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer elif len(clip_data) == 3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif len(clip_data) == 4: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), **llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer parameters = 0 for c in clip_data: @@ -899,7 +970,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: - return None + logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") + diffusion_model = load_diffusion_model_state_dict(sd, model_options={}) + if diffusion_model is None: + return None + return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' + unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index be21ec18d..ac61babe9 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -82,7 +82,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): LAYERS = [ "last", "pooled", - "hidden" + "hidden", + "all" ] def __init__(self, device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, @@ -93,6 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") + if "model_name" not in model_options: + model_options = {**model_options, "model_name": "clip_l"} if isinstance(textmodel_json_config, dict): config = textmodel_json_config @@ -100,6 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): with open(textmodel_json_config) as f: config = json.load(f) + te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {}) + for k, v in te_model_options.items(): + config[k] = v + operations = model_options.get("custom_operations", None) scaled_fp8 = None @@ -147,7 +154,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if layer_idx is None or abs(layer_idx) > self.num_layers: + if self.layer == "all": + pass + elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: self.layer = "hidden" @@ -244,7 +253,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if self.enable_attention_masks: attention_mask_model = attention_mask - outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) + if self.layer == "all": + intermediate_output = "all" + else: + intermediate_output = self.layer_idx + + outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) if self.layer == "last": z = outputs[0].float() @@ -443,13 +457,14 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) - self.max_length = max_length + self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length) self.min_length = min_length self.end_token = None + self.min_padding = min_padding empty = self.tokenizer('')["input_ids"] self.tokenizer_adds_end_token = has_end_token @@ -504,13 +519,15 @@ class SDTokenizer: return (embed, leftover) - def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs): ''' Takes a prompt and converts it to a list of (token, weight, word id) elements. Tokens can both be integer tokens and pre computed CLIP tensors. Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. Returned list has the dimensions NxM where M is the input size of CLIP ''' + min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length) + min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) text = escape_important(text) parsed_weights = token_weights(text, 1.0) @@ -589,10 +606,12 @@ class SDTokenizer: #fill last batch if self.end_token is not None: batch.append((self.end_token, 1.0, 0)) - if self.pad_to_max_length: + if min_padding is not None: + batch.extend([(self.pad_token, 1.0, 0)] * min_padding) + if self.pad_to_max_length and len(batch) < self.max_length: batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) - if self.min_length is not None and len(batch) < self.min_length: - batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch))) + if min_length is not None and len(batch) < min_length: + batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch))) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] @@ -620,7 +639,7 @@ class SD1Tokenizer: def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} - out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids) + out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): @@ -645,6 +664,7 @@ class SD1ClipModel(torch.nn.Module): self.clip = "clip_{}".format(self.clip_name) clip_model = model_options.get("{}_class".format(self.clip), clip_model) + model_options = {**model_options, "model_name": self.clip} setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs)) self.dtypes = set() diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 5b7c8a412..c8cef14e4 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -9,6 +9,7 @@ class SDXLClipG(sd1_clip.SDClipModel): layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + model_options = {**model_options, "model_name": "clip_g"} super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options) @@ -17,19 +18,18 @@ class SDXLClipG(sd1_clip.SDClipModel): class SDXLClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data) class SDXLTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} - out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) - out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) + out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): @@ -41,8 +41,7 @@ class SDXLTokenizer: class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__() - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options) self.dtypes = set([dtype]) @@ -75,7 +74,7 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data) class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -84,6 +83,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): class StableCascadeClipG(sd1_clip.SDClipModel): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + model_options = {**model_options, "model_name": "clip_g"} super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b4d7bfe20..d5210cfac 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -506,6 +506,22 @@ class SDXL_instructpix2pix(SDXL): def get_model(self, state_dict, prefix="", device=None): return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device) +class LotusD(SD20): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "use_temporal_attention": False, + "adm_in_channels": 4, + "in_channels": 4, + } + + unet_extra_config = { + "num_classes": 'sequential' + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Lotus(self, device=device) + class SD3(supported_models_base.BASE): unet_config = { "in_channels": 16, @@ -953,12 +969,133 @@ class WAN21_I2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", "model_type": "i2v", + "in_dim": 36, } def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21(self, image_to_video=True, device=device) return out -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] +class WAN21_FunControl2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "i2v", + "in_dim": 48, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21(self, image_to_video=False, device=device) + return out + +class WAN21_Vace(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "vace", + } + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = 1.2 * self.memory_usage_factor + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_Vace(self, image_to_video=False, device=device) + return out + +class Hunyuan3Dv2(supported_models_base.BASE): + unet_config = { + "image_model": "hunyuan3d2", + } + + unet_extra_config = {} + + sampling_settings = { + "multiplier": 1.0, + "shift": 1.0, + } + + memory_usage_factor = 3.5 + + clip_vision_prefix = "conditioner.main_image_encoder.model." + vae_key_prefix = ["vae."] + + latent_format = latent_formats.Hunyuan3Dv2 + + def process_unet_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "model."} + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Hunyuan3Dv2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None + +class Hunyuan3Dv2mini(Hunyuan3Dv2): + unet_config = { + "image_model": "hunyuan3d2", + "depth": 8, + } + + latent_format = latent_formats.Hunyuan3Dv2mini + +class HiDream(supported_models_base.BASE): + unet_config = { + "image_model": "hidream", + } + + sampling_settings = { + "shift": 3.0, + } + + sampling_settings = { + } + + # memory_usage_factor = 1.2 # TODO + + unet_extra_config = {} + latent_format = latent_formats.Flux + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HiDream(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None # TODO + +class Chroma(supported_models_base.BASE): + unet_config = { + "image_model": "chroma", + } + + unet_extra_config = { + } + + sampling_settings = { + "multiplier": 1.0, + } + + latent_format = comfy.latent_formats.Flux + + memory_usage_factor = 3.2 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Chroma(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma] models += [SVD_img2vid] diff --git a/comfy/text_encoders/aura_t5.py b/comfy/text_encoders/aura_t5.py index e9ad45a7f..cf4252eea 100644 --- a/comfy/text_encoders/aura_t5.py +++ b/comfy/text_encoders/aura_t5.py @@ -11,7 +11,7 @@ class PT5XlModel(sd1_clip.SDClipModel): class PT5XlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data) class AuraT5Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index 5441c8952..a1adb5242 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -22,7 +22,7 @@ class CosmosT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data) class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index a12995ec0..d61ef6668 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -9,19 +9,18 @@ import os class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) class FluxTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} - out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) - out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): @@ -35,8 +34,7 @@ class FluxClipModel(torch.nn.Module): def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options) self.dtypes = set([dtype, dtype_t5]) diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 45987a480..9dcf190a2 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -18,7 +18,7 @@ class MochiT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py new file mode 100644 index 000000000..dbcf52784 --- /dev/null +++ b/comfy/text_encoders/hidream.py @@ -0,0 +1,155 @@ +from . import hunyuan_video +from . import sd3_clip +from comfy import sd1_clip +from comfy import sdxl_clip +import comfy.model_management +import torch +import logging + + +class HiDreamTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data) + self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) + t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) + out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens + out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + return self.clip_g.untokenize(token_weight_pair) + + def state_dict(self): + return {} + + +class HiDreamTEModel(torch.nn.Module): + def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}): + super().__init__() + self.dtypes = set() + if clip_l: + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options) + self.dtypes.add(dtype) + else: + self.clip_l = None + + if clip_g: + self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options) + self.dtypes.add(dtype) + else: + self.clip_g = None + + if t5: + dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) + self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True) + self.dtypes.add(dtype_t5) + else: + self.t5xxl = None + + if llama: + dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device) + if "vocab_size" not in model_options: + model_options["vocab_size"] = 128256 + self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009}) + self.dtypes.add(dtype_llama) + else: + self.llama = None + + logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama)) + + def set_clip_options(self, options): + if self.clip_l is not None: + self.clip_l.set_clip_options(options) + if self.clip_g is not None: + self.clip_g.set_clip_options(options) + if self.t5xxl is not None: + self.t5xxl.set_clip_options(options) + if self.llama is not None: + self.llama.set_clip_options(options) + + def reset_clip_options(self): + if self.clip_l is not None: + self.clip_l.reset_clip_options() + if self.clip_g is not None: + self.clip_g.reset_clip_options() + if self.t5xxl is not None: + self.t5xxl.reset_clip_options() + if self.llama is not None: + self.llama.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_l = token_weight_pairs["l"] + token_weight_pairs_g = token_weight_pairs["g"] + token_weight_pairs_t5 = token_weight_pairs["t5xxl"] + token_weight_pairs_llama = token_weight_pairs["llama"] + lg_out = None + pooled = None + extra = {} + + if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: + if self.clip_l is not None: + lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) + else: + l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device()) + + if self.clip_g is not None: + g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) + else: + g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device()) + + pooled = torch.cat((l_pooled, g_pooled), dim=-1) + + if self.t5xxl is not None: + t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5) + t5_out, t5_pooled = t5_output[:2] + else: + t5_out = None + + if self.llama is not None: + ll_output = self.llama.encode_token_weights(token_weight_pairs_llama) + ll_out, ll_pooled = ll_output[:2] + ll_out = ll_out[:, 1:] + else: + ll_out = None + + if t5_out is None: + t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device()) + + if ll_out is None: + ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device()) + + if pooled is None: + pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device()) + + extra["conditioning_llama3"] = ll_out + return t5_out, pooled, extra + + def load_sd(self, sd): + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + return self.clip_g.load_sd(sd) + elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + return self.clip_l.load_sd(sd) + elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: + return self.t5xxl.load_sd(sd) + else: + return self.llama.load_sd(sd) + + +def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None): + class HiDreamTEModel_(HiDreamTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["llama_scaled_fp8"] = llama_scaled_fp8 + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return HiDreamTEModel_ diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index dbb259e54..b02148b33 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -21,36 +21,41 @@ def llama_detect(state_dict, prefix=""): class LLAMA3Tokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data) class LLAMAModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): + def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}): llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) if llama_scaled_fp8 is not None: model_options = model_options.copy() model_options["scaled_fp8"] = llama_scaled_fp8 - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + textmodel_json_config = {} + vocab_size = model_options.get("vocab_size", None) + if vocab_size is not None: + textmodel_json_config["vocab_size"] = vocab_size + + model_options = {**model_options, "model_name": "llama"} + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class HunyuanVideoTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens - self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) + self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs): out = {} - out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) if llama_template is None: llama_text = self.llama_template.format(text) else: llama_text = llama_template.format(text) - llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids) + llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs) embed_count = 0 for r in llama_text_tokens: for i in range(len(r)): @@ -72,8 +77,7 @@ class HunyuanVideoClipModel(torch.nn.Module): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device) - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options) self.dtypes = set([dtype, dtype_llama]) diff --git a/comfy/text_encoders/hydit.py b/comfy/text_encoders/hydit.py index 7da3e9fc5..ac6994529 100644 --- a/comfy/text_encoders/hydit.py +++ b/comfy/text_encoders/hydit.py @@ -9,24 +9,26 @@ import torch class HyditBertModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json") + model_options = {**model_options, "model_name": "hydit_clip"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) class HyditBertTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data) class MT5XLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json") + model_options = {**model_options, "model_name": "mt5xl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) class MT5XLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): #tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model") tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} @@ -35,12 +37,12 @@ class HyditTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None) self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory) - self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory) + self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} - out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids) - out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids) + out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs) + out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 58710b2bf..34eb870e3 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -268,11 +268,17 @@ class Llama2_(nn.Module): optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) intermediate = None + all_intermediate = None if intermediate_output is not None: - if intermediate_output < 0: + if intermediate_output == "all": + all_intermediate = [] + intermediate_output = None + elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output for i, layer in enumerate(self.layers): + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) x = layer( x=x, attention_mask=mask, @@ -283,6 +289,12 @@ class Llama2_(nn.Module): intermediate = x.clone() x = self.norm(x) + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) + + if all_intermediate is not None: + intermediate = torch.cat(all_intermediate, dim=1) + if intermediate is not None and final_layer_norm_intermediate: intermediate = self.norm(intermediate) diff --git a/comfy/text_encoders/long_clipl.py b/comfy/text_encoders/long_clipl.py index b81912cb3..8d4c7619d 100644 --- a/comfy/text_encoders/long_clipl.py +++ b/comfy/text_encoders/long_clipl.py @@ -1,30 +1,27 @@ -from comfy import sd1_clip -import os -class LongClipTokenizer_(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - -class LongClipModel_(sd1_clip.SDClipModel): - def __init__(self, *args, **kwargs): - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json") - super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs) - -class LongClipTokenizer(sd1_clip.SD1Tokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_) - -class LongClipModel(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): - super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs) def model_options_long_clip(sd, tokenizer_data, model_options): w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None) + if w is None: + w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None) + else: + model_name = "clip_g" + if w is None: w = sd.get("text_model.embeddings.position_embedding.weight", None) - if w is not None and w.shape[0] == 248: + if w is not None: + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + model_name = "clip_g" + elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + model_name = "clip_l" + else: + model_name = "clip_l" + + if w is not None: tokenizer_data = tokenizer_data.copy() model_options = model_options.copy() - tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_ - model_options["clip_l_class"] = LongClipModel_ + model_config = model_options.get("model_config", {}) + model_config["max_position_embeddings"] = w.shape[0] + model_options["{}_model_config".format(model_name)] = model_config + tokenizer_data["{}_max_length".format(model_name)] = w.shape[0] return tokenizer_data, model_options diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 5c2ce583f..48ea67e67 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -6,7 +6,7 @@ import comfy.text_encoders.genmo class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128? + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128? class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index a7b1d702b..674461b75 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -6,7 +6,7 @@ import comfy.text_encoders.llama class Gemma2BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index d56d57f1b..b8de6bc4e 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -24,7 +24,7 @@ class PixArtT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding class PixArtTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sa_t5.py b/comfy/text_encoders/sa_t5.py index 7778ce47a..2803926ac 100644 --- a/comfy/text_encoders/sa_t5.py +++ b/comfy/text_encoders/sa_t5.py @@ -11,7 +11,7 @@ class T5BaseModel(sd1_clip.SDClipModel): class T5BaseTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) class SAT5Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sd2_clip.py b/comfy/text_encoders/sd2_clip.py index 31fc89869..700a23bf0 100644 --- a/comfy/text_encoders/sd2_clip.py +++ b/comfy/text_encoders/sd2_clip.py @@ -12,7 +12,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel): class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data) class SD2Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 3ad2ed93a..ff5d412db 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -15,6 +15,7 @@ class T5XXLModel(sd1_clip.SDClipModel): model_options = model_options.copy() model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options = {**model_options, "model_name": "t5xxl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -31,23 +32,22 @@ def t5_xxl_detect(state_dict, prefix=""): return out class T5XXLTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data) class SD3Tokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) - self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} - out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) - out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) - out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) + out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): @@ -61,8 +61,7 @@ class SD3ClipModel(torch.nn.Module): super().__init__() self.dtypes = set() if clip_l: - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) self.dtypes.add(dtype) else: self.clip_l = None diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py index 21df4f863..caccb3ca2 100644 --- a/comfy/text_encoders/spiece_tokenizer.py +++ b/comfy/text_encoders/spiece_tokenizer.py @@ -1,4 +1,5 @@ import torch +import os class SPieceTokenizer: @staticmethod @@ -15,6 +16,8 @@ class SPieceTokenizer: if isinstance(tokenizer_path, bytes): self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos) else: + if not os.path.isfile(tokenizer_path): + raise ValueError("invalid tokenizer") self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos) def get_vocab(self): diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index 971ac8fa8..d50fa4b28 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel): class UMT5XXlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0) + super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py new file mode 100644 index 000000000..d2a1d0151 --- /dev/null +++ b/comfy/weight_adapter/__init__.py @@ -0,0 +1,17 @@ +from .base import WeightAdapterBase +from .lora import LoRAAdapter +from .loha import LoHaAdapter +from .lokr import LoKrAdapter +from .glora import GLoRAAdapter +from .oft import OFTAdapter +from .boft import BOFTAdapter + + +adapters: list[type[WeightAdapterBase]] = [ + LoRAAdapter, + LoHaAdapter, + LoKrAdapter, + GLoRAAdapter, + OFTAdapter, + BOFTAdapter, +] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py new file mode 100644 index 000000000..29873519d --- /dev/null +++ b/comfy/weight_adapter/base.py @@ -0,0 +1,104 @@ +from typing import Optional + +import torch +import torch.nn as nn + +import comfy.model_management + + +class WeightAdapterBase: + name: str + loaded_keys: set[str] + weights: list[torch.Tensor] + + @classmethod + def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]: + raise NotImplementedError + + def to_train(self) -> "WeightAdapterTrainBase": + raise NotImplementedError + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + raise NotImplementedError + + +class WeightAdapterTrainBase(nn.Module): + def __init__(self): + super().__init__() + + # [TODO] Collaborate with LoRA training PR #7032 + + +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): + dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) + lora_diff *= alpha + weight_calc = weight + function(lora_diff).type(weight.dtype) + + wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0] + if wd_on_output_axis: + weight_norm = ( + weight.reshape(weight.shape[0], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[0], *[1] * (weight.dim() - 1)) + ) + else: + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + weight_norm = weight_norm + torch.finfo(weight.dtype).eps + + weight_calc *= (dora_scale / weight_norm).type(weight.dtype) + if strength != 1.0: + weight_calc -= weight + weight += strength * (weight_calc) + else: + weight[:] = weight_calc + return weight + + +def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: + """ + Pad a tensor to a new shape with zeros. + + Args: + tensor (torch.Tensor): The original tensor to be padded. + new_shape (List[int]): The desired shape of the padded tensor. + + Returns: + torch.Tensor: A new tensor padded with zeros to the specified shape. + + Note: + If the new shape is smaller than the original tensor in any dimension, + the original tensor will be truncated in that dimension. + """ + if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): + raise ValueError("The new shape must be larger than the original tensor in all dimensions") + + if len(new_shape) != len(tensor.shape): + raise ValueError("The new shape must have the same number of dimensions as the original tensor") + + # Create a new tensor filled with zeros + padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) + + # Create slicing tuples for both tensors + orig_slices = tuple(slice(0, dim) for dim in tensor.shape) + new_slices = tuple(slice(0, dim) for dim in tensor.shape) + + # Copy the original tensor into the new tensor + padded_tensor[new_slices] = tensor[orig_slices] + + return padded_tensor diff --git a/comfy/weight_adapter/boft.py b/comfy/weight_adapter/boft.py new file mode 100644 index 000000000..c85adc7ab --- /dev/null +++ b/comfy/weight_adapter/boft.py @@ -0,0 +1,115 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class BOFTAdapter(WeightAdapterBase): + name = "boft" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["BOFTAdapter"]: + if loaded_keys is None: + loaded_keys = set() + blocks_name = "{}.boft_blocks".format(x) + rescale_name = "{}.rescale".format(x) + + blocks = None + if blocks_name in lora.keys(): + blocks = lora[blocks_name] + if blocks.ndim == 4: + loaded_keys.add(blocks_name) + + rescale = None + if rescale_name in lora.keys(): + rescale = lora[rescale_name] + loaded_keys.add(rescale_name) + + if blocks is not None: + weights = (blocks, rescale, alpha, dora_scale) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + blocks = v[0] + rescale = v[1] + alpha = v[2] + dora_scale = v[3] + + blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + if rescale is not None: + rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + + boft_m, block_num, boft_b, *_ = blocks.shape + + try: + # Get r + I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype) + # for Q = -Q^T + q = blocks - blocks.transpose(1, 2) + normed_q = q + if alpha > 0: # alpha in boft/bboft is for constraint + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + # use float() to prevent unsupported type in .inverse() + r = (I + normed_q) @ (I - normed_q).float().inverse() + r = r.to(original_weight) + + inp = org = original_weight + + r_b = boft_b//2 + for i in range(boft_m): + bi = r[i] + g = 2 + k = 2**i * r_b + if strength != 1: + bi = bi * strength + (1-strength) * I + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, boft_b)) + ) + inp = torch.einsum("b n m, b n ... -> b m ...", inp, bi) + inp = ( + inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + ) + + if rescale is not None: + inp = inp * rescale + + lora_diff = inp - org + lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py new file mode 100644 index 000000000..939abbba5 --- /dev/null +++ b/comfy/weight_adapter/glora.py @@ -0,0 +1,93 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class GLoRAAdapter(WeightAdapterBase): + name = "glora" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["GLoRAAdapter"]: + if loaded_keys is None: + loaded_keys = set() + a1_name = "{}.a1.weight".format(x) + a2_name = "{}.a2.weight".format(x) + b1_name = "{}.b1.weight".format(x) + b2_name = "{}.b2.weight".format(x) + if a1_name in lora: + weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale) + loaded_keys.add(a1_name) + loaded_keys.add(a2_name) + loaded_keys.add(b1_name) + loaded_keys.add(b2_name) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + dora_scale = v[5] + + old_glora = False + if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: + rank = v[0].shape[0] + old_glora = True + + if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: + if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: + pass + else: + old_glora = False + rank = v[1].shape[0] + + a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) + a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) + b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) + b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) + + if v[4] is not None: + alpha = v[4] / rank + else: + alpha = 1.0 + + try: + if old_glora: + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora + else: + if weight.dim() > 2: + lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + else: + lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff += torch.mm(b1, b2).reshape(weight.shape) + + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py new file mode 100644 index 000000000..ce79abad5 --- /dev/null +++ b/comfy/weight_adapter/loha.py @@ -0,0 +1,100 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class LoHaAdapter(WeightAdapterBase): + name = "loha" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["LoHaAdapter"]: + if loaded_keys is None: + loaded_keys = set() + + hada_w1_a_name = "{}.hada_w1_a".format(x) + hada_w1_b_name = "{}.hada_w1_b".format(x) + hada_w2_a_name = "{}.hada_w2_a".format(x) + hada_w2_b_name = "{}.hada_w2_b".format(x) + hada_t1_name = "{}.hada_t1".format(x) + hada_t2_name = "{}.hada_t2".format(x) + if hada_w1_a_name in lora.keys(): + hada_t1 = None + hada_t2 = None + if hada_t1_name in lora.keys(): + hada_t1 = lora[hada_t1_name] + hada_t2 = lora[hada_t2_name] + loaded_keys.add(hada_t1_name) + loaded_keys.add(hada_t2_name) + + weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale) + loaded_keys.add(hada_w1_a_name) + loaded_keys.add(hada_w1_b_name) + loaded_keys.add(hada_w2_a_name) + loaded_keys.add(hada_w2_b_name) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha = v[2] / w1b.shape[0] + else: + alpha = 1.0 + + w2a = v[3] + w2b = v[4] + dora_scale = v[7] + if v[5] is not None: #cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) + else: + m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) + m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) + + try: + lora_diff = (m1 * m2).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py new file mode 100644 index 000000000..51233db2d --- /dev/null +++ b/comfy/weight_adapter/lokr.py @@ -0,0 +1,133 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class LoKrAdapter(WeightAdapterBase): + name = "lokr" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["LoKrAdapter"]: + if loaded_keys is None: + loaded_keys = set() + lokr_w1_name = "{}.lokr_w1".format(x) + lokr_w2_name = "{}.lokr_w2".format(x) + lokr_w1_a_name = "{}.lokr_w1_a".format(x) + lokr_w1_b_name = "{}.lokr_w1_b".format(x) + lokr_t2_name = "{}.lokr_t2".format(x) + lokr_w2_a_name = "{}.lokr_w2_a".format(x) + lokr_w2_b_name = "{}.lokr_w2_b".format(x) + + lokr_w1 = None + if lokr_w1_name in lora.keys(): + lokr_w1 = lora[lokr_w1_name] + loaded_keys.add(lokr_w1_name) + + lokr_w2 = None + if lokr_w2_name in lora.keys(): + lokr_w2 = lora[lokr_w2_name] + loaded_keys.add(lokr_w2_name) + + lokr_w1_a = None + if lokr_w1_a_name in lora.keys(): + lokr_w1_a = lora[lokr_w1_a_name] + loaded_keys.add(lokr_w1_a_name) + + lokr_w1_b = None + if lokr_w1_b_name in lora.keys(): + lokr_w1_b = lora[lokr_w1_b_name] + loaded_keys.add(lokr_w1_b_name) + + lokr_w2_a = None + if lokr_w2_a_name in lora.keys(): + lokr_w2_a = lora[lokr_w2_a_name] + loaded_keys.add(lokr_w2_a_name) + + lokr_w2_b = None + if lokr_w2_b_name in lora.keys(): + lokr_w2_b = lora[lokr_w2_b_name] + loaded_keys.add(lokr_w2_b_name) + + lokr_t2 = None + if lokr_t2_name in lora.keys(): + lokr_t2 = lora[lokr_t2_name] + loaded_keys.add(lokr_t2_name) + + if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): + weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dora_scale = v[8] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) + else: + w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) + else: + w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha = v[2] / dim + else: + alpha = 1.0 + + try: + lora_diff = torch.kron(w1, w2).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py new file mode 100644 index 000000000..b2e623924 --- /dev/null +++ b/comfy/weight_adapter/lora.py @@ -0,0 +1,142 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape + + +class LoRAAdapter(WeightAdapterBase): + name = "lora" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["LoRAAdapter"]: + if loaded_keys is None: + loaded_keys = set() + + reshape_name = "{}.reshape_weight".format(x) + regular_lora = "{}.lora_up.weight".format(x) + diffusers_lora = "{}_lora.up.weight".format(x) + diffusers2_lora = "{}.lora_B.weight".format(x) + diffusers3_lora = "{}.lora.up.weight".format(x) + mochi_lora = "{}.lora_B".format(x) + transformers_lora = "{}.lora_linear_layer.up.weight".format(x) + A_name = None + + if regular_lora in lora.keys(): + A_name = regular_lora + B_name = "{}.lora_down.weight".format(x) + mid_name = "{}.lora_mid.weight".format(x) + elif diffusers_lora in lora.keys(): + A_name = diffusers_lora + B_name = "{}_lora.down.weight".format(x) + mid_name = None + elif diffusers2_lora in lora.keys(): + A_name = diffusers2_lora + B_name = "{}.lora_A.weight".format(x) + mid_name = None + elif diffusers3_lora in lora.keys(): + A_name = diffusers3_lora + B_name = "{}.lora.down.weight".format(x) + mid_name = None + elif mochi_lora in lora.keys(): + A_name = mochi_lora + B_name = "{}.lora_A".format(x) + mid_name = None + elif transformers_lora in lora.keys(): + A_name = transformers_lora + B_name = "{}.lora_linear_layer.down.weight".format(x) + mid_name = None + + if A_name is not None: + mid = None + if mid_name is not None and mid_name in lora.keys(): + mid = lora[mid_name] + loaded_keys.add(mid_name) + reshape = None + if reshape_name in lora.keys(): + try: + reshape = lora[reshape_name].tolist() + loaded_keys.add(reshape_name) + except: + pass + weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape) + loaded_keys.add(A_name) + loaded_keys.add(B_name) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + mat1 = comfy.model_management.cast_to_device( + v[0], weight.device, intermediate_dtype + ) + mat2 = comfy.model_management.cast_to_device( + v[1], weight.device, intermediate_dtype + ) + dora_scale = v[4] + reshape = v[5] + + if reshape is not None: + weight = pad_tensor_to_shape(weight, reshape) + + if v[2] is not None: + alpha = v[2] / mat2.shape[0] + else: + alpha = 1.0 + + if v[3] is not None: + # locon mid weights, hopefully the math is fine because I didn't properly test it + mat3 = comfy.model_management.cast_to_device( + v[3], weight.device, intermediate_dtype + ) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = ( + torch.mm( + mat2.transpose(0, 1).flatten(start_dim=1), + mat3.transpose(0, 1).flatten(start_dim=1), + ) + .reshape(final_shape) + .transpose(0, 1) + ) + try: + lora_diff = torch.mm( + mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) + ).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py new file mode 100644 index 000000000..0ea229b79 --- /dev/null +++ b/comfy/weight_adapter/oft.py @@ -0,0 +1,94 @@ +import logging +from typing import Optional + +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose + + +class OFTAdapter(WeightAdapterBase): + name = "oft" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> Optional["OFTAdapter"]: + if loaded_keys is None: + loaded_keys = set() + blocks_name = "{}.oft_blocks".format(x) + rescale_name = "{}.rescale".format(x) + + blocks = None + if blocks_name in lora.keys(): + blocks = lora[blocks_name] + if blocks.ndim == 3: + loaded_keys.add(blocks_name) + + rescale = None + if rescale_name in lora.keys(): + rescale = lora[rescale_name] + loaded_keys.add(rescale_name) + + if blocks is not None: + weights = (blocks, rescale, alpha, dora_scale) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights + blocks = v[0] + rescale = v[1] + alpha = v[2] + dora_scale = v[3] + + blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + if rescale is not None: + rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + + block_num, block_size, *_ = blocks.shape + + try: + # Get r + I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype) + # for Q = -Q^T + q = blocks - blocks.transpose(1, 2) + normed_q = q + if alpha > 0: # alpha in oft/boft is for constraint + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + # use float() to prevent unsupported type in .inverse() + r = (I + normed_q) @ (I - normed_q).float().inverse() + r = r.to(original_weight) + lora_diff = torch.einsum( + "k n m, k n ... -> k m ...", + (r * strength) - strength * I, + original_weight, + ) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py new file mode 100644 index 000000000..66667946f --- /dev/null +++ b/comfy_api/input/__init__.py @@ -0,0 +1,8 @@ +from .basic_types import ImageInput, AudioInput +from .video_types import VideoInput + +__all__ = [ + "ImageInput", + "AudioInput", + "VideoInput", +] diff --git a/comfy_api/input/basic_types.py b/comfy_api/input/basic_types.py new file mode 100644 index 000000000..033fb7e27 --- /dev/null +++ b/comfy_api/input/basic_types.py @@ -0,0 +1,20 @@ +import torch +from typing import TypedDict + +ImageInput = torch.Tensor +""" +An image in format [B, H, W, C] where B is the batch size, C is the number of channels, +""" + +class AudioInput(TypedDict): + """ + TypedDict representing audio input. + """ + + waveform: torch.Tensor + """ + Tensor in the format [B, C, T] where B is the batch size, C is the number of channels, + """ + + sample_rate: int + diff --git a/comfy_api/input/video_types.py b/comfy_api/input/video_types.py new file mode 100644 index 000000000..0676e0e66 --- /dev/null +++ b/comfy_api/input/video_types.py @@ -0,0 +1,45 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import Optional +from comfy_api.util import VideoContainer, VideoCodec, VideoComponents + +class VideoInput(ABC): + """ + Abstract base class for video input types. + """ + + @abstractmethod + def get_components(self) -> VideoComponents: + """ + Abstract method to get the video components (images, audio, and frame rate). + + Returns: + VideoComponents containing images, audio, and frame rate + """ + pass + + @abstractmethod + def save_to( + self, + path: str, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None + ): + """ + Abstract method to save the video input to a file. + """ + pass + + # Provide a default implementation, but subclasses can provide optimized versions + # if possible. + def get_dimensions(self) -> tuple[int, int]: + """ + Returns the dimensions of the video input. + + Returns: + Tuple of (width, height) + """ + components = self.get_components() + return components.images.shape[2], components.images.shape[1] + diff --git a/comfy_api/input_impl/__init__.py b/comfy_api/input_impl/__init__.py new file mode 100644 index 000000000..02901b8b9 --- /dev/null +++ b/comfy_api/input_impl/__init__.py @@ -0,0 +1,7 @@ +from .video_types import VideoFromFile, VideoFromComponents + +__all__ = [ + # Implementations + "VideoFromFile", + "VideoFromComponents", +] diff --git a/comfy_api/input_impl/video_types.py b/comfy_api/input_impl/video_types.py new file mode 100644 index 000000000..12e5783db --- /dev/null +++ b/comfy_api/input_impl/video_types.py @@ -0,0 +1,224 @@ +from __future__ import annotations +from av.container import InputContainer +from av.subtitles.stream import SubtitleStream +from fractions import Fraction +from typing import Optional +from comfy_api.input import AudioInput +import av +import io +import json +import numpy as np +import torch +from comfy_api.input import VideoInput +from comfy_api.util import VideoContainer, VideoCodec, VideoComponents + +class VideoFromFile(VideoInput): + """ + Class representing video input from a file. + """ + + def __init__(self, file: str | io.BytesIO): + """ + Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object + containing the file contents. + """ + self.__file = file + + def get_dimensions(self) -> tuple[int, int]: + """ + Returns the dimensions of the video input. + + Returns: + Tuple of (width, height) + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode='r') as container: + for stream in container.streams: + if stream.type == 'video': + assert isinstance(stream, av.VideoStream) + return stream.width, stream.height + raise ValueError(f"No video stream found in file '{self.__file}'") + + def get_components_internal(self, container: InputContainer) -> VideoComponents: + # Get video frames + frames = [] + for frame in container.decode(video=0): + img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3) + img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3) + frames.append(img) + + images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) + + # Get frame rate + video_stream = next(s for s in container.streams if s.type == 'video') + frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1) + + # Get audio if available + audio = None + try: + container.seek(0) # Reset the container to the beginning + for stream in container.streams: + if stream.type != 'audio': + continue + assert isinstance(stream, av.AudioStream) + audio_frames = [] + for packet in container.demux(stream): + for frame in packet.decode(): + assert isinstance(frame, av.AudioFrame) + audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) + if len(audio_frames) > 0: + audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) + audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) + audio = AudioInput({ + "waveform": audio_tensor, + "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1, + }) + except StopIteration: + pass # No audio stream + + metadata = container.metadata + return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) + + def get_components(self) -> VideoComponents: + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode='r') as container: + return self.get_components_internal(container) + raise ValueError(f"No video stream found in file '{self.__file}'") + + def save_to( + self, + path: str, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None + ): + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode='r') as container: + container_format = container.format.name + video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None + reuse_streams = True + if format != VideoContainer.AUTO and format not in container_format.split(","): + reuse_streams = False + if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: + reuse_streams = False + + if not reuse_streams: + components = self.get_components_internal(container) + video = VideoFromComponents(components) + return video.save_to( + path, + format=format, + codec=codec, + metadata=metadata + ) + + streams = container.streams + with av.open(path, mode='w', options={"movflags": "use_metadata_tags"}) as output_container: + # Copy over the original metadata + for key, value in container.metadata.items(): + if metadata is None or key not in metadata: + output_container.metadata[key] = value + + # Add our new metadata + if metadata is not None: + for key, value in metadata.items(): + if isinstance(value, str): + output_container.metadata[key] = value + else: + output_container.metadata[key] = json.dumps(value) + + # Add streams to the new container + stream_map = {} + for stream in streams: + if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)): + out_stream = output_container.add_stream_from_template(template=stream, opaque=True) + stream_map[stream] = out_stream + + # Write packets to the new container + for packet in container.demux(): + if packet.stream in stream_map and packet.dts is not None: + packet.stream = stream_map[packet.stream] + output_container.mux(packet) + +class VideoFromComponents(VideoInput): + """ + Class representing video input from tensors. + """ + + def __init__(self, components: VideoComponents): + self.__components = components + + def get_components(self) -> VideoComponents: + return VideoComponents( + images=self.__components.images, + audio=self.__components.audio, + frame_rate=self.__components.frame_rate + ) + + def save_to( + self, + path: str, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None + ): + if format != VideoContainer.AUTO and format != VideoContainer.MP4: + raise ValueError("Only MP4 format is supported for now") + if codec != VideoCodec.AUTO and codec != VideoCodec.H264: + raise ValueError("Only H264 codec is supported for now") + with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output: + # Add metadata before writing any streams + if metadata is not None: + for key, value in metadata.items(): + output.metadata[key] = json.dumps(value) + + frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) + # Create a video stream + video_stream = output.add_stream('h264', rate=frame_rate) + video_stream.width = self.__components.images.shape[2] + video_stream.height = self.__components.images.shape[1] + video_stream.pix_fmt = 'yuv420p' + + # Create an audio stream + audio_sample_rate = 1 + audio_stream: Optional[av.AudioStream] = None + if self.__components.audio: + audio_sample_rate = int(self.__components.audio['sample_rate']) + audio_stream = output.add_stream('aac', rate=audio_sample_rate) + audio_stream.sample_rate = audio_sample_rate + audio_stream.format = 'fltp' + + # Encode video + for i, frame in enumerate(self.__components.images): + img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) + frame = av.VideoFrame.from_ndarray(img, format='rgb24') + frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 + packet = video_stream.encode(frame) + output.mux(packet) + + # Flush video + packet = video_stream.encode(None) + output.mux(packet) + + if audio_stream and self.__components.audio: + # Encode audio + samples_per_frame = int(audio_sample_rate / frame_rate) + num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame + for i in range(num_frames): + start = i * samples_per_frame + end = start + samples_per_frame + # TODO(Feature) - Add support for stereo audio + chunk = self.__components.audio['waveform'][0, 0, start:end].unsqueeze(0).numpy() + audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono') + audio_frame.sample_rate = audio_sample_rate + audio_frame.pts = i * samples_per_frame + for packet in audio_stream.encode(audio_frame): + output.mux(packet) + + # Flush audio + for packet in audio_stream.encode(None): + output.mux(packet) + diff --git a/comfy_api/util/__init__.py b/comfy_api/util/__init__.py new file mode 100644 index 000000000..9019c46db --- /dev/null +++ b/comfy_api/util/__init__.py @@ -0,0 +1,8 @@ +from .video_types import VideoContainer, VideoCodec, VideoComponents + +__all__ = [ + # Utility Types + "VideoContainer", + "VideoCodec", + "VideoComponents", +] diff --git a/comfy_api/util/video_types.py b/comfy_api/util/video_types.py new file mode 100644 index 000000000..d09663db9 --- /dev/null +++ b/comfy_api/util/video_types.py @@ -0,0 +1,51 @@ +from __future__ import annotations +from dataclasses import dataclass +from enum import Enum +from fractions import Fraction +from typing import Optional +from comfy_api.input import ImageInput, AudioInput + +class VideoCodec(str, Enum): + AUTO = "auto" + H264 = "h264" + + @classmethod + def as_input(cls) -> list[str]: + """ + Returns a list of codec names that can be used as node input. + """ + return [member.value for member in cls] + +class VideoContainer(str, Enum): + AUTO = "auto" + MP4 = "mp4" + + @classmethod + def as_input(cls) -> list[str]: + """ + Returns a list of container names that can be used as node input. + """ + return [member.value for member in cls] + + @classmethod + def get_extension(cls, value) -> str: + """ + Returns the file extension for the container. + """ + if isinstance(value, str): + value = cls(value) + if value == VideoContainer.MP4 or value == VideoContainer.AUTO: + return "mp4" + return "" + +@dataclass +class VideoComponents: + """ + Dataclass representing the components of a video. + """ + + images: ImageInput + frame_rate: Fraction + audio: Optional[AudioInput] = None + metadata: Optional[dict] = None + diff --git a/comfy_api_nodes/__init__.py b/comfy_api_nodes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy_api_nodes/apis/PixverseController.py b/comfy_api_nodes/apis/PixverseController.py new file mode 100644 index 000000000..29a3ab33b --- /dev/null +++ b/comfy_api_nodes/apis/PixverseController.py @@ -0,0 +1,17 @@ +# generated by datamodel-codegen: +# filename: https://api.comfy.org/openapi +# timestamp: 2025-04-23T15:56:33+00:00 + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel + +from . import PixverseDto + + +class ResponseData(BaseModel): + ErrCode: Optional[int] = None + ErrMsg: Optional[str] = None + Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None diff --git a/comfy_api_nodes/apis/PixverseDto.py b/comfy_api_nodes/apis/PixverseDto.py new file mode 100644 index 000000000..399512214 --- /dev/null +++ b/comfy_api_nodes/apis/PixverseDto.py @@ -0,0 +1,57 @@ +# generated by datamodel-codegen: +# filename: https://api.comfy.org/openapi +# timestamp: 2025-04-23T15:56:33+00:00 + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel, Field, constr + + +class V2OpenAPII2VResp(BaseModel): + video_id: Optional[int] = Field(None, description='Video_id') + + +class V2OpenAPIT2VReq(BaseModel): + aspect_ratio: str = Field( + ..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9'] + ) + duration: int = Field( + ..., + description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)', + examples=[5], + ) + model: str = Field( + ..., description='Model version (only supports v3.5)', examples=['v3.5'] + ) + motion_mode: Optional[str] = Field( + 'normal', + description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)', + examples=['normal'], + ) + negative_prompt: Optional[constr(max_length=2048)] = Field( + None, description='Negative prompt\n' + ) + prompt: constr(max_length=2048) = Field(..., description='Prompt') + quality: str = Field( + ..., + description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")', + examples=['540p'], + ) + seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647') + style: Optional[str] = Field( + None, + description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed', + examples=['anime'], + ) + template_id: Optional[int] = Field( + None, + description='Template ID (template_id must be activated before use)', + examples=[302325299692608], + ) + water_mark: Optional[bool] = Field( + False, + description='Watermark (true: add watermark, false: no watermark)', + examples=[False], + ) diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py new file mode 100644 index 000000000..e7ea9b332 --- /dev/null +++ b/comfy_api_nodes/apis/__init__.py @@ -0,0 +1,422 @@ +# generated by datamodel-codegen: +# filename: https://api.comfy.org/openapi +# timestamp: 2025-04-23T15:56:33+00:00 + +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from pydantic import AnyUrl, BaseModel, Field, confloat, conint + +class Customer(BaseModel): + createdAt: Optional[datetime] = Field( + None, description='The date and time the user was created' + ) + email: Optional[str] = Field(None, description='The email address for this user') + id: str = Field(..., description='The firebase UID of the user') + name: Optional[str] = Field(None, description='The name for this user') + updatedAt: Optional[datetime] = Field( + None, description='The date and time the user was last updated' + ) + + +class Error(BaseModel): + details: Optional[List[str]] = Field( + None, + description='Optional detailed information about the error or hints for resolving it.', + ) + message: Optional[str] = Field( + None, description='A clear and concise description of the error.' + ) + + +class ErrorResponse(BaseModel): + error: str + message: str + +class ImageRequest(BaseModel): + aspect_ratio: Optional[str] = Field( + None, + description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.", + ) + color_palette: Optional[Dict[str, Any]] = Field( + None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.' + ) + magic_prompt_option: Optional[str] = Field( + None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')." + ) + model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')") + negative_prompt: Optional[str] = Field( + None, + description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.', + ) + num_images: Optional[conint(ge=1, le=8)] = Field( + 1, description='Optional. Number of images to generate (1-8). Defaults to 1.' + ) + prompt: str = Field( + ..., description='Required. The prompt to use to generate the image.' + ) + resolution: Optional[str] = Field( + None, + description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.", + ) + seed: Optional[conint(ge=0, le=2147483647)] = Field( + None, description='Optional. A number between 0 and 2147483647.' + ) + style_type: Optional[str] = Field( + None, + description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.", + ) + + +class Datum(BaseModel): + is_image_safe: Optional[bool] = Field( + None, description='Indicates whether the image is considered safe.' + ) + prompt: Optional[str] = Field( + None, description='The prompt used to generate this image.' + ) + resolution: Optional[str] = Field( + None, description="The resolution of the generated image (e.g., '1024x1024')." + ) + seed: Optional[int] = Field( + None, description='The seed value used for this generation.' + ) + style_type: Optional[str] = Field( + None, + description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').", + ) + url: Optional[str] = Field(None, description='URL to the generated image.') + + +class Code(Enum): + int_1100 = 1100 + int_1101 = 1101 + int_1102 = 1102 + int_1103 = 1103 + + +class Code1(Enum): + int_1000 = 1000 + int_1001 = 1001 + int_1002 = 1002 + int_1003 = 1003 + int_1004 = 1004 + + +class AspectRatio(str, Enum): + field_16_9 = '16:9' + field_9_16 = '9:16' + field_1_1 = '1:1' + + +class Config(BaseModel): + horizontal: Optional[confloat(ge=-10.0, le=10.0)] = None + pan: Optional[confloat(ge=-10.0, le=10.0)] = None + roll: Optional[confloat(ge=-10.0, le=10.0)] = None + tilt: Optional[confloat(ge=-10.0, le=10.0)] = None + vertical: Optional[confloat(ge=-10.0, le=10.0)] = None + zoom: Optional[confloat(ge=-10.0, le=10.0)] = None + + +class Type(str, Enum): + simple = 'simple' + down_back = 'down_back' + forward_up = 'forward_up' + right_turn_forward = 'right_turn_forward' + left_turn_forward = 'left_turn_forward' + + +class CameraControl(BaseModel): + config: Optional[Config] = None + type: Optional[Type] = Field(None, description='Predefined camera movements type') + + +class Duration(str, Enum): + field_5 = 5 + field_10 = 10 + + +class Mode(str, Enum): + std = 'std' + pro = 'pro' + + +class TaskInfo(BaseModel): + external_task_id: Optional[str] = None + + +class Video(BaseModel): + duration: Optional[str] = Field(None, description='Total video duration') + id: Optional[str] = Field(None, description='Generated video ID') + url: Optional[AnyUrl] = Field(None, description='URL for generated video') + + +class TaskResult(BaseModel): + videos: Optional[List[Video]] = None + + +class TaskStatus(str, Enum): + submitted = 'submitted' + processing = 'processing' + succeed = 'succeed' + failed = 'failed' + + +class Data(BaseModel): + created_at: Optional[int] = Field(None, description='Task creation time') + task_id: Optional[str] = Field(None, description='Task ID') + task_info: Optional[TaskInfo] = None + task_result: Optional[TaskResult] = None + task_status: Optional[TaskStatus] = None + updated_at: Optional[int] = Field(None, description='Task update time') + + +class AspectRatio1(str, Enum): + field_16_9 = '16:9' + field_9_16 = '9:16' + field_1_1 = '1:1' + field_4_3 = '4:3' + field_3_4 = '3:4' + field_3_2 = '3:2' + field_2_3 = '2:3' + field_21_9 = '21:9' + + +class ImageReference(str, Enum): + subject = 'subject' + face = 'face' + + +class Image(BaseModel): + index: Optional[int] = Field(None, description='Image Number (0-9)') + url: Optional[AnyUrl] = Field(None, description='URL for generated image') + + +class TaskResult1(BaseModel): + images: Optional[List[Image]] = None + + +class Data1(BaseModel): + created_at: Optional[int] = Field(None, description='Task creation time') + task_id: Optional[str] = Field(None, description='Task ID') + task_result: Optional[TaskResult1] = None + task_status: Optional[TaskStatus] = None + task_status_msg: Optional[str] = Field(None, description='Task status information') + updated_at: Optional[int] = Field(None, description='Task update time') + + +class AspectRatio2(str, Enum): + field_16_9 = '16:9' + field_9_16 = '9:16' + field_1_1 = '1:1' + + +class CameraControl1(BaseModel): + config: Optional[Config] = None + type: Optional[Type] = Field(None, description='Predefined camera movements type') + + +class ModelName2(str, Enum): + kling_v1 = 'kling-v1' + kling_v1_6 = 'kling-v1-6' + + +class TaskResult2(BaseModel): + videos: Optional[List[Video]] = None + + +class Data2(BaseModel): + created_at: Optional[int] = Field(None, description='Task creation time') + task_id: Optional[str] = Field(None, description='Task ID') + task_info: Optional[TaskInfo] = None + task_result: Optional[TaskResult2] = None + task_status: Optional[TaskStatus] = None + updated_at: Optional[int] = Field(None, description='Task update time') + + +class Code2(Enum): + int_1200 = 1200 + int_1201 = 1201 + int_1202 = 1202 + int_1203 = 1203 + + +class ResourcePackType(str, Enum): + decreasing_total = 'decreasing_total' + constant_period = 'constant_period' + + +class Status(str, Enum): + toBeOnline = 'toBeOnline' + online = 'online' + expired = 'expired' + runOut = 'runOut' + + +class ResourcePackSubscribeInfo(BaseModel): + effective_time: Optional[int] = Field( + None, description='Effective time, Unix timestamp in ms' + ) + invalid_time: Optional[int] = Field( + None, description='Expiration time, Unix timestamp in ms' + ) + purchase_time: Optional[int] = Field( + None, description='Purchase time, Unix timestamp in ms' + ) + remaining_quantity: Optional[float] = Field( + None, description='Remaining quantity (updated with a 12-hour delay)' + ) + resource_pack_id: Optional[str] = Field(None, description='Resource package ID') + resource_pack_name: Optional[str] = Field(None, description='Resource package name') + resource_pack_type: Optional[ResourcePackType] = Field( + None, + description='Resource package type (decreasing_total=decreasing total, constant_period=constant periodicity)', + ) + status: Optional[Status] = Field(None, description='Resource Package Status') + total_quantity: Optional[float] = Field(None, description='Total quantity') + +class Background(str, Enum): + transparent = 'transparent' + opaque = 'opaque' + + +class Moderation(str, Enum): + low = 'low' + auto = 'auto' + + +class OutputFormat(str, Enum): + png = 'png' + webp = 'webp' + jpeg = 'jpeg' + + +class Quality(str, Enum): + low = 'low' + medium = 'medium' + high = 'high' + + +class OpenAIImageEditRequest(BaseModel): + background: Optional[str] = Field( + None, description='Background transparency', examples=['opaque'] + ) + model: str = Field( + ..., description='The model to use for image editing', examples=['gpt-image-1'] + ) + moderation: Optional[Moderation] = Field( + None, description='Content moderation setting', examples=['auto'] + ) + n: Optional[int] = Field( + None, description='The number of images to generate', examples=[1] + ) + output_compression: Optional[int] = Field( + None, description='Compression level for JPEG or WebP (0-100)', examples=[100] + ) + output_format: Optional[OutputFormat] = Field( + None, description='Format of the output image', examples=['png'] + ) + prompt: str = Field( + ..., + description='A text description of the desired edit', + examples=['Give the rocketship rainbow coloring'], + ) + quality: Optional[str] = Field( + None, description='The quality of the edited image', examples=['low'] + ) + size: Optional[str] = Field( + None, description='Size of the output image', examples=['1024x1024'] + ) + user: Optional[str] = Field( + None, + description='A unique identifier for end-user monitoring', + examples=['user-1234'], + ) + + +class Quality1(str, Enum): + low = 'low' + medium = 'medium' + high = 'high' + standard = 'standard' + hd = 'hd' + + +class ResponseFormat(str, Enum): + url = 'url' + b64_json = 'b64_json' + + +class Style(str, Enum): + vivid = 'vivid' + natural = 'natural' + + +class OpenAIImageGenerationRequest(BaseModel): + background: Optional[Background] = Field( + None, description='Background transparency', examples=['opaque'] + ) + model: Optional[str] = Field( + None, description='The model to use for image generation', examples=['dall-e-3'] + ) + moderation: Optional[Moderation] = Field( + None, description='Content moderation setting', examples=['auto'] + ) + n: Optional[int] = Field( + None, + description='The number of images to generate (1-10). Only 1 supported for dall-e-3.', + examples=[1], + ) + output_compression: Optional[int] = Field( + None, description='Compression level for JPEG or WebP (0-100)', examples=[100] + ) + output_format: Optional[OutputFormat] = Field( + None, description='Format of the output image', examples=['png'] + ) + prompt: str = Field( + ..., + description='A text description of the desired image', + examples=['Draw a rocket in front of a blackhole in deep space'], + ) + quality: Optional[Quality1] = Field( + None, description='The quality of the generated image', examples=['high'] + ) + response_format: Optional[ResponseFormat] = Field( + None, description='Response format of image data', examples=['b64_json'] + ) + size: Optional[str] = Field( + None, + description='Size of the image (e.g., 1024x1024, 1536x1024, auto)', + examples=['1024x1536'], + ) + style: Optional[Style] = Field( + None, description='Style of the image (only for dall-e-3)', examples=['vivid'] + ) + user: Optional[str] = Field( + None, + description='A unique identifier for end-user monitoring', + examples=['user-1234'], + ) + + +class Datum1(BaseModel): + b64_json: Optional[str] = Field(None, description='Base64 encoded image data') + revised_prompt: Optional[str] = Field(None, description='Revised prompt') + url: Optional[str] = Field(None, description='URL of the image') + + +class OpenAIImageGenerationResponse(BaseModel): + data: Optional[List[Datum1]] = None +class User(BaseModel): + email: Optional[str] = Field(None, description='The email address for this user.') + id: Optional[str] = Field(None, description='The unique id for this user.') + isAdmin: Optional[bool] = Field( + None, description='Indicates if the user has admin privileges.' + ) + isApproved: Optional[bool] = Field( + None, description='Indicates if the user is approved.' + ) + name: Optional[str] = Field(None, description='The name for this user.') diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py new file mode 100644 index 000000000..d3cd9ad2f --- /dev/null +++ b/comfy_api_nodes/apis/client.py @@ -0,0 +1,341 @@ +import logging + +""" +API Client Framework for api.comfy.org. + +This module provides a flexible framework for making API requests from ComfyUI nodes. +It supports both synchronous and asynchronous API operations with proper type validation. + +Key Components: +-------------- +1. ApiClient - Handles HTTP requests with authentication and error handling +2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models +3. ApiOperation - Executes a single synchronous API operation + +Usage Examples: +-------------- + +# Example 1: Synchronous API Operation +# ------------------------------------ +# For a simple API call that returns the result immediately: + +# 1. Create the API client +api_client = ApiClient( + base_url="https://api.example.com", + api_key="your_api_key_here", + timeout=30.0, + verify_ssl=True +) + +# 2. Define the endpoint +user_info_endpoint = ApiEndpoint( + path="/v1/users/me", + method=HttpMethod.GET, + request_model=EmptyRequest, # No request body needed + response_model=UserProfile, # Pydantic model for the response + query_params=None +) + +# 3. Create the request object +request = EmptyRequest() + +# 4. Create and execute the operation +operation = ApiOperation( + endpoint=user_info_endpoint, + request=request +) +user_profile = operation.execute(client=api_client) # Returns immediately with the result + +""" + +from typing import ( + Dict, + Type, + Optional, + Any, + TypeVar, + Generic, +) +from pydantic import BaseModel +from enum import Enum +import json +import requests +from urllib.parse import urljoin + +T = TypeVar("T", bound=BaseModel) +R = TypeVar("R", bound=BaseModel) + +class EmptyRequest(BaseModel): + """Base class for empty request bodies. + For GET requests, fields will be sent as query parameters.""" + + pass + + +class HttpMethod(str, Enum): + GET = "GET" + POST = "POST" + PUT = "PUT" + DELETE = "DELETE" + PATCH = "PATCH" + + +class ApiClient: + """ + Client for making HTTP requests to an API with authentication and error handling. + """ + + def __init__( + self, + base_url: str, + api_key: Optional[str] = None, + timeout: float = 30.0, + verify_ssl: bool = True, + ): + self.base_url = base_url + self.api_key = api_key + self.timeout = timeout + self.verify_ssl = verify_ssl + + def get_headers(self) -> Dict[str, str]: + """Get headers for API requests, including authentication if available""" + headers = {"Content-Type": "application/json", "Accept": "application/json"} + + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + return headers + + def request( + self, + method: str, + path: str, + params: Optional[Dict[str, Any]] = None, + json: Optional[Dict[str, Any]] = None, + files: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """ + Make an HTTP request to the API + + Args: + method: HTTP method (GET, POST, etc.) + path: API endpoint path (will be joined with base_url) + params: Query parameters + json: JSON body data + files: Files to upload + headers: Additional headers + + Returns: + Parsed JSON response + + Raises: + requests.RequestException: If the request fails + """ + url = urljoin(self.base_url, path) + self.check_auth_token(self.api_key) + # Combine default headers with any provided headers + request_headers = self.get_headers() + if headers: + request_headers.update(headers) + + # Let requests handle the content type when files are present. + if files: + del request_headers["Content-Type"] + + logging.debug(f"[DEBUG] Request Headers: {request_headers}") + logging.debug(f"[DEBUG] Files: {files}") + logging.debug(f"[DEBUG] Params: {params}") + logging.debug(f"[DEBUG] Json: {json}") + + try: + # If files are present, use data parameter instead of json + if files: + form_data = {} + if json: + form_data.update(json) + response = requests.request( + method=method, + url=url, + params=params, + data=form_data, # Use data instead of json + files=files, + headers=request_headers, + timeout=self.timeout, + verify=self.verify_ssl, + ) + else: + response = requests.request( + method=method, + url=url, + params=params, + json=json, + headers=request_headers, + timeout=self.timeout, + verify=self.verify_ssl, + ) + + # Raise exception for error status codes + response.raise_for_status() + except requests.ConnectionError: + raise Exception( + f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available." + ) + + except requests.Timeout: + raise Exception( + f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected." + ) + + except requests.HTTPError as e: + status_code = e.response.status_code if hasattr(e, "response") else None + error_message = f"HTTP Error: {str(e)}" + + # Try to extract detailed error message from JSON response + try: + if hasattr(e, "response") and e.response.content: + error_json = e.response.json() + if "error" in error_json and "message" in error_json["error"]: + error_message = f"API Error: {error_json['error']['message']}" + if "type" in error_json["error"]: + error_message += f" (Type: {error_json['error']['type']})" + else: + error_message = f"API Error: {error_json}" + except Exception as json_error: + # If we can't parse the JSON, fall back to the original error message + logging.debug(f"[DEBUG] Failed to parse error response: {str(json_error)}") + + logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})") + if hasattr(e, "response") and e.response.content: + logging.debug(f"[DEBUG] Response content: {e.response.content}") + if status_code == 401: + error_message = "Unauthorized: Please login first to use this node." + if status_code == 402: + error_message = "Payment Required: Please add credits to your account to use this node." + if status_code == 409: + error_message = "There is a problem with your account. Please contact support@comfy.org. " + if status_code == 429: + error_message = "Rate Limit Exceeded: Please try again later." + raise Exception(error_message) + + # Parse and return JSON response + if response.content: + return response.json() + return {} + + def check_auth_token(self, auth_token): + """Verify that an auth token is present.""" + if auth_token is None: + raise Exception("Unauthorized: Please login first to use this node.") + return auth_token + + +class ApiEndpoint(Generic[T, R]): + """Defines an API endpoint with its request and response types""" + + def __init__( + self, + path: str, + method: HttpMethod, + request_model: Type[T], + response_model: Type[R], + query_params: Optional[Dict[str, Any]] = None, + ): + """Initialize an API endpoint definition. + + Args: + path: The URL path for this endpoint, can include placeholders like {id} + method: The HTTP method to use (GET, POST, etc.) + request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint + response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint + query_params: Optional dictionary of query parameters to include in the request + """ + self.path = path + self.method = method + self.request_model = request_model + self.response_model = response_model + self.query_params = query_params or {} + + +class SynchronousOperation(Generic[T, R]): + """ + Represents a single synchronous API operation. + """ + + def __init__( + self, + endpoint: ApiEndpoint[T, R], + request: T, + files: Optional[Dict[str, Any]] = None, + api_base: str = "https://api.comfy.org", + auth_token: Optional[str] = None, + timeout: float = 604800.0, + verify_ssl: bool = True, + ): + self.endpoint = endpoint + self.request = request + self.response = None + self.error = None + self.api_base = api_base + self.auth_token = auth_token + self.timeout = timeout + self.verify_ssl = verify_ssl + self.files = files + def execute(self, client: Optional[ApiClient] = None) -> R: + """Execute the API operation using the provided client or create one""" + try: + # Create client if not provided + if client is None: + if self.api_base is None: + raise ValueError("Either client or api_base must be provided") + client = ApiClient( + base_url=self.api_base, + api_key=self.auth_token, + timeout=self.timeout, + verify_ssl=self.verify_ssl, + ) + + # Convert request model to dict, but use None for EmptyRequest + request_dict = None if isinstance(self.request, EmptyRequest) else self.request.model_dump(exclude_none=True) + if request_dict: + for key, value in request_dict.items(): + if isinstance(value, Enum): + request_dict[key] = value.value + + # Debug log for request + logging.debug(f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}") + logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") + logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") + + # Make the request + resp = client.request( + method=self.endpoint.method.value, + path=self.endpoint.path, + json=request_dict, + params=self.endpoint.query_params, + files=self.files, + ) + + # Debug log for response + logging.debug("=" * 50) + logging.debug("[DEBUG] RESPONSE DETAILS:") + logging.debug("[DEBUG] Status Code: 200 (Success)") + logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}") + logging.debug("=" * 50) + + # Parse and return the response + return self._parse_response(resp) + + except Exception as e: + logging.debug(f"[DEBUG] API Exception: {str(e)}") + raise Exception(str(e)) + + def _parse_response(self, resp): + """Parse response data - can be overridden by subclasses""" + # The response is already the complete object, don't extract just the "data" field + # as that would lose the outer structure (created timestamp, etc.) + + # Parse response using the provided model + self.response = self.endpoint.response_model.model_validate(resp) + logging.debug(f"[DEBUG] Parsed Response: {self.response}") + return self.response diff --git a/comfy_api_nodes/nodes_api.py b/comfy_api_nodes/nodes_api.py new file mode 100644 index 000000000..a977bb9b7 --- /dev/null +++ b/comfy_api_nodes/nodes_api.py @@ -0,0 +1,449 @@ +import base64 +import io +import math +from inspect import cleandoc + +import numpy as np +import requests +import torch +from PIL import Image + +from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict +from comfy.utils import common_upscale +from comfy_api_nodes.apis import ( + OpenAIImageEditRequest, + OpenAIImageGenerationRequest, + OpenAIImageGenerationResponse, +) +from comfy_api_nodes.apis.client import ApiEndpoint, HttpMethod, SynchronousOperation + + +def downscale_input(image): + samples = image.movedim(-1,1) + #downscaling input images to roughly the same size as the outputs + total = int(1536 * 1024) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + if scale_by >= 1: + return image + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = common_upscale(samples, width, height, "lanczos", "disabled") + s = s.movedim(1,-1) + return s + +def validate_and_cast_response(response): + # validate raw JSON response + data = response.data + if not data or len(data) == 0: + raise Exception("No images returned from API endpoint") + + # Initialize list to store image tensors + image_tensors = [] + + # Process each image in the data array + for image_data in data: + image_url = image_data.url + b64_data = image_data.b64_json + + if not image_url and not b64_data: + raise Exception("No image was generated in the response") + + if b64_data: + img_data = base64.b64decode(b64_data) + img = Image.open(io.BytesIO(img_data)) + + elif image_url: + img_response = requests.get(image_url) + if img_response.status_code != 200: + raise Exception("Failed to download the image") + img = Image.open(io.BytesIO(img_response.content)) + + img = img.convert("RGBA") + + # Convert to numpy array, normalize to float32 between 0 and 1 + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array) + + # Add to list of tensors + image_tensors.append(img_tensor) + + return torch.stack(image_tensors, dim=0) + +class OpenAIDalle2(ComfyNodeABC): + """ + Generates images synchronously via OpenAI's DALL·E 2 endpoint. + + Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived, + so download or cache results if you need to keep them. + """ + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "prompt": (IO.STRING, { + "multiline": True, + "default": "", + "tooltip": "Text prompt for DALL·E", + }), + }, + "optional": { + "seed": (IO.INT, { + "default": 0, + "min": 0, + "max": 2**31-1, + "step": 1, + "display": "number", + "tooltip": "not implemented yet in backend", + }), + "size": (IO.COMBO, { + "options": ["256x256", "512x512", "1024x1024"], + "default": "1024x1024", + "tooltip": "Image size", + }), + "n": (IO.INT, { + "default": 1, + "min": 1, + "max": 8, + "step": 1, + "display": "number", + "tooltip": "How many images to generate", + }), + "image": (IO.IMAGE, { + "default": None, + "tooltip": "Optional reference image for image editing.", + }), + "mask": (IO.MASK, { + "default": None, + "tooltip": "Optional mask for inpainting (white areas will be replaced)", + }), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG" + } + } + + RETURN_TYPES = (IO.IMAGE,) + FUNCTION = "api_call" + CATEGORY = "api node" + DESCRIPTION = cleandoc(__doc__ or "") + API_NODE = True + + def api_call(self, prompt, seed=0, image=None, mask=None, n=1, size="1024x1024", auth_token=None): + model = "dall-e-2" + path = "/proxy/openai/images/generations" + request_class = OpenAIImageGenerationRequest + img_binary = None + + if image is not None and mask is not None: + path = "/proxy/openai/images/edits" + request_class = OpenAIImageEditRequest + + input_tensor = image.squeeze().cpu() + height, width, channels = input_tensor.shape + rgba_tensor = torch.ones(height, width, 4, device="cpu") + rgba_tensor[:, :, :channels] = input_tensor + + if mask.shape[1:] != image.shape[1:-1]: + raise Exception("Mask and Image must be the same size") + rgba_tensor[:,:,3] = (1-mask.squeeze().cpu()) + + rgba_tensor = downscale_input(rgba_tensor.unsqueeze(0)).squeeze() + + image_np = (rgba_tensor.numpy() * 255).astype(np.uint8) + img = Image.fromarray(image_np) + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, format='PNG') + img_byte_arr.seek(0) + img_binary = img_byte_arr#.getvalue() + img_binary.name = "image.png" + elif image is not None or mask is not None: + raise Exception("Dall-E 2 image editing requires an image AND a mask") + + # Build the operation + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=request_class, + response_model=OpenAIImageGenerationResponse + ), + request=request_class( + model=model, + prompt=prompt, + n=n, + size=size, + seed=seed, + ), + files={ + "image": img_binary, + } if img_binary else None, + auth_token=auth_token + ) + + response = operation.execute() + + img_tensor = validate_and_cast_response(response) + return (img_tensor,) + +class OpenAIDalle3(ComfyNodeABC): + """ + Generates images synchronously via OpenAI's DALL·E 3 endpoint. + + Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived, + so download or cache results if you need to keep them. + """ + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "prompt": (IO.STRING, { + "multiline": True, + "default": "", + "tooltip": "Text prompt for DALL·E", + }), + }, + "optional": { + "seed": (IO.INT, { + "default": 0, + "min": 0, + "max": 2**31-1, + "step": 1, + "display": "number", + "tooltip": "not implemented yet in backend", + }), + "quality" : (IO.COMBO, { + "options": ["standard","hd"], + "default": "standard", + "tooltip": "Image quality", + }), + "style": (IO.COMBO, { + "options": ["natural","vivid"], + "default": "natural", + "tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.", + }), + "size": (IO.COMBO, { + "options": ["1024x1024", "1024x1792", "1792x1024"], + "default": "1024x1024", + "tooltip": "Image size", + }), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG" + } + } + + RETURN_TYPES = (IO.IMAGE,) + FUNCTION = "api_call" + CATEGORY = "api node" + DESCRIPTION = cleandoc(__doc__ or "") + API_NODE = True + + def api_call(self, prompt, seed=0, style="natural", quality="standard", size="1024x1024", auth_token=None): + model = "dall-e-3" + + # build the operation + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/openai/images/generations", + method=HttpMethod.POST, + request_model=OpenAIImageGenerationRequest, + response_model=OpenAIImageGenerationResponse + ), + request=OpenAIImageGenerationRequest( + model=model, + prompt=prompt, + quality=quality, + size=size, + style=style, + seed=seed, + ), + auth_token=auth_token + ) + + response = operation.execute() + + img_tensor = validate_and_cast_response(response) + return (img_tensor,) + +class OpenAIGPTImage1(ComfyNodeABC): + """ + Generates images synchronously via OpenAI's GPT Image 1 endpoint. + + Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived, + so download or cache results if you need to keep them. + """ + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "prompt": (IO.STRING, { + "multiline": True, + "default": "", + "tooltip": "Text prompt for GPT Image 1", + }), + }, + "optional": { + "seed": (IO.INT, { + "default": 0, + "min": 0, + "max": 2**31-1, + "step": 1, + "display": "number", + "tooltip": "not implemented yet in backend", + }), + "quality": (IO.COMBO, { + "options": ["low","medium","high"], + "default": "low", + "tooltip": "Image quality, affects cost and generation time.", + }), + "background": (IO.COMBO, { + "options": ["opaque","transparent"], + "default": "opaque", + "tooltip": "Return image with or without background", + }), + "size": (IO.COMBO, { + "options": ["auto", "1024x1024", "1024x1536", "1536x1024"], + "default": "auto", + "tooltip": "Image size", + }), + "n": (IO.INT, { + "default": 1, + "min": 1, + "max": 8, + "step": 1, + "display": "number", + "tooltip": "How many images to generate", + }), + "image": (IO.IMAGE, { + "default": None, + "tooltip": "Optional reference image for image editing.", + }), + "mask": (IO.MASK, { + "default": None, + "tooltip": "Optional mask for inpainting (white areas will be replaced)", + }), + "moderation": (IO.COMBO, { + "options": ["low","auto"], + "default": "low", + "tooltip": "Moderation level", + }), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG" + } + } + + RETURN_TYPES = (IO.IMAGE,) + FUNCTION = "api_call" + CATEGORY = "api node" + DESCRIPTION = cleandoc(__doc__ or "") + API_NODE = True + + def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None, moderation="low"): + model = "gpt-image-1" + path = "/proxy/openai/images/generations" + request_class = OpenAIImageGenerationRequest + img_binaries = [] + mask_binary = None + files = [] + + if image is not None: + path = "/proxy/openai/images/edits" + request_class = OpenAIImageEditRequest + + batch_size = image.shape[0] + + + for i in range(batch_size): + single_image = image[i:i+1] + scaled_image = downscale_input(single_image).squeeze() + + image_np = (scaled_image.numpy() * 255).astype(np.uint8) + img = Image.fromarray(image_np) + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, format='PNG') + img_byte_arr.seek(0) + img_binary = img_byte_arr + img_binary.name = f"image_{i}.png" + + img_binaries.append(img_binary) + if batch_size == 1: + files.append(("image", img_binary)) + else: + files.append(("image[]", img_binary)) + + if mask is not None: + if image.shape[0] != 1: + raise Exception("Cannot use a mask with multiple image") + if image is None: + raise Exception("Cannot use a mask without an input image") + if mask.shape[1:] != image.shape[1:-1]: + raise Exception("Mask and Image must be the same size") + batch, height, width = mask.shape + rgba_mask = torch.zeros(height, width, 4, device="cpu") + rgba_mask[:,:,3] = (1-mask.squeeze().cpu()) + + scaled_mask = downscale_input(rgba_mask.unsqueeze(0)).squeeze() + + mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) + mask_img = Image.fromarray(mask_np) + mask_img_byte_arr = io.BytesIO() + mask_img.save(mask_img_byte_arr, format='PNG') + mask_img_byte_arr.seek(0) + mask_binary = mask_img_byte_arr + mask_binary.name = "mask.png" + files.append(("mask", mask_binary)) + + + # Build the operation + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=request_class, + response_model=OpenAIImageGenerationResponse + ), + request=request_class( + model=model, + prompt=prompt, + quality=quality, + background=background, + n=n, + seed=seed, + size=size, + moderation=moderation, + ), + files=files if files else None, + auth_token=auth_token + ) + + response = operation.execute() + + img_tensor = validate_and_cast_response(response) + return (img_tensor,) + + +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique +NODE_CLASS_MAPPINGS = { + "OpenAIDalle2": OpenAIDalle2, + "OpenAIDalle3": OpenAIDalle3, + "OpenAIGPTImage1": OpenAIGPTImage1, +} + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "OpenAIDalle2": "OpenAI DALL·E 2", + "OpenAIDalle3": "OpenAI DALL·E 3", + "OpenAIGPTImage1": "OpenAI GPT Image 1", +} diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 630f280fc..dbb37b89f 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -316,3 +316,156 @@ class LRUCache(BasicCache): self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self + +class DependencyAwareCache(BasicCache): + """ + A cache implementation that tracks dependencies between nodes and manages + their execution and caching accordingly. It extends the BasicCache class. + Nodes are removed from this cache once all of their descendants have been + executed. + """ + + def __init__(self, key_class): + """ + Initialize the DependencyAwareCache. + + Args: + key_class: The class used for generating cache keys. + """ + super().__init__(key_class) + self.descendants = {} # Maps node_id -> set of descendant node_ids + self.ancestors = {} # Maps node_id -> set of ancestor node_ids + self.executed_nodes = set() # Tracks nodes that have been executed + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + """ + Clear the entire cache and rebuild the dependency graph. + + Args: + dynprompt: The dynamic prompt object containing node information. + node_ids: List of node IDs to initialize the cache for. + is_changed_cache: Flag indicating if the cache has changed. + """ + # Clear all existing cache data + self.cache.clear() + self.subcaches.clear() + self.descendants.clear() + self.ancestors.clear() + self.executed_nodes.clear() + + # Call the parent method to initialize the cache with the new prompt + super().set_prompt(dynprompt, node_ids, is_changed_cache) + + # Rebuild the dependency graph + self._build_dependency_graph(dynprompt, node_ids) + + def _build_dependency_graph(self, dynprompt, node_ids): + """ + Build the dependency graph for all nodes. + + Args: + dynprompt: The dynamic prompt object containing node information. + node_ids: List of node IDs to build the graph for. + """ + self.descendants.clear() + self.ancestors.clear() + for node_id in node_ids: + self.descendants[node_id] = set() + self.ancestors[node_id] = set() + + for node_id in node_ids: + inputs = dynprompt.get_node(node_id)["inputs"] + for input_data in inputs.values(): + if is_link(input_data): # Check if the input is a link to another node + ancestor_id = input_data[0] + self.descendants[ancestor_id].add(node_id) + self.ancestors[node_id].add(ancestor_id) + + def set(self, node_id, value): + """ + Mark a node as executed and store its value in the cache. + + Args: + node_id: The ID of the node to store. + value: The value to store for the node. + """ + self._set_immediate(node_id, value) + self.executed_nodes.add(node_id) + self._cleanup_ancestors(node_id) + + def get(self, node_id): + """ + Retrieve the cached value for a node. + + Args: + node_id: The ID of the node to retrieve. + + Returns: + The cached value for the node. + """ + return self._get_immediate(node_id) + + def ensure_subcache_for(self, node_id, children_ids): + """ + Ensure a subcache exists for a node and update dependencies. + + Args: + node_id: The ID of the parent node. + children_ids: List of child node IDs to associate with the parent node. + + Returns: + The subcache object for the node. + """ + subcache = super()._ensure_subcache(node_id, children_ids) + for child_id in children_ids: + self.descendants[node_id].add(child_id) + self.ancestors[child_id].add(node_id) + return subcache + + def _cleanup_ancestors(self, node_id): + """ + Check if ancestors of a node can be removed from the cache. + + Args: + node_id: The ID of the node whose ancestors are to be checked. + """ + for ancestor_id in self.ancestors.get(node_id, []): + if ancestor_id in self.executed_nodes: + # Remove ancestor if all its descendants have been executed + if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]): + self._remove_node(ancestor_id) + + def _remove_node(self, node_id): + """ + Remove a node from the cache. + + Args: + node_id: The ID of the node to remove. + """ + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + del self.cache[cache_key] + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + del self.subcaches[subcache_key] + + def clean_unused(self): + """ + Clean up unused nodes. This is a no-op for this cache implementation. + """ + pass + + def recursive_debug_dump(self): + """ + Dump the cache and dependency graph for debugging. + + Returns: + A list containing the cache state and dependency graph. + """ + result = super().recursive_debug_dump() + result.append({ + "descendants": self.descendants, + "ancestors": self.ancestors, + "executed_nodes": list(self.executed_nodes), + }) + return result diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 59b42b746..a2799b52e 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -1,6 +1,9 @@ -import nodes +from __future__ import annotations +from typing import Type, Literal +import nodes from comfy_execution.graph_utils import is_link +from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions class DependencyCycleError(Exception): pass @@ -54,7 +57,22 @@ class DynamicPrompt: def get_original_prompt(self): return self.original_prompt -def get_input_info(class_def, input_name, valid_inputs=None): +def get_input_info( + class_def: Type[ComfyNodeABC], + input_name: str, + valid_inputs: InputTypeDict | None = None +) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]: + """Get the input type, category, and extra info for a given input name. + + Arguments: + class_def: The class definition of the node. + input_name: The name of the input to get info for. + valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES(). + + Returns: + tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name. + """ + valid_inputs = valid_inputs or class_def.INPUT_TYPES() input_info = None input_category = None @@ -126,7 +144,7 @@ class TopologicalSort: from_node_id, from_socket = value if subgraph_nodes is not None and from_node_id not in subgraph_nodes: continue - input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): node_ids.append(from_node_id) diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py new file mode 100644 index 000000000..1fb686644 --- /dev/null +++ b/comfy_extras/nodes_cfg.py @@ -0,0 +1,45 @@ +import torch + +# https://github.com/WeichenFan/CFG-Zero-star +def optimized_scale(positive, negative): + positive_flat = positive.reshape(positive.shape[0], -1) + negative_flat = negative.reshape(negative.shape[0], -1) + + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1)) + +class CFGZeroStar: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("patched_model",) + FUNCTION = "patch" + CATEGORY = "advanced/guidance" + + def patch(self, model): + m = model.clone() + def cfg_zero_star(args): + guidance_scale = args['cond_scale'] + x = args['input'] + cond_p = args['cond_denoised'] + uncond_p = args['uncond_denoised'] + out = args["denoised"] + alpha = optimized_scale(x - cond_p, x - uncond_p) + + return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) + m.set_model_sampler_post_cfg_function(cfg_zero_star) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "CFGZeroStar": CFGZeroStar +} diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py index 4c3a1d5bf..574262178 100644 --- a/comfy_extras/nodes_cond.py +++ b/comfy_extras/nodes_cond.py @@ -20,6 +20,29 @@ class CLIPTextEncodeControlnet: c.append(n) return (c, ) +class T5TokenizerOptions: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "clip": ("CLIP", ), + "min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}), + "min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}), + } + } + + RETURN_TYPES = ("CLIP",) + FUNCTION = "set_options" + + def set_options(self, clip, min_padding, min_length): + clip = clip.clone() + for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]: + clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding) + clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length) + + return (clip, ) + NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet + "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet, + "T5TokenizerOptions": T5TokenizerOptions, } diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py new file mode 100644 index 000000000..ee310c874 --- /dev/null +++ b/comfy_extras/nodes_fresca.py @@ -0,0 +1,100 @@ +# Code based on https://github.com/WikiChao/FreSca (MIT License) +import torch +import torch.fft as fft + + +def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): + """ + Apply frequency-dependent scaling to an image tensor using Fourier transforms. + + Parameters: + x: Input tensor of shape (B, C, H, W) + scale_low: Scaling factor for low-frequency components (default: 1.0) + scale_high: Scaling factor for high-frequency components (default: 1.5) + freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20) + + Returns: + x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied. + """ + # Preserve input dtype and device + dtype, device = x.dtype, x.device + + # Convert to float32 for FFT computations + x = x.to(torch.float32) + + # 1) Apply FFT and shift low frequencies to center + x_freq = fft.fftn(x, dim=(-2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-2, -1)) + + # Initialize mask with high-frequency scaling factor + mask = torch.ones(x_freq.shape, device=device) * scale_high + m = mask + for d in range(len(x_freq.shape) - 2): + dim = d + 2 + cc = x_freq.shape[dim] // 2 + f_c = min(freq_cutoff, cc) + m = m.narrow(dim, cc - f_c, f_c * 2) + + # Apply low-frequency scaling factor to center region + m[:] = scale_low + + # 3) Apply frequency-specific scaling + x_freq = x_freq * mask + + # 4) Convert back to spatial domain + x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real + + # 5) Restore original dtype + x_filtered = x_filtered.to(dtype) + + return x_filtered + + +class FreSca: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01, + "tooltip": "Scaling factor for low-frequency components"}), + "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, + "tooltip": "Scaling factor for high-frequency components"}), + "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1, + "tooltip": "Number of frequency indices around center to consider as low-frequency"}), + } + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "_for_testing" + DESCRIPTION = "Applies frequency-dependent scaling to the guidance" + def patch(self, model, scale_low, scale_high, freq_cutoff): + def custom_cfg_function(args): + cond = args["conds_out"][0] + uncond = args["conds_out"][1] + + guidance = cond - uncond + filtered_guidance = Fourier_filter( + guidance, + scale_low=scale_low, + scale_high=scale_high, + freq_cutoff=freq_cutoff, + ) + filtered_cond = filtered_guidance + uncond + + return [filtered_cond, uncond] + + m = model.clone() + m.set_model_sampler_pre_cfg_function(custom_cfg_function) + + return (m,) + + +NODE_CLASS_MAPPINGS = { + "FreSca": FreSca, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "FreSca": "FreSca", +} diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py new file mode 100644 index 000000000..dfb98597b --- /dev/null +++ b/comfy_extras/nodes_hidream.py @@ -0,0 +1,55 @@ +import folder_paths +import comfy.sd +import comfy.model_management + + +class QuadrupleCLIPLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name3": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name4": (folder_paths.get_filename_list("text_encoders"), ) + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "load_clip" + + CATEGORY = "advanced/loaders" + + DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct" + + def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4): + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) + clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) + clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) + clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings")) + return (clip,) + +class CLIPTextEncodeHiDream: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "llama": ("STRING", {"multiline": True, "dynamicPrompts": True}) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, clip_l, clip_g, t5xxl, llama): + + tokens = clip.tokenize(clip_g) + tokens["l"] = clip.tokenize(clip_l)["l"] + tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] + tokens["llama"] = clip.tokenize(llama)["llama"] + return (clip.encode_from_tokens_scheduled(tokens), ) + +NODE_CLASS_MAPPINGS = { + "QuadrupleCLIPLoader": QuadrupleCLIPLoader, + "CLIPTextEncodeHiDream": CLIPTextEncodeHiDream, +} diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py new file mode 100644 index 000000000..51e45336a --- /dev/null +++ b/comfy_extras/nodes_hunyuan3d.py @@ -0,0 +1,634 @@ +import torch +import os +import json +import struct +import numpy as np +from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch +import folder_paths +import comfy.model_management +from comfy.cli_args import args + + +class EmptyLatentHunyuan3Dv2: + @classmethod + def INPUT_TYPES(s): + return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent/3d" + + def generate(self, resolution, batch_size): + latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device()) + return ({"samples": latent, "type": "hunyuan3dv2"}, ) + + +class Hunyuan3Dv2Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, clip_vision_output): + embeds = clip_vision_output.last_hidden_state + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return (positive, negative) + + +class Hunyuan3Dv2ConditioningMultiView: + @classmethod + def INPUT_TYPES(s): + return {"required": {}, + "optional": {"front": ("CLIP_VISION_OUTPUT",), + "left": ("CLIP_VISION_OUTPUT",), + "back": ("CLIP_VISION_OUTPUT",), + "right": ("CLIP_VISION_OUTPUT",), }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, front=None, left=None, back=None, right=None): + all_embeds = [front, left, back, right] + out = [] + pos_embeds = None + for i, e in enumerate(all_embeds): + if e is not None: + if pos_embeds is None: + pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4)) + out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1)) + + embeds = torch.cat(out, dim=1) + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return (positive, negative) + + +class VOXEL: + def __init__(self, data): + self.data = data + + +class VAEDecodeHunyuan3D: + @classmethod + def INPUT_TYPES(s): + return {"required": {"samples": ("LATENT", ), + "vae": ("VAE", ), + "num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}), + "octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}), + }} + RETURN_TYPES = ("VOXEL",) + FUNCTION = "decode" + + CATEGORY = "latent/3d" + + def decode(self, vae, samples, num_chunks, octree_resolution): + voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) + return (voxels, ) + + +def voxel_to_mesh(voxels, threshold=0.5, device=None): + if device is None: + device = torch.device("cpu") + voxels = voxels.to(device) + + binary = (voxels > threshold).float() + padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0) + + D, H, W = binary.shape + + neighbors = torch.tensor([ + [0, 0, 1], + [0, 0, -1], + [0, 1, 0], + [0, -1, 0], + [1, 0, 0], + [-1, 0, 0] + ], device=device) + + z, y, x = torch.meshgrid( + torch.arange(D, device=device), + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1) + + solid_mask = binary.flatten() > 0 + solid_indices = voxel_indices[solid_mask] + + corner_offsets = [ + torch.tensor([ + [0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0] + ], device=device), + torch.tensor([ + [1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0] + ], device=device) + ] + + all_vertices = [] + all_indices = [] + + vertex_count = 0 + + for face_idx, offset in enumerate(neighbors): + neighbor_indices = solid_indices + offset + + padded_indices = neighbor_indices + 1 + + is_exposed = padded[ + padded_indices[:, 0], + padded_indices[:, 1], + padded_indices[:, 2] + ] == 0 + + if not is_exposed.any(): + continue + + exposed_indices = solid_indices[is_exposed] + + corners = corner_offsets[face_idx].unsqueeze(0) + + face_vertices = exposed_indices.unsqueeze(1) + corners + + all_vertices.append(face_vertices.reshape(-1, 3)) + + num_faces = exposed_indices.shape[0] + face_indices = torch.arange( + vertex_count, + vertex_count + 4 * num_faces, + device=device + ).reshape(-1, 4) + + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1)) + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1)) + + vertex_count += 4 * num_faces + + if len(all_vertices) > 0: + vertices = torch.cat(all_vertices, dim=0) + faces = torch.cat(all_indices, dim=0) + else: + vertices = torch.zeros((1, 3)) + faces = torch.zeros((1, 3)) + + v_min = 0 + v_max = max(voxels.shape) + + vertices = vertices - (v_min + v_max) / 2 + + scale = (v_max - v_min) / 2 + if scale > 0: + vertices = vertices / scale + + vertices = torch.fliplr(vertices) + return vertices, faces + +def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): + if device is None: + device = torch.device("cpu") + voxels = voxels.to(device) + + D, H, W = voxels.shape + + padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0) + z, y, x = torch.meshgrid( + torch.arange(D, device=device), + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1) + + corner_offsets = torch.tensor([ + [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1] + ], device=device) + + corner_values = torch.zeros((cell_positions.shape[0], 8), device=device) + for c, (dz, dy, dx) in enumerate(corner_offsets): + corner_values[:, c] = padded[ + cell_positions[:, 0] + dz, + cell_positions[:, 1] + dy, + cell_positions[:, 2] + dx + ] + + corner_signs = corner_values > threshold + has_inside = torch.any(corner_signs, dim=1) + has_outside = torch.any(~corner_signs, dim=1) + contains_surface = has_inside & has_outside + + active_cells = cell_positions[contains_surface] + active_signs = corner_signs[contains_surface] + active_values = corner_values[contains_surface] + + if active_cells.shape[0] == 0: + return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device) + + edges = torch.tensor([ + [0, 1], [0, 2], [0, 4], [1, 3], + [1, 5], [2, 3], [2, 6], [3, 7], + [4, 5], [4, 6], [5, 7], [6, 7] + ], device=device) + + cell_vertices = {} + progress = comfy.utils.ProgressBar(100) + + for edge_idx, (e1, e2) in enumerate(edges): + progress.update(1) + crossing = active_signs[:, e1] != active_signs[:, e2] + if not crossing.any(): + continue + + cell_indices = torch.nonzero(crossing, as_tuple=True)[0] + + v1 = active_values[cell_indices, e1] + v2 = active_values[cell_indices, e2] + + t = torch.zeros_like(v1, device=device) + denom = v2 - v1 + valid = denom != 0 + t[valid] = (threshold - v1[valid]) / denom[valid] + t[~valid] = 0.5 + + p1 = corner_offsets[e1].float() + p2 = corner_offsets[e2].float() + + intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0)) + + for i, point in zip(cell_indices.tolist(), intersection): + if i not in cell_vertices: + cell_vertices[i] = [] + cell_vertices[i].append(point) + + # Calculate the final vertices as the average of intersection points for each cell + vertices = [] + vertex_lookup = {} + + vert_progress_mod = round(len(cell_vertices)/50) + + for i, points in cell_vertices.items(): + if not i % vert_progress_mod: + progress.update(1) + + if points: + vertex = torch.stack(points).mean(dim=0) + vertex = vertex + active_cells[i].float() + vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices) + vertices.append(vertex) + + if not vertices: + return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device) + + final_vertices = torch.stack(vertices) + + inside_corners_mask = active_signs + outside_corners_mask = ~active_signs + + inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float() + outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float() + + inside_pos = torch.zeros((active_cells.shape[0], 3), device=device) + outside_pos = torch.zeros((active_cells.shape[0], 3), device=device) + + for i in range(8): + mask_inside = inside_corners_mask[:, i].unsqueeze(1) + mask_outside = outside_corners_mask[:, i].unsqueeze(1) + inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside + outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside + + inside_pos /= inside_counts + outside_pos /= outside_counts + gradients = inside_pos - outside_pos + + pos_dirs = torch.tensor([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1] + ], device=device) + + cross_products = [ + torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float()) + for i in range(3) for j in range(i+1, 3) + ] + + faces = [] + all_keys = set(vertex_lookup.keys()) + + face_progress_mod = round(len(active_cells)/38*3) + + for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]): + dir_i = pos_dirs[i] + dir_j = pos_dirs[j] + cross_product = cross_products[pair_idx] + + ni_positions = active_cells + dir_i + nj_positions = active_cells + dir_j + diag_positions = active_cells + dir_i + dir_j + + alignments = torch.matmul(gradients, cross_product) + + valid_quads = [] + quad_indices = [] + + for idx, active_cell in enumerate(active_cells): + if not idx % face_progress_mod: + progress.update(1) + cell_key = tuple(active_cell.tolist()) + ni_key = tuple(ni_positions[idx].tolist()) + nj_key = tuple(nj_positions[idx].tolist()) + diag_key = tuple(diag_positions[idx].tolist()) + + if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys: + v0 = vertex_lookup[cell_key] + v1 = vertex_lookup[ni_key] + v2 = vertex_lookup[nj_key] + v3 = vertex_lookup[diag_key] + + valid_quads.append((v0, v1, v2, v3)) + quad_indices.append(idx) + + for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads): + cell_idx = quad_indices[q_idx] + if alignments[cell_idx] > 0: + faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long)) + faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long)) + else: + faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long)) + faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long)) + + if faces: + faces = torch.stack(faces) + else: + faces = torch.zeros((0, 3), dtype=torch.long, device=device) + + v_min = 0 + v_max = max(D, H, W) + + final_vertices = final_vertices - (v_min + v_max) / 2 + + scale = (v_max - v_min) / 2 + if scale > 0: + final_vertices = final_vertices / scale + + final_vertices = torch.fliplr(final_vertices) + + return final_vertices, faces + +class MESH: + def __init__(self, vertices, faces): + self.vertices = vertices + self.faces = faces + + +class VoxelToMeshBasic: + @classmethod + def INPUT_TYPES(s): + return {"required": {"voxel": ("VOXEL", ), + "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MESH",) + FUNCTION = "decode" + + CATEGORY = "3d" + + def decode(self, voxel, threshold): + vertices = [] + faces = [] + for x in voxel.data: + v, f = voxel_to_mesh(x, threshold=threshold, device=None) + vertices.append(v) + faces.append(f) + + return (MESH(torch.stack(vertices), torch.stack(faces)), ) + +class VoxelToMesh: + @classmethod + def INPUT_TYPES(s): + return {"required": {"voxel": ("VOXEL", ), + "algorithm": (["surface net", "basic"], ), + "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MESH",) + FUNCTION = "decode" + + CATEGORY = "3d" + + def decode(self, voxel, algorithm, threshold): + vertices = [] + faces = [] + + if algorithm == "basic": + mesh_function = voxel_to_mesh + elif algorithm == "surface net": + mesh_function = voxel_to_mesh_surfnet + + for x in voxel.data: + v, f = mesh_function(x, threshold=threshold, device=None) + vertices.append(v) + faces.append(f) + + return (MESH(torch.stack(vertices), torch.stack(faces)), ) + + +def save_glb(vertices, faces, filepath, metadata=None): + """ + Save PyTorch tensor vertices and faces as a GLB file without external dependencies. + + Parameters: + vertices: torch.Tensor of shape (N, 3) - The vertex coordinates + faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces) + filepath: str - Output filepath (should end with .glb) + """ + + # Convert tensors to numpy arrays + vertices_np = vertices.cpu().numpy().astype(np.float32) + faces_np = faces.cpu().numpy().astype(np.uint32) + + vertices_buffer = vertices_np.tobytes() + indices_buffer = faces_np.tobytes() + + def pad_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b'\x00' * padding_length + + vertices_buffer_padded = pad_to_4_bytes(vertices_buffer) + indices_buffer_padded = pad_to_4_bytes(indices_buffer) + + buffer_data = vertices_buffer_padded + indices_buffer_padded + + vertices_byte_length = len(vertices_buffer) + vertices_byte_offset = 0 + indices_byte_length = len(indices_buffer) + indices_byte_offset = len(vertices_buffer_padded) + + gltf = { + "asset": {"version": "2.0", "generator": "ComfyUI"}, + "buffers": [ + { + "byteLength": len(buffer_data) + } + ], + "bufferViews": [ + { + "buffer": 0, + "byteOffset": vertices_byte_offset, + "byteLength": vertices_byte_length, + "target": 34962 # ARRAY_BUFFER + }, + { + "buffer": 0, + "byteOffset": indices_byte_offset, + "byteLength": indices_byte_length, + "target": 34963 # ELEMENT_ARRAY_BUFFER + } + ], + "accessors": [ + { + "bufferView": 0, + "byteOffset": 0, + "componentType": 5126, # FLOAT + "count": len(vertices_np), + "type": "VEC3", + "max": vertices_np.max(axis=0).tolist(), + "min": vertices_np.min(axis=0).tolist() + }, + { + "bufferView": 1, + "byteOffset": 0, + "componentType": 5125, # UNSIGNED_INT + "count": faces_np.size, + "type": "SCALAR" + } + ], + "meshes": [ + { + "primitives": [ + { + "attributes": { + "POSITION": 0 + }, + "indices": 1, + "mode": 4 # TRIANGLES + } + ] + } + ], + "nodes": [ + { + "mesh": 0 + } + ], + "scenes": [ + { + "nodes": [0] + } + ], + "scene": 0 + } + + if metadata is not None: + gltf["asset"]["extras"] = metadata + + # Convert the JSON to bytes + gltf_json = json.dumps(gltf).encode('utf8') + + def pad_json_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b' ' * padding_length + + gltf_json_padded = pad_json_to_4_bytes(gltf_json) + + # Create the GLB header + # Magic glTF + glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data)) + + # Create JSON chunk header (chunk type 0) + json_chunk_header = struct.pack(' 0: - output_images = [] - for i in range(image.shape[0]): - output_images.append(preprocess(image[i], img_compression)) + output_images = [] + for i in range(image.shape[0]): + output_images.append(preprocess(image[i], img_compression)) return (torch.stack(output_images),) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 63fd13b9a..99b264a32 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -2,7 +2,11 @@ import numpy as np import scipy.ndimage import torch import comfy.utils +import node_helpers +import folder_paths +import random +import nodes from nodes import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): @@ -87,6 +91,7 @@ class ImageCompositeMasked: CATEGORY = "image" def composite(self, destination, source, x, y, resize_source, mask = None): + destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) return (output,) @@ -360,6 +365,30 @@ class ThresholdMask: mask = (mask > value).float() return (mask,) +# Mask Preview - original implement from +# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81 +# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes +class MaskPreview(nodes.SaveImage): + def __init__(self): + self.output_dir = folder_paths.get_temp_directory() + self.type = "temp" + self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) + self.compress_level = 4 + + @classmethod + def INPUT_TYPES(s): + return { + "required": {"mask": ("MASK",), }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + FUNCTION = "execute" + CATEGORY = "mask" + + def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) + return self.save_images(preview, filename_prefix, prompt, extra_pnginfo) + NODE_CLASS_MAPPINGS = { "LatentCompositeMasked": LatentCompositeMasked, @@ -374,6 +403,7 @@ NODE_CLASS_MAPPINGS = { "FeatherMask": FeatherMask, "GrowMask": GrowMask, "ThresholdMask": ThresholdMask, + "MaskPreview": MaskPreview } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index ceac5654b..71a652ffa 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -20,10 +20,6 @@ class LCM(comfy.model_sampling.EPS): return c_out * x0 + c_skip * model_input -class X0(comfy.model_sampling.EPS): - def calculate_denoised(self, sigma, model_output, model_input): - return model_output - class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 @@ -56,7 +52,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm", "x0"],), + "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -77,7 +73,9 @@ class ModelSamplingDiscrete: sampling_type = LCM sampling_base = ModelSamplingDiscreteDistilled elif sampling == "x0": - sampling_type = X0 + sampling_type = comfy.model_sampling.X0 + elif sampling == "img_to_img": + sampling_type = comfy.model_sampling.IMG_TO_IMG class ModelSamplingAdvanced(sampling_base, sampling_type): pass diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index ccf601158..78d284889 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -209,6 +209,9 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi metadata["modelspec.predict_key"] = "epsilon" elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: metadata["modelspec.predict_key"] = "v" + extra_keys["v_pred"] = torch.tensor([]) + if getattr(model_sampling, "zsnr", False): + extra_keys["ztsnr"] = torch.tensor([]) if not args.disable_metadata: metadata["prompt"] = prompt_info diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py index 3e37f70d4..dc3411947 100644 --- a/comfy_extras/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -244,6 +244,30 @@ class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks): return {"required": arg_dict} +class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb." + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["patch_embedding."] = argument + arg_dict["time_embedding."] = argument + arg_dict["time_projection."] = argument + arg_dict["text_embedding."] = argument + arg_dict["img_emb."] = argument + + for i in range(40): + arg_dict["blocks.{}.".format(i)] = argument + + arg_dict["head."] = argument + + return {"required": arg_dict} + NODE_CLASS_MAPPINGS = { "ModelMergeSD1": ModelMergeSD1, "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks @@ -256,4 +280,5 @@ NODE_CLASS_MAPPINGS = { "ModelMergeLTXV": ModelMergeLTXV, "ModelMergeCosmos7B": ModelMergeCosmos7B, "ModelMergeCosmos14B": ModelMergeCosmos14B, + "ModelMergeWAN2_1": ModelMergeWAN2_1, } diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py index b1372b8ce..075b26c40 100644 --- a/comfy_extras/nodes_morphology.py +++ b/comfy_extras/nodes_morphology.py @@ -2,6 +2,7 @@ import torch import comfy.model_management from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat +import kornia.color class Morphology: @@ -40,8 +41,45 @@ class Morphology: img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) return (img_out,) + +class ImageRGBToYUV: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + }} + + RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE") + RETURN_NAMES = ("Y", "U", "V") + FUNCTION = "execute" + + CATEGORY = "image/batch" + + def execute(self, image): + out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1) + return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) + +class ImageYUVToRGB: + @classmethod + def INPUT_TYPES(s): + return {"required": {"Y": ("IMAGE",), + "U": ("IMAGE",), + "V": ("IMAGE",), + }} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "execute" + + CATEGORY = "image/batch" + + def execute(self, Y, U, V): + image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1) + out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1) + return (out,) + NODE_CLASS_MAPPINGS = { "Morphology": Morphology, + "ImageRGBToYUV": ImageRGBToYUV, + "ImageYUVToRGB": ImageYUVToRGB, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/comfy_extras/nodes_optimalsteps.py b/comfy_extras/nodes_optimalsteps.py new file mode 100644 index 000000000..e7c851ca2 --- /dev/null +++ b/comfy_extras/nodes_optimalsteps.py @@ -0,0 +1,57 @@ +# from https://github.com/bebebe666/OptimalSteps + + +import numpy as np +import torch + +def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + + +NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001], +"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001], +"Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001], +} + +class OptimalStepsScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model_type": (["FLUX", "Wan", "Chroma"], ), + "steps": ("INT", {"default": 20, "min": 3, "max": 1000}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model_type, steps, denoise): + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return (torch.FloatTensor([]),) + total_steps = round(steps * denoise) + + sigmas = NOISE_LEVELS[model_type][:] + if (steps + 1) != len(sigmas): + sigmas = loglinear_interp(sigmas, steps + 1) + + sigmas = sigmas[-(total_steps + 1):] + sigmas[-1] = 0 + return (torch.FloatTensor(sigmas), ) + +NODE_CLASS_MAPPINGS = { + "OptimalStepsScheduler": OptimalStepsScheduler, +} diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 68f6ef51e..5b9542015 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -6,7 +6,7 @@ import math import comfy.utils import comfy.model_management - +import node_helpers class Blend: def __init__(self): @@ -34,6 +34,7 @@ class Blend: CATEGORY = "image/postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + image1, image2 = node_helpers.image_alpha_fix(image1, image2) image2 = image2.to(image1.device) if image1.shape != image2.shape: image2 = image2.permute(0, 3, 1, 2) diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py new file mode 100644 index 000000000..184b990c3 --- /dev/null +++ b/comfy_extras/nodes_primitive.py @@ -0,0 +1,81 @@ +# Primitive nodes that are evaluated at backend. +from __future__ import annotations + +import sys + +from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO + + +class String(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.STRING, {})}, + } + + RETURN_TYPES = (IO.STRING,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: str) -> tuple[str]: + return (value,) + + +class Int(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})}, + } + + RETURN_TYPES = (IO.INT,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: int) -> tuple[int]: + return (value,) + + +class Float(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})}, + } + + RETURN_TYPES = (IO.FLOAT,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: float) -> tuple[float]: + return (value,) + + +class Boolean(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.BOOLEAN, {})}, + } + + RETURN_TYPES = (IO.BOOLEAN,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: bool) -> tuple[bool]: + return (value,) + + +NODE_CLASS_MAPPINGS = { + "PrimitiveString": String, + "PrimitiveInt": Int, + "PrimitiveFloat": Float, + "PrimitiveBoolean": Boolean, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PrimitiveString": "String", + "PrimitiveInt": "Int", + "PrimitiveFloat": "Float", + "PrimitiveBoolean": "Boolean", +} diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 97ca513d8..61f7171b2 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -5,9 +5,13 @@ import av import torch import folder_paths import json +from typing import Optional, Literal from fractions import Fraction -from comfy.comfy_types import FileLocator - +from comfy.comfy_types import IO, FileLocator, ComfyNodeABC +from comfy_api.input import ImageInput, AudioInput, VideoInput +from comfy_api.util import VideoContainer, VideoCodec, VideoComponents +from comfy_api.input_impl import VideoFromFile, VideoFromComponents +from comfy.cli_args import args class SaveWEBM: def __init__(self): @@ -50,13 +54,15 @@ class SaveWEBM: for x in extra_pnginfo: container.metadata[x] = json.dumps(extra_pnginfo[x]) - codec_map = {"vp9": "libvpx-vp9", "av1": "libaom-av1"} + codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"} stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000)) stream.width = images.shape[-2] stream.height = images.shape[-3] - stream.pix_fmt = "yuv420p" + stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p" stream.bit_rate = 0 stream.options = {'crf': str(crf)} + if codec == "av1": + stream.options["preset"] = "6" for frame in images: frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24") @@ -73,7 +79,163 @@ class SaveWEBM: return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side +class SaveVideo(ComfyNodeABC): + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type: Literal["output"] = "output" + self.prefix_append = "" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "video": (IO.VIDEO, {"tooltip": "The video to save."}), + "filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}), + "format": (VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}), + "codec": (VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}), + }, + "hidden": { + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + RETURN_TYPES = () + FUNCTION = "save_video" + + OUTPUT_NODE = True + + CATEGORY = "image/video" + DESCRIPTION = "Saves the input images to your ComfyUI output directory." + + def save_video(self, video: VideoInput, filename_prefix, format, codec, prompt=None, extra_pnginfo=None): + filename_prefix += self.prefix_append + width, height = video.get_dimensions() + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( + filename_prefix, + self.output_dir, + width, + height + ) + results: list[FileLocator] = list() + saved_metadata = None + if not args.disable_metadata: + metadata = {} + if extra_pnginfo is not None: + metadata.update(extra_pnginfo) + if prompt is not None: + metadata["prompt"] = prompt + if len(metadata) > 0: + saved_metadata = metadata + file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" + video.save_to( + os.path.join(full_output_folder, file), + format=format, + codec=codec, + metadata=saved_metadata + ) + + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + counter += 1 + + return { "ui": { "images": results, "animated": (True,) } } + +class CreateVideo(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": (IO.IMAGE, {"tooltip": "The images to create a video from."}), + "fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}), + }, + "optional": { + "audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}), + } + } + + RETURN_TYPES = (IO.VIDEO,) + FUNCTION = "create_video" + + CATEGORY = "image/video" + DESCRIPTION = "Create a video from images." + + def create_video(self, images: ImageInput, fps: float, audio: Optional[AudioInput] = None): + return (VideoFromComponents( + VideoComponents( + images=images, + audio=audio, + frame_rate=Fraction(fps), + ) + ),) + +class GetVideoComponents(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "video": (IO.VIDEO, {"tooltip": "The video to extract components from."}), + } + } + RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT) + RETURN_NAMES = ("images", "audio", "fps") + FUNCTION = "get_components" + + CATEGORY = "image/video" + DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate." + + def get_components(self, video: VideoInput): + components = video.get_components() + + return (components.images, components.audio, float(components.frame_rate)) + +class LoadVideo(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls): + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = folder_paths.filter_files_content_types(files, ["video"]) + return {"required": + {"file": (sorted(files), {"video_upload": True})}, + } + + CATEGORY = "image/video" + + RETURN_TYPES = (IO.VIDEO,) + FUNCTION = "load_video" + def load_video(self, file): + video_path = folder_paths.get_annotated_filepath(file) + return (VideoFromFile(video_path),) + + @classmethod + def IS_CHANGED(cls, file): + video_path = folder_paths.get_annotated_filepath(file) + mod_time = os.path.getmtime(video_path) + # Instead of hashing the file, we can just use the modification time to avoid + # rehashing large files. + return mod_time + + @classmethod + def VALIDATE_INPUTS(cls, file): + if not folder_paths.exists_annotated_filepath(file): + return "Invalid video file: {}".format(file) + + return True NODE_CLASS_MAPPINGS = { "SaveWEBM": SaveWEBM, + "SaveVideo": SaveVideo, + "CreateVideo": CreateVideo, + "GetVideoComponents": GetVideoComponents, + "LoadVideo": LoadVideo, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SaveVideo": "Save Video", + "CreateVideo": "Create Video", + "GetVideoComponents": "Get Video Components", + "LoadVideo": "Load Video", } diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index dc30eb546..9dda64597 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -3,6 +3,8 @@ import node_helpers import torch import comfy.model_management import comfy.utils +import comfy.latent_formats +import comfy.clip_vision class WanImageToVideo: @@ -49,6 +51,258 @@ class WanImageToVideo: return (positive, negative, out_latent) +class WanFunControlToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "control_video": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(control_video[:, :, :, :3]) + concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + +class WanFirstLastFrameToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ), + "clip_vision_end_image": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if end_image is not None: + end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + image = torch.ones((length, height, width, 3)) * 0.5 + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + image[:start_image.shape[0]] = start_image + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + if end_image is not None: + image[-end_image.shape[0]:] = end_image + mask[:, :, -end_image.shape[0]:] = 0.0 + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_start_image is not None: + clip_vision_output = clip_vision_start_image + + if clip_vision_end_image is not None: + if clip_vision_output is not None: + states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2) + clip_vision_output = comfy.clip_vision.Output() + clip_vision_output.penultimate_hidden_states = states + else: + clip_vision_output = clip_vision_end_image + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + + +class WanFunInpaintToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): + flfv = WanFirstLastFrameToVideo() + return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) + + +class WanVaceToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}), + }, + "optional": {"control_video": ("IMAGE", ), + "control_masks": ("MASK", ), + "reference_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT") + RETURN_NAMES = ("positive", "negative", "latent", "trim_latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + EXPERIMENTAL = True + + def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None): + latent_length = ((length - 1) // 4) + 1 + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if control_video.shape[0] < length: + control_video = torch.nn.functional.pad(control_video, (0, 0, 0, 0, 0, 0, 0, length - control_video.shape[0]), value=0.5) + else: + control_video = torch.ones((length, height, width, 3)) * 0.5 + + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + reference_image = vae.encode(reference_image[:, :, :, :3]) + reference_image = torch.cat([reference_image, comfy.latent_formats.Wan21().process_out(torch.zeros_like(reference_image))], dim=1) + + if control_masks is None: + mask = torch.ones((length, height, width, 1)) + else: + mask = control_masks + if mask.ndim == 3: + mask = mask.unsqueeze(1) + mask = comfy.utils.common_upscale(mask[:length], width, height, "bilinear", "center").movedim(1, -1) + if mask.shape[0] < length: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, 0, 0, length - mask.shape[0]), value=1.0) + + control_video = control_video - 0.5 + inactive = (control_video * (1 - mask)) + 0.5 + reactive = (control_video * mask) + 0.5 + + inactive = vae.encode(inactive[:, :, :, :3]) + reactive = vae.encode(reactive[:, :, :, :3]) + control_video_latent = torch.cat((inactive, reactive), dim=1) + if reference_image is not None: + control_video_latent = torch.cat((reference_image, control_video_latent), dim=2) + + vae_stride = 8 + height_mask = height // vae_stride + width_mask = width // vae_stride + mask = mask.view(length, height_mask, vae_stride, width_mask, vae_stride) + mask = mask.permute(2, 4, 0, 1, 3) + mask = mask.reshape(vae_stride * vae_stride, length, height_mask, width_mask) + mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(latent_length, height_mask, width_mask), mode='nearest-exact').squeeze(0) + + trim_latent = 0 + if reference_image is not None: + mask_pad = torch.zeros_like(mask[:, :reference_image.shape[2], :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + latent_length += reference_image.shape[2] + trim_latent = reference_image.shape[2] + + mask = mask.unsqueeze(0) + positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength}) + negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength}) + + latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent, trim_latent) + +class TrimVideoLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/video" + + EXPERIMENTAL = True + + def op(self, samples, trim_amount): + samples_out = samples.copy() + + s1 = samples["samples"] + samples_out["samples"] = s1[:, :, trim_amount:] + return (samples_out,) + + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, + "WanFunControlToVideo": WanFunControlToVideo, + "WanFunInpaintToVideo": WanFunInpaintToVideo, + "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, + "WanVaceToVideo": WanVaceToVideo, + "TrimVideoLatent": TrimVideoLatent, } diff --git a/comfyui_version.py b/comfyui_version.py index b5e6fbead..67d27f942 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.26" +__version__ = "0.3.30" diff --git a/execution.py b/execution.py index 2c979205b..feb61ae82 100644 --- a/execution.py +++ b/execution.py @@ -15,7 +15,7 @@ import nodes import comfy.model_management from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph_utils import is_link, GraphBuilder -from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID +from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID from comfy_execution.validation import validate_node_input class ExecutionResult(Enum): @@ -59,20 +59,27 @@ class IsChangedCache: self.is_changed[node_id] = node["is_changed"] return self.is_changed[node_id] -class CacheSet: - def __init__(self, lru_size=None): - if lru_size is None or lru_size == 0: - self.init_classic_cache() - else: - self.init_lru_cache(lru_size) - self.all = [self.outputs, self.ui, self.objects] - # Useful for those with ample RAM/VRAM -- allows experimenting without - # blowing away the cache every time - def init_lru_cache(self, cache_size): - self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.objects = HierarchicalCache(CacheKeySetID) +class CacheType(Enum): + CLASSIC = 0 + LRU = 1 + DEPENDENCY_AWARE = 2 + + +class CacheSet: + def __init__(self, cache_type=None, cache_size=None): + if cache_type == CacheType.DEPENDENCY_AWARE: + self.init_dependency_aware_cache() + logging.info("Disabling intermediate node cache.") + elif cache_type == CacheType.LRU: + if cache_size is None: + cache_size = 0 + self.init_lru_cache(cache_size) + logging.info("Using LRU cache") + else: + self.init_classic_cache() + + self.all = [self.outputs, self.ui, self.objects] # Performs like the old cache -- dump data ASAP def init_classic_cache(self): @@ -80,6 +87,17 @@ class CacheSet: self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) + def init_lru_cache(self, cache_size): + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.objects = HierarchicalCache(CacheKeySetID) + + # only hold cached items while the decendents have not executed + def init_dependency_aware_cache(self): + self.outputs = DependencyAwareCache(CacheKeySetInputSignature) + self.ui = DependencyAwareCache(CacheKeySetInputSignature) + self.objects = DependencyAwareCache(CacheKeySetID) + def recursive_debug_dump(self): result = { "outputs": self.outputs.recursive_debug_dump(), @@ -93,7 +111,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e missing_keys = {} for x in inputs: input_data = inputs[x] - input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs) + _, input_category, input_info = get_input_info(class_def, x, valid_inputs) def mark_missing(): missing_keys[x] = True input_data_all[x] = (None,) @@ -126,6 +144,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] + if h[x] == "AUTH_TOKEN_COMFY_ORG": + input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] return input_data_all, missing_keys map_node_over_list = None #Don't hook this please @@ -414,13 +434,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server, lru_size=None): - self.lru_size = lru_size + def __init__(self, server, cache_type=False, cache_size=None): + self.cache_size = cache_size + self.cache_type = cache_type self.server = server self.reset() def reset(self): - self.caches = CacheSet(self.lru_size) + self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) self.status_messages = [] self.success = True @@ -555,7 +576,7 @@ def validate_inputs(prompt, item, validated): received_types = {} for x in valid_inputs: - type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs) + input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) assert extra_info is not None if x not in inputs: if input_category == "required": @@ -571,7 +592,7 @@ def validate_inputs(prompt, item, validated): continue val = inputs[x] - info = (type_input, extra_info) + info = (input_type, extra_info) if isinstance(val, list): if len(val) != 2: error = { @@ -592,8 +613,8 @@ def validate_inputs(prompt, item, validated): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): - details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" + if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type): + details = f"{x}, received_type({received_type}) mismatch input_type({input_type})" error = { "type": "return_type_mismatch", "message": "Return type mismatch between linked nodes", @@ -634,22 +655,29 @@ def validate_inputs(prompt, item, validated): continue else: try: - if type_input == "INT": + # Unwraps values wrapped in __value__ key. This is used to pass + # list widget value to execution, as by default list value is + # reserved to represent the connection between nodes. + if isinstance(val, dict) and "__value__" in val: + val = val["__value__"] + inputs[x] = val + + if input_type == "INT": val = int(val) inputs[x] = val - if type_input == "FLOAT": + if input_type == "FLOAT": val = float(val) inputs[x] = val - if type_input == "STRING": + if input_type == "STRING": val = str(val) inputs[x] = val - if type_input == "BOOLEAN": + if input_type == "BOOLEAN": val = bool(val) inputs[x] = val except Exception as ex: error = { "type": "invalid_input_type", - "message": f"Failed to convert an input value to a {type_input} value", + "message": f"Failed to convert an input value to a {input_type} value", "details": f"{x}, {val}, {ex}", "extra_info": { "input_name": x, @@ -689,18 +717,19 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if isinstance(type_input, list): - if val not in type_input: + if isinstance(input_type, list): + combo_options = input_type + if val not in combo_options: input_config = info list_info = "" # Don't send back gigantic lists like if they're lots of # scanned model filepaths - if len(type_input) > 20: - list_info = f"(list of length {len(type_input)})" + if len(combo_options) > 20: + list_info = f"(list of length {len(combo_options)})" input_config = None else: - list_info = str(type_input) + list_info = str(combo_options) error = { "type": "value_not_in_list", @@ -768,7 +797,7 @@ def validate_prompt(prompt): "details": f"Node ID '#{x}'", "extra_info": {} } - return (False, error, [], []) + return (False, error, [], {}) class_type = prompt[x]['class_type'] class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) @@ -779,7 +808,7 @@ def validate_prompt(prompt): "details": f"Node ID '#{x}'", "extra_info": {} } - return (False, error, [], []) + return (False, error, [], {}) if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: outputs.add(x) @@ -791,7 +820,7 @@ def validate_prompt(prompt): "details": "", "extra_info": {} } - return (False, error, [], []) + return (False, error, [], {}) good_outputs = set() errors = [] diff --git a/folder_paths.py b/folder_paths.py index 72c70f594..f0b3fd103 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -4,7 +4,7 @@ import os import time import mimetypes import logging -from typing import Literal +from typing import Literal, List from collections.abc import Collection from comfy.cli_args import args @@ -85,6 +85,7 @@ cache_helper = CacheHelper() extension_mimetypes_cache = { "webp" : "image", + "fbx" : "model", } def map_legacy(folder_name: str) -> str: @@ -140,11 +141,14 @@ def get_directory_by_type(type_name: str) -> str | None: return get_input_directory() return None -def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]: +def filter_files_content_types(files: list[str], content_types: List[Literal["image", "video", "audio", "model"]]) -> list[str]: """ Example: files = os.listdir(folder_paths.get_input_directory()) - filter_files_content_types(files, ["image", "audio", "video"]) + videos = filter_files_content_types(files, ["video"]) + + Note: + - 'model' in MIME context refers to 3D models, not files containing trained weights and parameters """ global extension_mimetypes_cache result = [] diff --git a/hook_breaker_ac10a0.py b/hook_breaker_ac10a0.py new file mode 100644 index 000000000..c3e1c0633 --- /dev/null +++ b/hook_breaker_ac10a0.py @@ -0,0 +1,17 @@ +# Prevent custom nodes from hooking anything important +import comfy.model_management + +HOOK_BREAK = [(comfy.model_management, "cast_to")] + + +SAVED_FUNCTIONS = [] + + +def save_functions(): + for f in HOOK_BREAK: + SAVED_FUNCTIONS.append((f[0], f[1], getattr(f[0], f[1]))) + + +def restore_functions(): + for f in SAVED_FUNCTIONS: + setattr(f[0], f[1], f[2]) diff --git a/main.py b/main.py index c6f5c3c1e..f3f56597a 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ from app.logger import setup_logger import itertools import utils.extra_config import logging +import sys if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes. @@ -139,7 +140,8 @@ from server import BinaryEventTypes import nodes import comfy.model_management import comfyui_version - +import app.logger +import hook_breaker_ac10a0 def cuda_malloc_warning(): device = comfy.model_management.get_torch_device() @@ -155,7 +157,13 @@ def cuda_malloc_warning(): def prompt_worker(q, server_instance): current_time: float = 0.0 - e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru) + cache_type = execution.CacheType.CLASSIC + if args.cache_lru > 0: + cache_type = execution.CacheType.LRU + elif args.cache_none: + cache_type = execution.CacheType.DEPENDENCY_AWARE + + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 @@ -207,6 +215,7 @@ def prompt_worker(q, server_instance): comfy.model_management.soft_empty_cache() last_gc_collect = current_time need_gc = False + hook_breaker_ac10a0.restore_functions() async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): @@ -260,7 +269,9 @@ def start_comfyui(asyncio_loop=None): prompt_server = server.PromptServer(asyncio_loop) q = execution.PromptQueue(prompt_server) + hook_breaker_ac10a0.save_functions() nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes) + hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() @@ -294,11 +305,13 @@ def start_comfyui(asyncio_loop=None): if __name__ == "__main__": # Running directly, just start ComfyUI. + logging.info("Python version: {}".format(sys.version)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) event_loop, _, start_all_func = start_comfyui() try: x = start_all_func() + app.logger.print_startup_warnings() event_loop.run_until_complete(x) except KeyboardInterrupt: logging.info("\nStopped server") diff --git a/node_helpers.py b/node_helpers.py index 48da3b099..c3e1a14ca 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -44,3 +44,11 @@ def string_to_torch_dtype(string): return torch.float16 if string == "bf16": return torch.bfloat16 + +def image_alpha_fix(destination, source): + if destination.shape[-1] < source.shape[-1]: + source = source[...,:destination.shape[-1]] + elif destination.shape[-1] > source.shape[-1]: + destination = torch.nn.functional.pad(destination, (0, 1)) + destination[..., -1] = 1.0 + return destination, source diff --git a/nodes.py b/nodes.py index 4ef13203f..2e85304cc 100644 --- a/nodes.py +++ b/nodes.py @@ -770,6 +770,7 @@ class VAELoader: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) vae = comfy.sd.VAE(sd=sd) + vae.throw_exception_if_invalid() return (vae,) class ControlNetLoader: @@ -785,6 +786,8 @@ class ControlNetLoader: def load_controlnet(self, control_net_name): controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) controlnet = comfy.controlnet.load_controlnet(controlnet_path) + if controlnet is None: + raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.") return (controlnet,) class DiffControlNetLoader: @@ -914,7 +917,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -924,29 +927,10 @@ class CLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5" def load_clip(self, clip_name, type="stable_diffusion", device="default"): - if type == "stable_cascade": - clip_type = comfy.sd.CLIPType.STABLE_CASCADE - elif type == "sd3": - clip_type = comfy.sd.CLIPType.SD3 - elif type == "stable_audio": - clip_type = comfy.sd.CLIPType.STABLE_AUDIO - elif type == "mochi": - clip_type = comfy.sd.CLIPType.MOCHI - elif type == "ltxv": - clip_type = comfy.sd.CLIPType.LTXV - elif type == "pixart": - clip_type = comfy.sd.CLIPType.PIXART - elif type == "cosmos": - clip_type = comfy.sd.CLIPType.COSMOS - elif type == "lumina2": - clip_type = comfy.sd.CLIPType.LUMINA2 - elif type == "wan": - clip_type = comfy.sd.CLIPType.WAN - else: - clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) model_options = {} if device == "cpu": @@ -961,7 +945,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -971,19 +955,13 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama" def load_clip(self, clip_name1, clip_name2, type, device="default"): + clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) - if type == "sdxl": - clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION - elif type == "sd3": - clip_type = comfy.sd.CLIPType.SD3 - elif type == "flux": - clip_type = comfy.sd.CLIPType.FLUX - elif type == "hunyuan_video": - clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO model_options = {} if device == "cpu": @@ -1005,6 +983,8 @@ class CLIPVisionLoader: def load_clip(self, clip_name): clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name) clip_vision = comfy.clip_vision.load(clip_path) + if clip_vision is None: + raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.") return (clip_vision,) class CLIPVisionEncode: @@ -1650,6 +1630,7 @@ class LoadImage: input_dir = folder_paths.get_input_directory() strip_inputdir = lambda path: path[len(input_dir)+1:] files = [strip_inputdir(os.path.join(dir,f)) for (dir, subdirs, files) in os.walk(input_dir, followlinks=True) for f in files] + files = folder_paths.filter_files_content_types(files, ["image"]) return {"required": {"image": (sorted(files), {"image_upload": True})}, } @@ -1688,6 +1669,9 @@ class LoadImage: if 'A' in i.getbands(): mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) + elif i.mode == 'P' and 'transparency' in i.info: + mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) else: mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") output_images.append(image) @@ -2123,21 +2107,25 @@ def get_module_name(module_path: str) -> str: def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool: - module_name = os.path.basename(module_path) + module_name = get_module_name(module_path) if os.path.isfile(module_path): sp = os.path.splitext(module_path) module_name = sp[0] + sys_module_name = module_name + elif os.path.isdir(module_path): + sys_module_name = module_path.replace(".", "_x_") + try: logging.debug("Trying to load custom node {}".format(module_path)) if os.path.isfile(module_path): - module_spec = importlib.util.spec_from_file_location(module_name, module_path) + module_spec = importlib.util.spec_from_file_location(sys_module_name, module_path) module_dir = os.path.split(module_path)[0] else: - module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py")) + module_spec = importlib.util.spec_from_file_location(sys_module_name, os.path.join(module_path, "__init__.py")) module_dir = module_path module = importlib.util.module_from_spec(module_spec) - sys.modules[module_name] = module + sys.modules[sys_module_name] = module module_spec.loader.exec_module(module) LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir) @@ -2264,6 +2252,18 @@ def init_builtin_extra_nodes(): "nodes_video.py", "nodes_lumina2.py", "nodes_wan.py", + "nodes_lotus.py", + "nodes_hunyuan3d.py", + "nodes_primitive.py", + "nodes_cfg.py", + "nodes_optimalsteps.py", + "nodes_hidream.py", + "nodes_fresca.py", + ] + + api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes") + api_nodes_files = [ + "nodes_api.py", ] import_failed = [] @@ -2271,6 +2271,10 @@ def init_builtin_extra_nodes(): if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"): import_failed.append(node_file) + for node_file in api_nodes_files: + if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"): + import_failed.append(node_file) + return import_failed diff --git a/pyproject.toml b/pyproject.toml index f13fed8dc..eadca662e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.26" +version = "0.3.30" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index e1316ccff..f64a05947 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -comfyui-frontend-package==1.11.8 +comfyui-frontend-package==1.17.11 +comfyui-workflow-templates==0.1.3 torch torchsde torchvision @@ -21,4 +22,5 @@ psutil kornia>=0.7.1 spandrel soundfile -av +av>=14.2.0 +pydantic~=2.0 diff --git a/server.py b/server.py index 76a99167d..f64ec27d4 100644 --- a/server.py +++ b/server.py @@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message): @web.middleware async def cache_control(request: web.Request, handler): response: web.Response = await handler(request) - if request.path.endswith('.js') or request.path.endswith('.css'): + if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'): response.headers.setdefault('Cache-Control', 'no-cache') return response @@ -580,6 +580,9 @@ class PromptServer(): info['deprecated'] = True if getattr(obj_class, "EXPERIMENTAL", False): info['experimental'] = True + + if hasattr(obj_class, 'API_NODE'): + info['api_node'] = obj_class.API_NODE return info @routes.get("/object_info") @@ -657,7 +660,13 @@ class PromptServer(): logging.warning("invalid prompt: {}".format(valid[1])) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) else: - return web.json_response({"error": "no prompt", "node_errors": []}, status=400) + error = { + "type": "no_prompt", + "message": "No prompt provided", + "details": "No prompt provided", + "extra_info": {} + } + return web.json_response({"error": error, "node_errors": {}}, status=400) @routes.post("/queue") async def post_queue(request): @@ -730,6 +739,12 @@ class PromptServer(): for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([web.static('/extensions/' + name, dir)]) + workflow_templates_path = FrontendManager.templates_path() + if workflow_templates_path: + self.app.add_routes([ + web.static('/templates', workflow_templates_path) + ]) + self.app.add_routes([ web.static('/', self.web_root), ]) diff --git a/tests-unit/folder_paths_test/filter_by_content_types_test.py b/tests-unit/folder_paths_test/filter_by_content_types_test.py index 423677a60..683f9fc11 100644 --- a/tests-unit/folder_paths_test/filter_by_content_types_test.py +++ b/tests-unit/folder_paths_test/filter_by_content_types_test.py @@ -1,14 +1,17 @@ import pytest import os import tempfile -from folder_paths import filter_files_content_types +from folder_paths import filter_files_content_types, extension_mimetypes_cache +from unittest.mock import patch + @pytest.fixture(scope="module") def file_extensions(): return { 'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'], 'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'], - 'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'] + 'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'], + 'model': ['gltf', 'glb', 'obj', 'fbx', 'stl'] } @@ -22,7 +25,18 @@ def mock_dir(file_extensions): yield directory -def test_categorizes_all_correctly(mock_dir, file_extensions): +@pytest.fixture +def patched_mimetype_cache(file_extensions): + # Mock model file extensions since they may not be in the test-runner system's mimetype cache + new_cache = extension_mimetypes_cache.copy() + for extension in file_extensions["model"]: + new_cache[extension] = "model" + + with patch("folder_paths.extension_mimetypes_cache", new_cache): + yield + + +def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache): files = os.listdir(mock_dir) for content_type, extensions in file_extensions.items(): filtered_files = filter_files_content_types(files, [content_type]) @@ -30,7 +44,7 @@ def test_categorizes_all_correctly(mock_dir, file_extensions): assert f"sample_{content_type}.{extension}" in filtered_files -def test_categorizes_all_uniquely(mock_dir, file_extensions): +def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache): files = os.listdir(mock_dir) for content_type, extensions in file_extensions.items(): filtered_files = filter_files_content_types(files, [content_type]) diff --git a/tests-unit/prompt_server_test/user_manager_test.py b/tests-unit/prompt_server_test/user_manager_test.py index 7e523cbf4..b939d8e68 100644 --- a/tests-unit/prompt_server_test/user_manager_test.py +++ b/tests-unit/prompt_server_test/user_manager_test.py @@ -229,3 +229,61 @@ async def test_move_userdata_full_info(aiohttp_client, app, tmp_path): assert not os.path.exists(tmp_path / "source.txt") with open(tmp_path / "dest.txt", "r") as f: assert f.read() == "test content" + + +async def test_listuserdata_v2_empty_root(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/v2/userdata") + assert resp.status == 200 + assert await resp.json() == [] + + +async def test_listuserdata_v2_nonexistent_subdirectory(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/v2/userdata?path=does_not_exist") + assert resp.status == 404 + + +async def test_listuserdata_v2_default(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir" / "subdir") + (tmp_path / "test_dir" / "file1.txt").write_text("content") + (tmp_path / "test_dir" / "subdir" / "file2.txt").write_text("content") + + client = await aiohttp_client(app) + resp = await client.get("/v2/userdata?path=test_dir") + assert resp.status == 200 + data = await resp.json() + file_paths = {item["path"] for item in data if item["type"] == "file"} + assert file_paths == {"test_dir/file1.txt", "test_dir/subdir/file2.txt"} + + +async def test_listuserdata_v2_normalized_separators(aiohttp_client, app, tmp_path, monkeypatch): + # Force backslash as os separator + monkeypatch.setattr(os, 'sep', '\\') + monkeypatch.setattr(os.path, 'sep', '\\') + os.makedirs(tmp_path / "test_dir" / "subdir") + (tmp_path / "test_dir" / "subdir" / "file1.txt").write_text("x") + + client = await aiohttp_client(app) + resp = await client.get("/v2/userdata?path=test_dir") + assert resp.status == 200 + data = await resp.json() + for item in data: + assert "/" in item["path"] + assert "\\" not in item["path"]\ + +async def test_listuserdata_v2_url_encoded_path(aiohttp_client, app, tmp_path): + # Create a directory with a space in its name and a file inside + os.makedirs(tmp_path / "my dir") + (tmp_path / "my dir" / "file.txt").write_text("content") + + client = await aiohttp_client(app) + # Use URL-encoded space in path parameter + resp = await client.get("/v2/userdata?path=my%20dir&recurse=false") + assert resp.status == 200 + data = await resp.json() + assert len(data) == 1 + entry = data[0] + assert entry["name"] == "file.txt" + # Ensure the path is correctly decoded and uses forward slash + assert entry["path"] == "my dir/file.txt"